mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-29 11:24:19 +00:00
feat: small ollama package
This commit is contained in:
commit
2d5d05a2b4
103 changed files with 7262 additions and 7422 deletions
10
.github/PULL_REQUEST_TEMPLATE.md
vendored
10
.github/PULL_REQUEST_TEMPLATE.md
vendored
|
@ -1,10 +1,8 @@
|
||||||
# What does this PR do?
|
# What does this PR do?
|
||||||
[Provide a short summary of what this PR does and why. Link to relevant issues if applicable.]
|
<!-- Provide a short summary of what this PR does and why. Link to relevant issues if applicable. -->
|
||||||
|
|
||||||
[//]: # (If resolving an issue, uncomment and update the line below)
|
<!-- If resolving an issue, uncomment and update the line below -->
|
||||||
[//]: # (Closes #[issue-number])
|
<!-- Closes #[issue-number] -->
|
||||||
|
|
||||||
## Test Plan
|
## Test Plan
|
||||||
[Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*]
|
<!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* -->
|
||||||
|
|
||||||
[//]: # (## Documentation)
|
|
||||||
|
|
2
.github/actions/setup-runner/action.yml
vendored
2
.github/actions/setup-runner/action.yml
vendored
|
@ -13,7 +13,7 @@ runs:
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
uv sync --all-extras
|
uv sync --all-groups
|
||||||
uv pip install ollama faiss-cpu
|
uv pip install ollama faiss-cpu
|
||||||
# always test against the latest version of the client
|
# always test against the latest version of the client
|
||||||
# TODO: this is not necessarily a good idea. we need to test against both published and latest
|
# TODO: this is not necessarily a good idea. we need to test against both published and latest
|
||||||
|
|
|
@ -53,7 +53,7 @@ repos:
|
||||||
- black==24.3.0
|
- black==24.3.0
|
||||||
|
|
||||||
- repo: https://github.com/astral-sh/uv-pre-commit
|
- repo: https://github.com/astral-sh/uv-pre-commit
|
||||||
rev: 0.6.3
|
rev: 0.7.8
|
||||||
hooks:
|
hooks:
|
||||||
- id: uv-lock
|
- id: uv-lock
|
||||||
- id: uv-export
|
- id: uv-export
|
||||||
|
@ -61,6 +61,7 @@ repos:
|
||||||
"--frozen",
|
"--frozen",
|
||||||
"--no-hashes",
|
"--no-hashes",
|
||||||
"--no-emit-project",
|
"--no-emit-project",
|
||||||
|
"--no-default-groups",
|
||||||
"--output-file=requirements.txt"
|
"--output-file=requirements.txt"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -88,8 +89,8 @@ repos:
|
||||||
- id: distro-codegen
|
- id: distro-codegen
|
||||||
name: Distribution Template Codegen
|
name: Distribution Template Codegen
|
||||||
additional_dependencies:
|
additional_dependencies:
|
||||||
- uv==0.6.0
|
- uv==0.7.8
|
||||||
entry: uv run --extra codegen ./scripts/distro_codegen.py
|
entry: uv run --group codegen ./scripts/distro_codegen.py
|
||||||
language: python
|
language: python
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
require_serial: true
|
require_serial: true
|
||||||
|
@ -97,8 +98,8 @@ repos:
|
||||||
- id: openapi-codegen
|
- id: openapi-codegen
|
||||||
name: API Spec Codegen
|
name: API Spec Codegen
|
||||||
additional_dependencies:
|
additional_dependencies:
|
||||||
- uv==0.6.2
|
- uv==0.7.8
|
||||||
entry: sh -c 'uv run --with ".[dev]" ./docs/openapi_generator/run_openapi_generator.sh > /dev/null'
|
entry: sh -c 'uv run ./docs/openapi_generator/run_openapi_generator.sh > /dev/null'
|
||||||
language: python
|
language: python
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
require_serial: true
|
require_serial: true
|
||||||
|
|
|
@ -5,28 +5,21 @@
|
||||||
# Required
|
# Required
|
||||||
version: 2
|
version: 2
|
||||||
|
|
||||||
|
# Build documentation in the "docs/" directory with Sphinx
|
||||||
|
sphinx:
|
||||||
|
configuration: docs/source/conf.py
|
||||||
|
|
||||||
# Set the OS, Python version and other tools you might need
|
# Set the OS, Python version and other tools you might need
|
||||||
build:
|
build:
|
||||||
os: ubuntu-22.04
|
os: ubuntu-22.04
|
||||||
tools:
|
tools:
|
||||||
python: "3.12"
|
python: "3.12"
|
||||||
# You can also specify other tool versions:
|
jobs:
|
||||||
# nodejs: "19"
|
pre_create_environment:
|
||||||
# rust: "1.64"
|
- asdf plugin add uv
|
||||||
# golang: "1.19"
|
- asdf install uv latest
|
||||||
|
- asdf global uv latest
|
||||||
# Build documentation in the "docs/" directory with Sphinx
|
create_environment:
|
||||||
sphinx:
|
- uv venv "${READTHEDOCS_VIRTUALENV_PATH}"
|
||||||
configuration: docs/source/conf.py
|
|
||||||
|
|
||||||
# Optionally build your docs in additional formats such as PDF and ePub
|
|
||||||
# formats:
|
|
||||||
# - pdf
|
|
||||||
# - epub
|
|
||||||
|
|
||||||
# Optional but recommended, declare the Python requirements required
|
|
||||||
# to build your documentation
|
|
||||||
# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html
|
|
||||||
python:
|
|
||||||
install:
|
install:
|
||||||
- requirements: docs/requirements.txt
|
- UV_PROJECT_ENVIRONMENT="${READTHEDOCS_VIRTUALENV_PATH}" uv sync --frozen --group docs
|
||||||
|
|
|
@ -168,10 +168,10 @@ If you are making changes to the documentation at [https://llama-stack.readthedo
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# This rebuilds the documentation pages.
|
# This rebuilds the documentation pages.
|
||||||
uv run --with ".[docs]" make -C docs/ html
|
uv run --group docs make -C docs/ html
|
||||||
|
|
||||||
# This will start a local server (usually at http://127.0.0.1:8000) that automatically rebuilds and refreshes when you make changes to the documentation.
|
# This will start a local server (usually at http://127.0.0.1:8000) that automatically rebuilds and refreshes when you make changes to the documentation.
|
||||||
uv run --with ".[docs]" sphinx-autobuild docs/source docs/build/html --write-all
|
uv run --group docs sphinx-autobuild docs/source docs/build/html --write-all
|
||||||
```
|
```
|
||||||
|
|
||||||
### Update API Documentation
|
### Update API Documentation
|
||||||
|
@ -179,7 +179,7 @@ uv run --with ".[docs]" sphinx-autobuild docs/source docs/build/html --write-all
|
||||||
If you modify or add new API endpoints, update the API documentation accordingly. You can do this by running the following command:
|
If you modify or add new API endpoints, update the API documentation accordingly. You can do this by running the following command:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uv run --with ".[dev]" ./docs/openapi_generator/run_openapi_generator.sh
|
uv run ./docs/openapi_generator/run_openapi_generator.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
The generated API documentation will be available in `docs/_static/`. Make sure to review the changes before committing.
|
The generated API documentation will be available in `docs/_static/`. Make sure to review the changes before committing.
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
include pyproject.toml
|
include pyproject.toml
|
||||||
include llama_stack/templates/dependencies.json
|
|
||||||
include llama_stack/models/llama/llama3/tokenizer.model
|
include llama_stack/models/llama/llama3/tokenizer.model
|
||||||
include llama_stack/models/llama/llama4/tokenizer.model
|
include llama_stack/models/llama/llama4/tokenizer.model
|
||||||
include llama_stack/distribution/*.sh
|
include llama_stack/distribution/*.sh
|
||||||
|
|
43
README.md
43
README.md
|
@ -107,26 +107,29 @@ By reducing friction and complexity, Llama Stack empowers developers to focus on
|
||||||
### API Providers
|
### API Providers
|
||||||
Here is a list of the various API providers and available distributions that can help developers get started easily with Llama Stack.
|
Here is a list of the various API providers and available distributions that can help developers get started easily with Llama Stack.
|
||||||
|
|
||||||
| **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** |
|
| **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** | **Post Training** |
|
||||||
|:------------------------:|:----------------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:|
|
|:------------------------:|:----------------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:|:-----------------:|
|
||||||
| Meta Reference | Single Node | ✅ | ✅ | ✅ | ✅ | ✅ |
|
| Meta Reference | Single Node | ✅ | ✅ | ✅ | ✅ | ✅ | |
|
||||||
| SambaNova | Hosted | | ✅ | | ✅ | |
|
| SambaNova | Hosted | | ✅ | | ✅ | | |
|
||||||
| Cerebras | Hosted | | ✅ | | | |
|
| Cerebras | Hosted | | ✅ | | | | |
|
||||||
| Fireworks | Hosted | ✅ | ✅ | ✅ | | |
|
| Fireworks | Hosted | ✅ | ✅ | ✅ | | | |
|
||||||
| AWS Bedrock | Hosted | | ✅ | | ✅ | |
|
| AWS Bedrock | Hosted | | ✅ | | ✅ | | |
|
||||||
| Together | Hosted | ✅ | ✅ | | ✅ | |
|
| Together | Hosted | ✅ | ✅ | | ✅ | | |
|
||||||
| Groq | Hosted | | ✅ | | | |
|
| Groq | Hosted | | ✅ | | | | |
|
||||||
| Ollama | Single Node | | ✅ | | | |
|
| Ollama | Single Node | | ✅ | | | | |
|
||||||
| TGI | Hosted and Single Node | | ✅ | | | |
|
| TGI | Hosted and Single Node | | ✅ | | | | |
|
||||||
| NVIDIA NIM | Hosted and Single Node | | ✅ | | | |
|
| NVIDIA NIM | Hosted and Single Node | | ✅ | | | | |
|
||||||
| Chroma | Single Node | | | ✅ | | |
|
| Chroma | Single Node | | | ✅ | | | |
|
||||||
| PG Vector | Single Node | | | ✅ | | |
|
| PG Vector | Single Node | | | ✅ | | | |
|
||||||
| PyTorch ExecuTorch | On-device iOS | ✅ | ✅ | | | |
|
| PyTorch ExecuTorch | On-device iOS | ✅ | ✅ | | | | |
|
||||||
| vLLM | Hosted and Single Node | | ✅ | | | |
|
| vLLM | Hosted and Single Node | | ✅ | | | | |
|
||||||
| OpenAI | Hosted | | ✅ | | | |
|
| OpenAI | Hosted | | ✅ | | | | |
|
||||||
| Anthropic | Hosted | | ✅ | | | |
|
| Anthropic | Hosted | | ✅ | | | | |
|
||||||
| Gemini | Hosted | | ✅ | | | |
|
| Gemini | Hosted | | ✅ | | | | |
|
||||||
| watsonx | Hosted | | ✅ | | | |
|
| watsonx | Hosted | | ✅ | | | | |
|
||||||
|
| HuggingFace | Single Node | | | | | | ✅ |
|
||||||
|
| TorchTune | Single Node | | | | | | ✅ |
|
||||||
|
| NVIDIA NEMO | Hosted | | | | | | ✅ |
|
||||||
|
|
||||||
|
|
||||||
### Distributions
|
### Distributions
|
||||||
|
|
52
docs/_static/llama-stack-spec.html
vendored
52
docs/_static/llama-stack-spec.html
vendored
|
@ -7540,6 +7540,9 @@
|
||||||
{
|
{
|
||||||
"$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseCreated"
|
"$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseCreated"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseOutputTextDelta"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseCompleted"
|
"$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseCompleted"
|
||||||
}
|
}
|
||||||
|
@ -7548,6 +7551,7 @@
|
||||||
"propertyName": "type",
|
"propertyName": "type",
|
||||||
"mapping": {
|
"mapping": {
|
||||||
"response.created": "#/components/schemas/OpenAIResponseObjectStreamResponseCreated",
|
"response.created": "#/components/schemas/OpenAIResponseObjectStreamResponseCreated",
|
||||||
|
"response.output_text.delta": "#/components/schemas/OpenAIResponseObjectStreamResponseOutputTextDelta",
|
||||||
"response.completed": "#/components/schemas/OpenAIResponseObjectStreamResponseCompleted"
|
"response.completed": "#/components/schemas/OpenAIResponseObjectStreamResponseCompleted"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -7590,6 +7594,41 @@
|
||||||
],
|
],
|
||||||
"title": "OpenAIResponseObjectStreamResponseCreated"
|
"title": "OpenAIResponseObjectStreamResponseCreated"
|
||||||
},
|
},
|
||||||
|
"OpenAIResponseObjectStreamResponseOutputTextDelta": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"content_index": {
|
||||||
|
"type": "integer"
|
||||||
|
},
|
||||||
|
"delta": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"item_id": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"output_index": {
|
||||||
|
"type": "integer"
|
||||||
|
},
|
||||||
|
"sequence_number": {
|
||||||
|
"type": "integer"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "response.output_text.delta",
|
||||||
|
"default": "response.output_text.delta"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"content_index",
|
||||||
|
"delta",
|
||||||
|
"item_id",
|
||||||
|
"output_index",
|
||||||
|
"sequence_number",
|
||||||
|
"type"
|
||||||
|
],
|
||||||
|
"title": "OpenAIResponseObjectStreamResponseOutputTextDelta"
|
||||||
|
},
|
||||||
"CreateUploadSessionRequest": {
|
"CreateUploadSessionRequest": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
@ -9555,9 +9594,6 @@
|
||||||
"toolgroup_id": {
|
"toolgroup_id": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
"tool_host": {
|
|
||||||
"$ref": "#/components/schemas/ToolHost"
|
|
||||||
},
|
|
||||||
"description": {
|
"description": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
|
@ -9599,21 +9635,11 @@
|
||||||
"provider_id",
|
"provider_id",
|
||||||
"type",
|
"type",
|
||||||
"toolgroup_id",
|
"toolgroup_id",
|
||||||
"tool_host",
|
|
||||||
"description",
|
"description",
|
||||||
"parameters"
|
"parameters"
|
||||||
],
|
],
|
||||||
"title": "Tool"
|
"title": "Tool"
|
||||||
},
|
},
|
||||||
"ToolHost": {
|
|
||||||
"type": "string",
|
|
||||||
"enum": [
|
|
||||||
"distribution",
|
|
||||||
"client",
|
|
||||||
"model_context_protocol"
|
|
||||||
],
|
|
||||||
"title": "ToolHost"
|
|
||||||
},
|
|
||||||
"ToolGroup": {
|
"ToolGroup": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
|
39
docs/_static/llama-stack-spec.yaml
vendored
39
docs/_static/llama-stack-spec.yaml
vendored
|
@ -5294,11 +5294,13 @@ components:
|
||||||
OpenAIResponseObjectStream:
|
OpenAIResponseObjectStream:
|
||||||
oneOf:
|
oneOf:
|
||||||
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated'
|
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated'
|
||||||
|
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseOutputTextDelta'
|
||||||
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCompleted'
|
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCompleted'
|
||||||
discriminator:
|
discriminator:
|
||||||
propertyName: type
|
propertyName: type
|
||||||
mapping:
|
mapping:
|
||||||
response.created: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated'
|
response.created: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated'
|
||||||
|
response.output_text.delta: '#/components/schemas/OpenAIResponseObjectStreamResponseOutputTextDelta'
|
||||||
response.completed: '#/components/schemas/OpenAIResponseObjectStreamResponseCompleted'
|
response.completed: '#/components/schemas/OpenAIResponseObjectStreamResponseCompleted'
|
||||||
"OpenAIResponseObjectStreamResponseCompleted":
|
"OpenAIResponseObjectStreamResponseCompleted":
|
||||||
type: object
|
type: object
|
||||||
|
@ -5330,6 +5332,33 @@ components:
|
||||||
- type
|
- type
|
||||||
title: >-
|
title: >-
|
||||||
OpenAIResponseObjectStreamResponseCreated
|
OpenAIResponseObjectStreamResponseCreated
|
||||||
|
"OpenAIResponseObjectStreamResponseOutputTextDelta":
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
content_index:
|
||||||
|
type: integer
|
||||||
|
delta:
|
||||||
|
type: string
|
||||||
|
item_id:
|
||||||
|
type: string
|
||||||
|
output_index:
|
||||||
|
type: integer
|
||||||
|
sequence_number:
|
||||||
|
type: integer
|
||||||
|
type:
|
||||||
|
type: string
|
||||||
|
const: response.output_text.delta
|
||||||
|
default: response.output_text.delta
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- content_index
|
||||||
|
- delta
|
||||||
|
- item_id
|
||||||
|
- output_index
|
||||||
|
- sequence_number
|
||||||
|
- type
|
||||||
|
title: >-
|
||||||
|
OpenAIResponseObjectStreamResponseOutputTextDelta
|
||||||
CreateUploadSessionRequest:
|
CreateUploadSessionRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
@ -6713,8 +6742,6 @@ components:
|
||||||
default: tool
|
default: tool
|
||||||
toolgroup_id:
|
toolgroup_id:
|
||||||
type: string
|
type: string
|
||||||
tool_host:
|
|
||||||
$ref: '#/components/schemas/ToolHost'
|
|
||||||
description:
|
description:
|
||||||
type: string
|
type: string
|
||||||
parameters:
|
parameters:
|
||||||
|
@ -6737,17 +6764,9 @@ components:
|
||||||
- provider_id
|
- provider_id
|
||||||
- type
|
- type
|
||||||
- toolgroup_id
|
- toolgroup_id
|
||||||
- tool_host
|
|
||||||
- description
|
- description
|
||||||
- parameters
|
- parameters
|
||||||
title: Tool
|
title: Tool
|
||||||
ToolHost:
|
|
||||||
type: string
|
|
||||||
enum:
|
|
||||||
- distribution
|
|
||||||
- client
|
|
||||||
- model_context_protocol
|
|
||||||
title: ToolHost
|
|
||||||
ToolGroup:
|
ToolGroup:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -6,7 +6,7 @@ Here's a collection of comprehensive guides, examples, and resources for buildin
|
||||||
|
|
||||||
From the llama-stack root directory, run the following command to render the docs locally:
|
From the llama-stack root directory, run the following command to render the docs locally:
|
||||||
```bash
|
```bash
|
||||||
uv run --with ".[docs]" sphinx-autobuild docs/source docs/build/html --write-all
|
uv run --group docs sphinx-autobuild docs/source docs/build/html --write-all
|
||||||
```
|
```
|
||||||
You can open up the docs in your browser at http://localhost:8000
|
You can open up the docs in your browser at http://localhost:8000
|
||||||
|
|
||||||
|
|
|
@ -30,6 +30,18 @@ Runs inference with an LLM.
|
||||||
## Post Training
|
## Post Training
|
||||||
Fine-tunes a model.
|
Fine-tunes a model.
|
||||||
|
|
||||||
|
#### Post Training Providers
|
||||||
|
The following providers are available for Post Training:
|
||||||
|
|
||||||
|
```{toctree}
|
||||||
|
:maxdepth: 1
|
||||||
|
|
||||||
|
external
|
||||||
|
post_training/huggingface
|
||||||
|
post_training/torchtune
|
||||||
|
post_training/nvidia_nemo
|
||||||
|
```
|
||||||
|
|
||||||
## Safety
|
## Safety
|
||||||
Applies safety policies to the output at a Systems (not only model) level.
|
Applies safety policies to the output at a Systems (not only model) level.
|
||||||
|
|
||||||
|
|
122
docs/source/providers/post_training/huggingface.md
Normal file
122
docs/source/providers/post_training/huggingface.md
Normal file
|
@ -0,0 +1,122 @@
|
||||||
|
---
|
||||||
|
orphan: true
|
||||||
|
---
|
||||||
|
# HuggingFace SFTTrainer
|
||||||
|
|
||||||
|
[HuggingFace SFTTrainer](https://huggingface.co/docs/trl/en/sft_trainer) is an inline post training provider for Llama Stack. It allows you to run supervised fine tuning on a variety of models using many datasets
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- Simple access through the post_training API
|
||||||
|
- Fully integrated with Llama Stack
|
||||||
|
- GPU support, CPU support, and MPS support (MacOS Metal Performance Shaders)
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To use the HF SFTTrainer in your Llama Stack project, follow these steps:
|
||||||
|
|
||||||
|
1. Configure your Llama Stack project to use this provider.
|
||||||
|
2. Kick off a SFT job using the Llama Stack post_training API.
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
You can access the HuggingFace trainer via the `ollama` distribution:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llama stack build --template ollama --image-type venv
|
||||||
|
llama stack run --image-type venv ~/.llama/distributions/ollama/ollama-run.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
## Run Training
|
||||||
|
|
||||||
|
You can access the provider and the `supervised_fine_tune` method via the post_training API:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
|
from llama_stack_client.types import (
|
||||||
|
post_training_supervised_fine_tune_params,
|
||||||
|
algorithm_config_param,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_http_client():
|
||||||
|
from llama_stack_client import LlamaStackClient
|
||||||
|
|
||||||
|
return LlamaStackClient(base_url="http://localhost:8321")
|
||||||
|
|
||||||
|
|
||||||
|
client = create_http_client()
|
||||||
|
|
||||||
|
# Example Dataset
|
||||||
|
client.datasets.register(
|
||||||
|
purpose="post-training/messages",
|
||||||
|
source={
|
||||||
|
"type": "uri",
|
||||||
|
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
|
||||||
|
},
|
||||||
|
dataset_id="simpleqa",
|
||||||
|
)
|
||||||
|
|
||||||
|
training_config = post_training_supervised_fine_tune_params.TrainingConfig(
|
||||||
|
data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig(
|
||||||
|
batch_size=32,
|
||||||
|
data_format="instruct",
|
||||||
|
dataset_id="simpleqa",
|
||||||
|
shuffle=True,
|
||||||
|
),
|
||||||
|
gradient_accumulation_steps=1,
|
||||||
|
max_steps_per_epoch=0,
|
||||||
|
max_validation_steps=1,
|
||||||
|
n_epochs=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
algorithm_config = algorithm_config_param.LoraFinetuningConfig( # this config is also currently mandatory but should not be
|
||||||
|
alpha=1,
|
||||||
|
apply_lora_to_mlp=True,
|
||||||
|
apply_lora_to_output=False,
|
||||||
|
lora_attn_modules=["q_proj"],
|
||||||
|
rank=1,
|
||||||
|
type="LoRA",
|
||||||
|
)
|
||||||
|
|
||||||
|
job_uuid = f"test-job{uuid.uuid4()}"
|
||||||
|
|
||||||
|
# Example Model
|
||||||
|
training_model = "ibm-granite/granite-3.3-8b-instruct"
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
response = client.post_training.supervised_fine_tune(
|
||||||
|
job_uuid=job_uuid,
|
||||||
|
logger_config={},
|
||||||
|
model=training_model,
|
||||||
|
hyperparam_search_config={},
|
||||||
|
training_config=training_config,
|
||||||
|
algorithm_config=algorithm_config,
|
||||||
|
checkpoint_dir="output",
|
||||||
|
)
|
||||||
|
print("Job: ", job_uuid)
|
||||||
|
|
||||||
|
|
||||||
|
# Wait for the job to complete!
|
||||||
|
while True:
|
||||||
|
status = client.post_training.job.status(job_uuid=job_uuid)
|
||||||
|
if not status:
|
||||||
|
print("Job not found")
|
||||||
|
break
|
||||||
|
|
||||||
|
print(status)
|
||||||
|
if status.status == "completed":
|
||||||
|
break
|
||||||
|
|
||||||
|
print("Waiting for job to complete...")
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
print("Job completed in", end_time - start_time, "seconds!")
|
||||||
|
|
||||||
|
print("Artifacts:")
|
||||||
|
print(client.post_training.job.artifacts(job_uuid=job_uuid))
|
||||||
|
```
|
163
docs/source/providers/post_training/nvidia_nemo.md
Normal file
163
docs/source/providers/post_training/nvidia_nemo.md
Normal file
|
@ -0,0 +1,163 @@
|
||||||
|
---
|
||||||
|
orphan: true
|
||||||
|
---
|
||||||
|
# NVIDIA NEMO
|
||||||
|
|
||||||
|
[NVIDIA NEMO](https://developer.nvidia.com/nemo-framework) is a remote post training provider for Llama Stack. It provides enterprise-grade fine-tuning capabilities through NVIDIA's NeMo Customizer service.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- Enterprise-grade fine-tuning capabilities
|
||||||
|
- Support for LoRA and SFT fine-tuning
|
||||||
|
- Integration with NVIDIA's NeMo Customizer service
|
||||||
|
- Support for various NVIDIA-optimized models
|
||||||
|
- Efficient training with NVIDIA hardware acceleration
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To use NVIDIA NEMO in your Llama Stack project, follow these steps:
|
||||||
|
|
||||||
|
1. Configure your Llama Stack project to use this provider.
|
||||||
|
2. Set up your NVIDIA API credentials.
|
||||||
|
3. Kick off a fine-tuning job using the Llama Stack post_training API.
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
You'll need to set the following environment variables:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export NVIDIA_API_KEY="your-api-key"
|
||||||
|
export NVIDIA_DATASET_NAMESPACE="default"
|
||||||
|
export NVIDIA_CUSTOMIZER_URL="your-customizer-url"
|
||||||
|
export NVIDIA_PROJECT_ID="your-project-id"
|
||||||
|
export NVIDIA_OUTPUT_MODEL_DIR="your-output-model-dir"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Run Training
|
||||||
|
|
||||||
|
You can access the provider and the `supervised_fine_tune` method via the post_training API:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from llama_stack_client.types import (
|
||||||
|
post_training_supervised_fine_tune_params,
|
||||||
|
algorithm_config_param,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_http_client():
|
||||||
|
from llama_stack_client import LlamaStackClient
|
||||||
|
|
||||||
|
return LlamaStackClient(base_url="http://localhost:8321")
|
||||||
|
|
||||||
|
|
||||||
|
client = create_http_client()
|
||||||
|
|
||||||
|
# Example Dataset
|
||||||
|
client.datasets.register(
|
||||||
|
purpose="post-training/messages",
|
||||||
|
source={
|
||||||
|
"type": "uri",
|
||||||
|
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
|
||||||
|
},
|
||||||
|
dataset_id="simpleqa",
|
||||||
|
)
|
||||||
|
|
||||||
|
training_config = post_training_supervised_fine_tune_params.TrainingConfig(
|
||||||
|
data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig(
|
||||||
|
batch_size=8, # Default batch size for NEMO
|
||||||
|
data_format="instruct",
|
||||||
|
dataset_id="simpleqa",
|
||||||
|
shuffle=True,
|
||||||
|
),
|
||||||
|
n_epochs=50, # Default epochs for NEMO
|
||||||
|
optimizer_config=post_training_supervised_fine_tune_params.TrainingConfigOptimizerConfig(
|
||||||
|
lr=0.0001, # Default learning rate
|
||||||
|
weight_decay=0.01, # NEMO-specific parameter
|
||||||
|
),
|
||||||
|
# NEMO-specific parameters
|
||||||
|
log_every_n_steps=None,
|
||||||
|
val_check_interval=0.25,
|
||||||
|
sequence_packing_enabled=False,
|
||||||
|
hidden_dropout=None,
|
||||||
|
attention_dropout=None,
|
||||||
|
ffn_dropout=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
algorithm_config = algorithm_config_param.LoraFinetuningConfig(
|
||||||
|
alpha=16, # Default alpha for NEMO
|
||||||
|
type="LoRA",
|
||||||
|
)
|
||||||
|
|
||||||
|
job_uuid = f"test-job{uuid.uuid4()}"
|
||||||
|
|
||||||
|
# Example Model - must be a supported NEMO model
|
||||||
|
training_model = "meta/llama-3.1-8b-instruct"
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
response = client.post_training.supervised_fine_tune(
|
||||||
|
job_uuid=job_uuid,
|
||||||
|
logger_config={},
|
||||||
|
model=training_model,
|
||||||
|
hyperparam_search_config={},
|
||||||
|
training_config=training_config,
|
||||||
|
algorithm_config=algorithm_config,
|
||||||
|
checkpoint_dir="output",
|
||||||
|
)
|
||||||
|
print("Job: ", job_uuid)
|
||||||
|
|
||||||
|
# Wait for the job to complete!
|
||||||
|
while True:
|
||||||
|
status = client.post_training.job.status(job_uuid=job_uuid)
|
||||||
|
if not status:
|
||||||
|
print("Job not found")
|
||||||
|
break
|
||||||
|
|
||||||
|
print(status)
|
||||||
|
if status.status == "completed":
|
||||||
|
break
|
||||||
|
|
||||||
|
print("Waiting for job to complete...")
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
print("Job completed in", end_time - start_time, "seconds!")
|
||||||
|
|
||||||
|
print("Artifacts:")
|
||||||
|
print(client.post_training.job.artifacts(job_uuid=job_uuid))
|
||||||
|
```
|
||||||
|
|
||||||
|
## Supported Models
|
||||||
|
|
||||||
|
Currently supports the following models:
|
||||||
|
- meta/llama-3.1-8b-instruct
|
||||||
|
- meta/llama-3.2-1b-instruct
|
||||||
|
|
||||||
|
## Supported Parameters
|
||||||
|
|
||||||
|
### TrainingConfig
|
||||||
|
- n_epochs (default: 50)
|
||||||
|
- data_config
|
||||||
|
- optimizer_config
|
||||||
|
- log_every_n_steps
|
||||||
|
- val_check_interval (default: 0.25)
|
||||||
|
- sequence_packing_enabled (default: False)
|
||||||
|
- hidden_dropout (0.0-1.0)
|
||||||
|
- attention_dropout (0.0-1.0)
|
||||||
|
- ffn_dropout (0.0-1.0)
|
||||||
|
|
||||||
|
### DataConfig
|
||||||
|
- dataset_id
|
||||||
|
- batch_size (default: 8)
|
||||||
|
|
||||||
|
### OptimizerConfig
|
||||||
|
- lr (default: 0.0001)
|
||||||
|
- weight_decay (default: 0.01)
|
||||||
|
|
||||||
|
### LoRA Config
|
||||||
|
- alpha (default: 16)
|
||||||
|
- type (must be "LoRA")
|
||||||
|
|
||||||
|
Note: Some parameters from the standard Llama Stack API are not supported and will be ignored with a warning.
|
125
docs/source/providers/post_training/torchtune.md
Normal file
125
docs/source/providers/post_training/torchtune.md
Normal file
|
@ -0,0 +1,125 @@
|
||||||
|
---
|
||||||
|
orphan: true
|
||||||
|
---
|
||||||
|
# TorchTune
|
||||||
|
|
||||||
|
[TorchTune](https://github.com/pytorch/torchtune) is an inline post training provider for Llama Stack. It provides a simple and efficient way to fine-tune language models using PyTorch.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- Simple access through the post_training API
|
||||||
|
- Fully integrated with Llama Stack
|
||||||
|
- GPU support and single device capabilities.
|
||||||
|
- Support for LoRA
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To use TorchTune in your Llama Stack project, follow these steps:
|
||||||
|
|
||||||
|
1. Configure your Llama Stack project to use this provider.
|
||||||
|
2. Kick off a fine-tuning job using the Llama Stack post_training API.
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
You can access the TorchTune trainer by writing your own yaml pointing to the provider:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
post_training:
|
||||||
|
- provider_id: torchtune
|
||||||
|
provider_type: inline::torchtune
|
||||||
|
config: {}
|
||||||
|
```
|
||||||
|
|
||||||
|
you can then build and run your own stack with this provider.
|
||||||
|
|
||||||
|
## Run Training
|
||||||
|
|
||||||
|
You can access the provider and the `supervised_fine_tune` method via the post_training API:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from llama_stack_client.types import (
|
||||||
|
post_training_supervised_fine_tune_params,
|
||||||
|
algorithm_config_param,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_http_client():
|
||||||
|
from llama_stack_client import LlamaStackClient
|
||||||
|
|
||||||
|
return LlamaStackClient(base_url="http://localhost:8321")
|
||||||
|
|
||||||
|
|
||||||
|
client = create_http_client()
|
||||||
|
|
||||||
|
# Example Dataset
|
||||||
|
client.datasets.register(
|
||||||
|
purpose="post-training/messages",
|
||||||
|
source={
|
||||||
|
"type": "uri",
|
||||||
|
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
|
||||||
|
},
|
||||||
|
dataset_id="simpleqa",
|
||||||
|
)
|
||||||
|
|
||||||
|
training_config = post_training_supervised_fine_tune_params.TrainingConfig(
|
||||||
|
data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig(
|
||||||
|
batch_size=32,
|
||||||
|
data_format="instruct",
|
||||||
|
dataset_id="simpleqa",
|
||||||
|
shuffle=True,
|
||||||
|
),
|
||||||
|
gradient_accumulation_steps=1,
|
||||||
|
max_steps_per_epoch=0,
|
||||||
|
max_validation_steps=1,
|
||||||
|
n_epochs=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
algorithm_config = algorithm_config_param.LoraFinetuningConfig(
|
||||||
|
alpha=1,
|
||||||
|
apply_lora_to_mlp=True,
|
||||||
|
apply_lora_to_output=False,
|
||||||
|
lora_attn_modules=["q_proj"],
|
||||||
|
rank=1,
|
||||||
|
type="LoRA",
|
||||||
|
)
|
||||||
|
|
||||||
|
job_uuid = f"test-job{uuid.uuid4()}"
|
||||||
|
|
||||||
|
# Example Model
|
||||||
|
training_model = "meta-llama/Llama-2-7b-hf"
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
response = client.post_training.supervised_fine_tune(
|
||||||
|
job_uuid=job_uuid,
|
||||||
|
logger_config={},
|
||||||
|
model=training_model,
|
||||||
|
hyperparam_search_config={},
|
||||||
|
training_config=training_config,
|
||||||
|
algorithm_config=algorithm_config,
|
||||||
|
checkpoint_dir="output",
|
||||||
|
)
|
||||||
|
print("Job: ", job_uuid)
|
||||||
|
|
||||||
|
# Wait for the job to complete!
|
||||||
|
while True:
|
||||||
|
status = client.post_training.job.status(job_uuid=job_uuid)
|
||||||
|
if not status:
|
||||||
|
print("Job not found")
|
||||||
|
break
|
||||||
|
|
||||||
|
print(status)
|
||||||
|
if status.status == "completed":
|
||||||
|
break
|
||||||
|
|
||||||
|
print("Waiting for job to complete...")
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
print("Job completed in", end_time - start_time, "seconds!")
|
||||||
|
|
||||||
|
print("Artifacts:")
|
||||||
|
print(client.post_training.job.artifacts(job_uuid=job_uuid))
|
||||||
|
```
|
|
@ -149,6 +149,16 @@ class OpenAIResponseObjectStreamResponseCreated(BaseModel):
|
||||||
type: Literal["response.created"] = "response.created"
|
type: Literal["response.created"] = "response.created"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseObjectStreamResponseOutputTextDelta(BaseModel):
|
||||||
|
content_index: int
|
||||||
|
delta: str
|
||||||
|
item_id: str
|
||||||
|
output_index: int
|
||||||
|
sequence_number: int
|
||||||
|
type: Literal["response.output_text.delta"] = "response.output_text.delta"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIResponseObjectStreamResponseCompleted(BaseModel):
|
class OpenAIResponseObjectStreamResponseCompleted(BaseModel):
|
||||||
response: OpenAIResponseObject
|
response: OpenAIResponseObject
|
||||||
|
@ -156,7 +166,9 @@ class OpenAIResponseObjectStreamResponseCompleted(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
OpenAIResponseObjectStream = Annotated[
|
OpenAIResponseObjectStream = Annotated[
|
||||||
OpenAIResponseObjectStreamResponseCreated | OpenAIResponseObjectStreamResponseCompleted,
|
OpenAIResponseObjectStreamResponseCreated
|
||||||
|
| OpenAIResponseObjectStreamResponseOutputTextDelta
|
||||||
|
| OpenAIResponseObjectStreamResponseCompleted,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(OpenAIResponseObjectStream, name="OpenAIResponseObjectStream")
|
register_schema(OpenAIResponseObjectStream, name="OpenAIResponseObjectStream")
|
||||||
|
|
|
@ -1,30 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
|
||||||
from llama_stack.schema_utils import json_schema_type
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class RestAPIMethod(Enum):
|
|
||||||
GET = "GET"
|
|
||||||
POST = "POST"
|
|
||||||
PUT = "PUT"
|
|
||||||
DELETE = "DELETE"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class RestAPIExecutionConfig(BaseModel):
|
|
||||||
url: URL
|
|
||||||
method: RestAPIMethod
|
|
||||||
params: dict[str, Any] | None = None
|
|
||||||
headers: dict[str, Any] | None = None
|
|
||||||
body: dict[str, Any] | None = None
|
|
|
@ -27,18 +27,10 @@ class ToolParameter(BaseModel):
|
||||||
default: Any | None = None
|
default: Any | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ToolHost(Enum):
|
|
||||||
distribution = "distribution"
|
|
||||||
client = "client"
|
|
||||||
model_context_protocol = "model_context_protocol"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Tool(Resource):
|
class Tool(Resource):
|
||||||
type: Literal[ResourceType.tool] = ResourceType.tool
|
type: Literal[ResourceType.tool] = ResourceType.tool
|
||||||
toolgroup_id: str
|
toolgroup_id: str
|
||||||
tool_host: ToolHost
|
|
||||||
description: str
|
description: str
|
||||||
parameters: list[ToolParameter]
|
parameters: list[ToolParameter]
|
||||||
metadata: dict[str, Any] | None = None
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
|
@ -267,8 +267,8 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
if args.run:
|
if args.run:
|
||||||
config_dict = yaml.safe_load(run_config.read_text())
|
config_dict = yaml.safe_load(run_config.read_text())
|
||||||
config = parse_and_maybe_upgrade_config(config_dict)
|
config = parse_and_maybe_upgrade_config(config_dict)
|
||||||
if not os.path.exists(config.external_providers_dir):
|
if config.external_providers_dir and not config.external_providers_dir.exists():
|
||||||
os.makedirs(config.external_providers_dir, exist_ok=True)
|
config.external_providers_dir.mkdir(exist_ok=True)
|
||||||
run_args = formulate_run_args(args.image_type, args.image_name, config, args.template)
|
run_args = formulate_run_args(args.image_type, args.image_name, config, args.template)
|
||||||
run_args.extend([str(os.getenv("LLAMA_STACK_PORT", 8321)), "--config", run_config])
|
run_args.extend([str(os.getenv("LLAMA_STACK_PORT", 8321)), "--config", run_config])
|
||||||
run_command(run_args)
|
run_command(run_args)
|
||||||
|
|
|
@ -125,7 +125,6 @@ RUN apt-get update && apt-get install -y \
|
||||||
curl wget telnet git\
|
curl wget telnet git\
|
||||||
procps psmisc lsof \
|
procps psmisc lsof \
|
||||||
traceroute \
|
traceroute \
|
||||||
bubblewrap \
|
|
||||||
gcc \
|
gcc \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ from llama_stack.apis.inspect import (
|
||||||
VersionInfo,
|
VersionInfo,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.datatypes import StackRunConfig
|
from llama_stack.distribution.datatypes import StackRunConfig
|
||||||
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
from llama_stack.distribution.server.routes import get_all_api_routes
|
||||||
from llama_stack.providers.datatypes import HealthStatus
|
from llama_stack.providers.datatypes import HealthStatus
|
||||||
|
|
||||||
|
|
||||||
|
@ -42,15 +42,15 @@ class DistributionInspectImpl(Inspect):
|
||||||
run_config: StackRunConfig = self.config.run_config
|
run_config: StackRunConfig = self.config.run_config
|
||||||
|
|
||||||
ret = []
|
ret = []
|
||||||
all_endpoints = get_all_api_endpoints()
|
all_endpoints = get_all_api_routes()
|
||||||
for api, endpoints in all_endpoints.items():
|
for api, endpoints in all_endpoints.items():
|
||||||
# Always include provider and inspect APIs, filter others based on run config
|
# Always include provider and inspect APIs, filter others based on run config
|
||||||
if api.value in ["providers", "inspect"]:
|
if api.value in ["providers", "inspect"]:
|
||||||
ret.extend(
|
ret.extend(
|
||||||
[
|
[
|
||||||
RouteInfo(
|
RouteInfo(
|
||||||
route=e.route,
|
route=e.path,
|
||||||
method=e.method,
|
method=next(iter([m for m in e.methods if m != "HEAD"])),
|
||||||
provider_types=[], # These APIs don't have "real" providers - they're internal to the stack
|
provider_types=[], # These APIs don't have "real" providers - they're internal to the stack
|
||||||
)
|
)
|
||||||
for e in endpoints
|
for e in endpoints
|
||||||
|
@ -62,8 +62,8 @@ class DistributionInspectImpl(Inspect):
|
||||||
ret.extend(
|
ret.extend(
|
||||||
[
|
[
|
||||||
RouteInfo(
|
RouteInfo(
|
||||||
route=e.route,
|
route=e.path,
|
||||||
method=e.method,
|
method=next(iter([m for m in e.methods if m != "HEAD"])),
|
||||||
provider_types=[p.provider_type for p in providers],
|
provider_types=[p.provider_type for p in providers],
|
||||||
)
|
)
|
||||||
for e in endpoints
|
for e in endpoints
|
||||||
|
|
|
@ -37,10 +37,7 @@ from llama_stack.distribution.request_headers import (
|
||||||
request_provider_data_context,
|
request_provider_data_context,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.resolver import ProviderRegistry
|
from llama_stack.distribution.resolver import ProviderRegistry
|
||||||
from llama_stack.distribution.server.endpoints import (
|
from llama_stack.distribution.server.routes import find_matching_route, initialize_route_impls
|
||||||
find_matching_endpoint,
|
|
||||||
initialize_endpoint_impls,
|
|
||||||
)
|
|
||||||
from llama_stack.distribution.stack import (
|
from llama_stack.distribution.stack import (
|
||||||
construct_stack,
|
construct_stack,
|
||||||
get_stack_run_config_from_template,
|
get_stack_run_config_from_template,
|
||||||
|
@ -208,7 +205,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
|
|
||||||
async def initialize(self) -> bool:
|
async def initialize(self) -> bool:
|
||||||
try:
|
try:
|
||||||
self.endpoint_impls = None
|
self.route_impls = None
|
||||||
self.impls = await construct_stack(self.config, self.custom_provider_registry)
|
self.impls = await construct_stack(self.config, self.custom_provider_registry)
|
||||||
except ModuleNotFoundError as _e:
|
except ModuleNotFoundError as _e:
|
||||||
cprint(_e.msg, color="red", file=sys.stderr)
|
cprint(_e.msg, color="red", file=sys.stderr)
|
||||||
|
@ -254,7 +251,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
safe_config = redact_sensitive_fields(self.config.model_dump())
|
safe_config = redact_sensitive_fields(self.config.model_dump())
|
||||||
console.print(yaml.dump(safe_config, indent=2))
|
console.print(yaml.dump(safe_config, indent=2))
|
||||||
|
|
||||||
self.endpoint_impls = initialize_endpoint_impls(self.impls)
|
self.route_impls = initialize_route_impls(self.impls)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def request(
|
async def request(
|
||||||
|
@ -265,7 +262,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
stream=False,
|
stream=False,
|
||||||
stream_cls=None,
|
stream_cls=None,
|
||||||
):
|
):
|
||||||
if not self.endpoint_impls:
|
if not self.route_impls:
|
||||||
raise ValueError("Client not initialized")
|
raise ValueError("Client not initialized")
|
||||||
|
|
||||||
# Create headers with provider data if available
|
# Create headers with provider data if available
|
||||||
|
@ -296,11 +293,14 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
cast_to: Any,
|
cast_to: Any,
|
||||||
options: Any,
|
options: Any,
|
||||||
):
|
):
|
||||||
|
if self.route_impls is None:
|
||||||
|
raise ValueError("Client not initialized")
|
||||||
|
|
||||||
path = options.url
|
path = options.url
|
||||||
body = options.params or {}
|
body = options.params or {}
|
||||||
body |= options.json_data or {}
|
body |= options.json_data or {}
|
||||||
|
|
||||||
matched_func, path_params, route = find_matching_endpoint(options.method, path, self.endpoint_impls)
|
matched_func, path_params, route = find_matching_route(options.method, path, self.route_impls)
|
||||||
body |= path_params
|
body |= path_params
|
||||||
body = self._convert_body(path, options.method, body)
|
body = self._convert_body(path, options.method, body)
|
||||||
await start_trace(route, {"__location__": "library_client"})
|
await start_trace(route, {"__location__": "library_client"})
|
||||||
|
@ -342,10 +342,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
options: Any,
|
options: Any,
|
||||||
stream_cls: Any,
|
stream_cls: Any,
|
||||||
):
|
):
|
||||||
|
if self.route_impls is None:
|
||||||
|
raise ValueError("Client not initialized")
|
||||||
|
|
||||||
path = options.url
|
path = options.url
|
||||||
body = options.params or {}
|
body = options.params or {}
|
||||||
body |= options.json_data or {}
|
body |= options.json_data or {}
|
||||||
func, path_params, route = find_matching_endpoint(options.method, path, self.endpoint_impls)
|
func, path_params, route = find_matching_route(options.method, path, self.route_impls)
|
||||||
body |= path_params
|
body |= path_params
|
||||||
|
|
||||||
body = self._convert_body(path, options.method, body)
|
body = self._convert_body(path, options.method, body)
|
||||||
|
@ -397,7 +400,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
if not body:
|
if not body:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
func, _, _ = find_matching_endpoint(method, path, self.endpoint_impls)
|
if self.route_impls is None:
|
||||||
|
raise ValueError("Client not initialized")
|
||||||
|
|
||||||
|
func, _, _ = find_matching_route(method, path, self.route_impls)
|
||||||
sig = inspect.signature(func)
|
sig = inspect.signature(func)
|
||||||
|
|
||||||
# Strip NOT_GIVENs to use the defaults in signature
|
# Strip NOT_GIVENs to use the defaults in signature
|
||||||
|
|
|
@ -47,7 +47,7 @@ from llama_stack.providers.datatypes import (
|
||||||
RemoteProviderSpec,
|
RemoteProviderSpec,
|
||||||
ScoringFunctionsProtocolPrivate,
|
ScoringFunctionsProtocolPrivate,
|
||||||
ShieldsProtocolPrivate,
|
ShieldsProtocolPrivate,
|
||||||
ToolsProtocolPrivate,
|
ToolGroupsProtocolPrivate,
|
||||||
VectorDBsProtocolPrivate,
|
VectorDBsProtocolPrivate,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -93,7 +93,7 @@ def api_protocol_map_for_compliance_check() -> dict[Api, Any]:
|
||||||
def additional_protocols_map() -> dict[Api, Any]:
|
def additional_protocols_map() -> dict[Api, Any]:
|
||||||
return {
|
return {
|
||||||
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
|
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
|
||||||
Api.tool_groups: (ToolsProtocolPrivate, ToolGroups, Api.tool_groups),
|
Api.tool_groups: (ToolGroupsProtocolPrivate, ToolGroups, Api.tool_groups),
|
||||||
Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs),
|
Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs),
|
||||||
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
|
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
|
||||||
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
|
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
|
||||||
|
|
|
@ -11,7 +11,7 @@ from llama_stack.apis.common.content_types import (
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
ListToolDefsResponse,
|
ListToolsResponse,
|
||||||
RAGDocument,
|
RAGDocument,
|
||||||
RAGQueryConfig,
|
RAGQueryConfig,
|
||||||
RAGQueryResult,
|
RAGQueryResult,
|
||||||
|
@ -19,7 +19,8 @@ from llama_stack.apis.tools import (
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
)
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
|
||||||
|
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
@ -28,7 +29,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
class RagToolImpl(RAGToolRuntime):
|
class RagToolImpl(RAGToolRuntime):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: ToolGroupsRoutingTable,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
|
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
@ -59,7 +60,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: ToolGroupsRoutingTable,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug("Initializing ToolRuntimeRouter")
|
logger.debug("Initializing ToolRuntimeRouter")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
@ -86,6 +87,6 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
|
|
||||||
async def list_runtime_tools(
|
async def list_runtime_tools(
|
||||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||||
) -> ListToolDefsResponse:
|
) -> ListToolsResponse:
|
||||||
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
||||||
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
|
return await self.routing_table.list_tools(tool_group_id)
|
||||||
|
|
|
@ -46,7 +46,7 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
|
||||||
elif api == Api.eval:
|
elif api == Api.eval:
|
||||||
return await p.register_benchmark(obj)
|
return await p.register_benchmark(obj)
|
||||||
elif api == Api.tool_runtime:
|
elif api == Api.tool_runtime:
|
||||||
return await p.register_tool(obj)
|
return await p.register_toolgroup(obj)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown API {api} for registering object with provider")
|
raise ValueError(f"Unknown API {api} for registering object with provider")
|
||||||
|
|
||||||
|
@ -60,7 +60,7 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
||||||
elif api == Api.datasetio:
|
elif api == Api.datasetio:
|
||||||
return await p.unregister_dataset(obj.identifier)
|
return await p.unregister_dataset(obj.identifier)
|
||||||
elif api == Api.tool_runtime:
|
elif api == Api.tool_runtime:
|
||||||
return await p.unregister_tool(obj.identifier)
|
return await p.unregister_toolgroup(obj.identifier)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unregister not supported for {api}")
|
raise ValueError(f"Unregister not supported for {api}")
|
||||||
|
|
||||||
|
@ -136,7 +136,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
elif isinstance(self, BenchmarksRoutingTable):
|
elif isinstance(self, BenchmarksRoutingTable):
|
||||||
return ("Eval", "benchmark")
|
return ("Eval", "benchmark")
|
||||||
elif isinstance(self, ToolGroupsRoutingTable):
|
elif isinstance(self, ToolGroupsRoutingTable):
|
||||||
return ("Tools", "tool")
|
return ("ToolGroups", "tool_group")
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown routing table type")
|
raise ValueError("Unknown routing table type")
|
||||||
|
|
||||||
|
|
|
@ -7,11 +7,8 @@
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups, ToolHost
|
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import ToolGroupWithACL
|
||||||
ToolGroupWithACL,
|
|
||||||
ToolWithACL,
|
|
||||||
)
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
from .common import CommonRoutingTableImpl
|
||||||
|
@ -19,12 +16,70 @@ from .common import CommonRoutingTableImpl
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None:
|
||||||
|
# handle the funny case like "builtin::rag/knowledge_search"
|
||||||
|
parts = toolgroup_name_with_maybe_tool_name.split("/")
|
||||||
|
if len(parts) == 2:
|
||||||
|
return parts[0]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
toolgroups_to_tools: dict[str, list[Tool]] = {}
|
||||||
tools = await self.get_all_with_type("tool")
|
tool_to_toolgroup: dict[str, str] = {}
|
||||||
|
|
||||||
|
# overridden
|
||||||
|
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
||||||
|
# we don't index tools in the registry anymore, but only keep a cache of them by toolgroup_id
|
||||||
|
# TODO: we may want to invalidate the cache (for a given toolgroup_id) every once in a while?
|
||||||
|
|
||||||
|
toolgroup_id = parse_toolgroup_from_toolgroup_name_pair(routing_key)
|
||||||
if toolgroup_id:
|
if toolgroup_id:
|
||||||
tools = [tool for tool in tools if tool.toolgroup_id == toolgroup_id]
|
routing_key = toolgroup_id
|
||||||
return ListToolsResponse(data=tools)
|
|
||||||
|
if routing_key in self.tool_to_toolgroup:
|
||||||
|
routing_key = self.tool_to_toolgroup[routing_key]
|
||||||
|
return super().get_provider_impl(routing_key, provider_id)
|
||||||
|
|
||||||
|
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
||||||
|
if toolgroup_id:
|
||||||
|
if group_id := parse_toolgroup_from_toolgroup_name_pair(toolgroup_id):
|
||||||
|
toolgroup_id = group_id
|
||||||
|
toolgroups = [await self.get_tool_group(toolgroup_id)]
|
||||||
|
else:
|
||||||
|
toolgroups = await self.get_all_with_type("tool_group")
|
||||||
|
|
||||||
|
all_tools = []
|
||||||
|
for toolgroup in toolgroups:
|
||||||
|
if toolgroup.identifier not in self.toolgroups_to_tools:
|
||||||
|
await self._index_tools(toolgroup)
|
||||||
|
all_tools.extend(self.toolgroups_to_tools[toolgroup.identifier])
|
||||||
|
|
||||||
|
return ListToolsResponse(data=all_tools)
|
||||||
|
|
||||||
|
async def _index_tools(self, toolgroup: ToolGroup):
|
||||||
|
provider_impl = super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id)
|
||||||
|
tooldefs_response = await provider_impl.list_runtime_tools(toolgroup.identifier, toolgroup.mcp_endpoint)
|
||||||
|
|
||||||
|
# TODO: kill this Tool vs ToolDef distinction
|
||||||
|
tooldefs = tooldefs_response.data
|
||||||
|
tools = []
|
||||||
|
for t in tooldefs:
|
||||||
|
tools.append(
|
||||||
|
Tool(
|
||||||
|
identifier=t.name,
|
||||||
|
toolgroup_id=toolgroup.identifier,
|
||||||
|
description=t.description or "",
|
||||||
|
parameters=t.parameters or [],
|
||||||
|
metadata=t.metadata,
|
||||||
|
provider_id=toolgroup.provider_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.toolgroups_to_tools[toolgroup.identifier] = tools
|
||||||
|
for tool in tools:
|
||||||
|
self.tool_to_toolgroup[tool.identifier] = toolgroup.identifier
|
||||||
|
|
||||||
async def list_tool_groups(self) -> ListToolGroupsResponse:
|
async def list_tool_groups(self) -> ListToolGroupsResponse:
|
||||||
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
|
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
|
||||||
|
@ -36,7 +91,13 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
return tool_group
|
return tool_group
|
||||||
|
|
||||||
async def get_tool(self, tool_name: str) -> Tool:
|
async def get_tool(self, tool_name: str) -> Tool:
|
||||||
return await self.get_object_by_identifier("tool", tool_name)
|
if tool_name in self.tool_to_toolgroup:
|
||||||
|
toolgroup_id = self.tool_to_toolgroup[tool_name]
|
||||||
|
tools = self.toolgroups_to_tools[toolgroup_id]
|
||||||
|
for tool in tools:
|
||||||
|
if tool.identifier == tool_name:
|
||||||
|
return tool
|
||||||
|
raise ValueError(f"Tool '{tool_name}' not found")
|
||||||
|
|
||||||
async def register_tool_group(
|
async def register_tool_group(
|
||||||
self,
|
self,
|
||||||
|
@ -45,53 +106,26 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
mcp_endpoint: URL | None = None,
|
mcp_endpoint: URL | None = None,
|
||||||
args: dict[str, Any] | None = None,
|
args: dict[str, Any] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
tools = []
|
toolgroup = ToolGroupWithACL(
|
||||||
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint)
|
|
||||||
tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
|
|
||||||
|
|
||||||
for tool_def in tool_defs.data:
|
|
||||||
tools.append(
|
|
||||||
ToolWithACL(
|
|
||||||
identifier=tool_def.name,
|
|
||||||
toolgroup_id=toolgroup_id,
|
|
||||||
description=tool_def.description or "",
|
|
||||||
parameters=tool_def.parameters or [],
|
|
||||||
provider_id=provider_id,
|
|
||||||
provider_resource_id=tool_def.name,
|
|
||||||
metadata=tool_def.metadata,
|
|
||||||
tool_host=tool_host,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
for tool in tools:
|
|
||||||
existing_tool = await self.get_tool(tool.identifier)
|
|
||||||
# Compare existing and new object if one exists
|
|
||||||
if existing_tool:
|
|
||||||
existing_dict = existing_tool.model_dump()
|
|
||||||
new_dict = tool.model_dump()
|
|
||||||
|
|
||||||
if existing_dict != new_dict:
|
|
||||||
raise ValueError(
|
|
||||||
f"Object {tool.identifier} already exists in registry. Please use a different identifier."
|
|
||||||
)
|
|
||||||
await self.register_object(tool)
|
|
||||||
|
|
||||||
await self.dist_registry.register(
|
|
||||||
ToolGroupWithACL(
|
|
||||||
identifier=toolgroup_id,
|
identifier=toolgroup_id,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
provider_resource_id=toolgroup_id,
|
provider_resource_id=toolgroup_id,
|
||||||
mcp_endpoint=mcp_endpoint,
|
mcp_endpoint=mcp_endpoint,
|
||||||
args=args,
|
args=args,
|
||||||
)
|
)
|
||||||
)
|
await self.register_object(toolgroup)
|
||||||
|
|
||||||
|
# ideally, indexing of the tools should not be necessary because anyone using
|
||||||
|
# the tools should first list the tools and then use them. but there are assumptions
|
||||||
|
# baked in some of the code and tests right now.
|
||||||
|
if not toolgroup.mcp_endpoint:
|
||||||
|
await self._index_tools(toolgroup)
|
||||||
|
return toolgroup
|
||||||
|
|
||||||
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||||
tool_group = await self.get_tool_group(toolgroup_id)
|
tool_group = await self.get_tool_group(toolgroup_id)
|
||||||
if tool_group is None:
|
if tool_group is None:
|
||||||
raise ValueError(f"Tool group {toolgroup_id} not found")
|
raise ValueError(f"Tool group {toolgroup_id} not found")
|
||||||
tools = await self.list_tools(toolgroup_id)
|
|
||||||
for tool in getattr(tools, "data", []):
|
|
||||||
await self.unregister_object(tool)
|
|
||||||
await self.unregister_object(tool_group)
|
await self.unregister_object(tool_group)
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
|
|
|
@ -6,20 +6,23 @@
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import re
|
import re
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from aiohttp import hdrs
|
||||||
|
from starlette.routing import Route
|
||||||
|
|
||||||
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
|
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
|
||||||
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
||||||
from llama_stack.distribution.resolver import api_protocol_map
|
from llama_stack.distribution.resolver import api_protocol_map
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
|
EndpointFunc = Callable[..., Any]
|
||||||
class ApiEndpoint(BaseModel):
|
PathParams = dict[str, str]
|
||||||
route: str
|
RouteInfo = tuple[EndpointFunc, str]
|
||||||
method: str
|
PathImpl = dict[str, RouteInfo]
|
||||||
name: str
|
RouteImpls = dict[str, PathImpl]
|
||||||
descriptive_name: str | None = None
|
RouteMatch = tuple[EndpointFunc, PathParams, str]
|
||||||
|
|
||||||
|
|
||||||
def toolgroup_protocol_map():
|
def toolgroup_protocol_map():
|
||||||
|
@ -28,13 +31,13 @@ def toolgroup_protocol_map():
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_all_api_endpoints() -> dict[Api, list[ApiEndpoint]]:
|
def get_all_api_routes() -> dict[Api, list[Route]]:
|
||||||
apis = {}
|
apis = {}
|
||||||
|
|
||||||
protocols = api_protocol_map()
|
protocols = api_protocol_map()
|
||||||
toolgroup_protocols = toolgroup_protocol_map()
|
toolgroup_protocols = toolgroup_protocol_map()
|
||||||
for api, protocol in protocols.items():
|
for api, protocol in protocols.items():
|
||||||
endpoints = []
|
routes = []
|
||||||
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
|
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
|
||||||
|
|
||||||
# HACK ALERT
|
# HACK ALERT
|
||||||
|
@ -51,26 +54,28 @@ def get_all_api_endpoints() -> dict[Api, list[ApiEndpoint]]:
|
||||||
if not hasattr(method, "__webmethod__"):
|
if not hasattr(method, "__webmethod__"):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
webmethod = method.__webmethod__
|
# The __webmethod__ attribute is dynamically added by the @webmethod decorator
|
||||||
route = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
|
# mypy doesn't know about this dynamic attribute, so we ignore the attr-defined error
|
||||||
if webmethod.method == "GET":
|
webmethod = method.__webmethod__ # type: ignore[attr-defined]
|
||||||
method = "get"
|
path = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
|
||||||
elif webmethod.method == "DELETE":
|
if webmethod.method == hdrs.METH_GET:
|
||||||
method = "delete"
|
http_method = hdrs.METH_GET
|
||||||
|
elif webmethod.method == hdrs.METH_DELETE:
|
||||||
|
http_method = hdrs.METH_DELETE
|
||||||
else:
|
else:
|
||||||
method = "post"
|
http_method = hdrs.METH_POST
|
||||||
endpoints.append(
|
routes.append(
|
||||||
ApiEndpoint(route=route, method=method, name=name, descriptive_name=webmethod.descriptive_name)
|
Route(path=path, methods=[http_method], name=name, endpoint=None)
|
||||||
)
|
) # setting endpoint to None since don't use a Router object
|
||||||
|
|
||||||
apis[api] = endpoints
|
apis[api] = routes
|
||||||
|
|
||||||
return apis
|
return apis
|
||||||
|
|
||||||
|
|
||||||
def initialize_endpoint_impls(impls):
|
def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls:
|
||||||
endpoints = get_all_api_endpoints()
|
routes = get_all_api_routes()
|
||||||
endpoint_impls = {}
|
route_impls: RouteImpls = {}
|
||||||
|
|
||||||
def _convert_path_to_regex(path: str) -> str:
|
def _convert_path_to_regex(path: str) -> str:
|
||||||
# Convert {param} to named capture groups
|
# Convert {param} to named capture groups
|
||||||
|
@ -83,29 +88,34 @@ def initialize_endpoint_impls(impls):
|
||||||
|
|
||||||
return f"^{pattern}$"
|
return f"^{pattern}$"
|
||||||
|
|
||||||
for api, api_endpoints in endpoints.items():
|
for api, api_routes in routes.items():
|
||||||
if api not in impls:
|
if api not in impls:
|
||||||
continue
|
continue
|
||||||
for endpoint in api_endpoints:
|
for route in api_routes:
|
||||||
impl = impls[api]
|
impl = impls[api]
|
||||||
func = getattr(impl, endpoint.name)
|
func = getattr(impl, route.name)
|
||||||
if endpoint.method not in endpoint_impls:
|
# Get the first (and typically only) method from the set, filtering out HEAD
|
||||||
endpoint_impls[endpoint.method] = {}
|
available_methods = [m for m in route.methods if m != "HEAD"]
|
||||||
endpoint_impls[endpoint.method][_convert_path_to_regex(endpoint.route)] = (
|
if not available_methods:
|
||||||
|
continue # Skip if only HEAD method is available
|
||||||
|
method = available_methods[0].lower()
|
||||||
|
if method not in route_impls:
|
||||||
|
route_impls[method] = {}
|
||||||
|
route_impls[method][_convert_path_to_regex(route.path)] = (
|
||||||
func,
|
func,
|
||||||
endpoint.descriptive_name or endpoint.route,
|
route.path,
|
||||||
)
|
)
|
||||||
|
|
||||||
return endpoint_impls
|
return route_impls
|
||||||
|
|
||||||
|
|
||||||
def find_matching_endpoint(method, path, endpoint_impls):
|
def find_matching_route(method: str, path: str, route_impls: RouteImpls) -> RouteMatch:
|
||||||
"""Find the matching endpoint implementation for a given method and path.
|
"""Find the matching endpoint implementation for a given method and path.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
method: HTTP method (GET, POST, etc.)
|
method: HTTP method (GET, POST, etc.)
|
||||||
path: URL path to match against
|
path: URL path to match against
|
||||||
endpoint_impls: A dictionary of endpoint implementations
|
route_impls: A dictionary of endpoint implementations
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of (endpoint_function, path_params, descriptive_name)
|
A tuple of (endpoint_function, path_params, descriptive_name)
|
||||||
|
@ -113,7 +123,7 @@ def find_matching_endpoint(method, path, endpoint_impls):
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If no matching endpoint is found
|
ValueError: If no matching endpoint is found
|
||||||
"""
|
"""
|
||||||
impls = endpoint_impls.get(method.lower())
|
impls = route_impls.get(method.lower())
|
||||||
if not impls:
|
if not impls:
|
||||||
raise ValueError(f"No endpoint found for {path}")
|
raise ValueError(f"No endpoint found for {path}")
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
@ -13,6 +14,7 @@ import ssl
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
|
from collections.abc import Callable
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from importlib.metadata import version as parse_version
|
from importlib.metadata import version as parse_version
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -20,6 +22,7 @@ from typing import Annotated, Any
|
||||||
|
|
||||||
import rich.pretty
|
import rich.pretty
|
||||||
import yaml
|
import yaml
|
||||||
|
from aiohttp import hdrs
|
||||||
from fastapi import Body, FastAPI, HTTPException, Request
|
from fastapi import Body, FastAPI, HTTPException, Request
|
||||||
from fastapi import Path as FastapiPath
|
from fastapi import Path as FastapiPath
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
@ -35,9 +38,10 @@ from llama_stack.distribution.request_headers import (
|
||||||
request_provider_data_context,
|
request_provider_data_context,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.resolver import InvalidProviderError
|
from llama_stack.distribution.resolver import InvalidProviderError
|
||||||
from llama_stack.distribution.server.endpoints import (
|
from llama_stack.distribution.server.routes import (
|
||||||
find_matching_endpoint,
|
find_matching_route,
|
||||||
initialize_endpoint_impls,
|
get_all_api_routes,
|
||||||
|
initialize_route_impls,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.stack import (
|
from llama_stack.distribution.stack import (
|
||||||
construct_stack,
|
construct_stack,
|
||||||
|
@ -60,7 +64,6 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from .auth import AuthenticationMiddleware
|
from .auth import AuthenticationMiddleware
|
||||||
from .endpoints import get_all_api_endpoints
|
|
||||||
from .quota import QuotaMiddleware
|
from .quota import QuotaMiddleware
|
||||||
|
|
||||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
@ -209,8 +212,9 @@ async def log_request_pre_validation(request: Request):
|
||||||
logger.warning(f"Could not read or log request body for {request.method} {request.url.path}: {e}")
|
logger.warning(f"Could not read or log request body for {request.method} {request.url.path}: {e}")
|
||||||
|
|
||||||
|
|
||||||
def create_dynamic_typed_route(func: Any, method: str, route: str):
|
def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
||||||
async def endpoint(request: Request, **kwargs):
|
@functools.wraps(func)
|
||||||
|
async def route_handler(request: Request, **kwargs):
|
||||||
# Get auth attributes from the request scope
|
# Get auth attributes from the request scope
|
||||||
user_attributes = request.scope.get("user_attributes", {})
|
user_attributes = request.scope.get("user_attributes", {})
|
||||||
|
|
||||||
|
@ -250,9 +254,9 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
|
||||||
for param in new_params[1:]
|
for param in new_params[1:]
|
||||||
]
|
]
|
||||||
|
|
||||||
endpoint.__signature__ = sig.replace(parameters=new_params)
|
route_handler.__signature__ = sig.replace(parameters=new_params)
|
||||||
|
|
||||||
return endpoint
|
return route_handler
|
||||||
|
|
||||||
|
|
||||||
class TracingMiddleware:
|
class TracingMiddleware:
|
||||||
|
@ -274,14 +278,14 @@ class TracingMiddleware:
|
||||||
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
|
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
|
||||||
return await self.app(scope, receive, send)
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
if not hasattr(self, "endpoint_impls"):
|
if not hasattr(self, "route_impls"):
|
||||||
self.endpoint_impls = initialize_endpoint_impls(self.impls)
|
self.route_impls = initialize_route_impls(self.impls)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
_, _, trace_path = find_matching_endpoint(scope.get("method", "GET"), path, self.endpoint_impls)
|
_, _, trace_path = find_matching_route(scope.get("method", hdrs.METH_GET), path, self.route_impls)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
# If no matching endpoint is found, pass through to FastAPI
|
# If no matching endpoint is found, pass through to FastAPI
|
||||||
logger.debug(f"No matching endpoint found for path: {path}, falling back to FastAPI")
|
logger.debug(f"No matching route found for path: {path}, falling back to FastAPI")
|
||||||
return await self.app(scope, receive, send)
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
trace_attributes = {"__location__": "server", "raw_path": path}
|
trace_attributes = {"__location__": "server", "raw_path": path}
|
||||||
|
@ -423,7 +427,7 @@ def main(args: argparse.Namespace | None = None):
|
||||||
|
|
||||||
logger.info("Run configuration:")
|
logger.info("Run configuration:")
|
||||||
safe_config = redact_sensitive_fields(config.model_dump())
|
safe_config = redact_sensitive_fields(config.model_dump())
|
||||||
logger.info(yaml.dump(safe_config, indent=2))
|
logger.info(yaml.dump(safe_config, indent=2, default_style=None))
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
lifespan=lifespan,
|
lifespan=lifespan,
|
||||||
|
@ -490,7 +494,7 @@ def main(args: argparse.Namespace | None = None):
|
||||||
else:
|
else:
|
||||||
setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
|
setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
|
||||||
|
|
||||||
all_endpoints = get_all_api_endpoints()
|
all_routes = get_all_api_routes()
|
||||||
|
|
||||||
if config.apis:
|
if config.apis:
|
||||||
apis_to_serve = set(config.apis)
|
apis_to_serve = set(config.apis)
|
||||||
|
@ -508,24 +512,29 @@ def main(args: argparse.Namespace | None = None):
|
||||||
for api_str in apis_to_serve:
|
for api_str in apis_to_serve:
|
||||||
api = Api(api_str)
|
api = Api(api_str)
|
||||||
|
|
||||||
endpoints = all_endpoints[api]
|
routes = all_routes[api]
|
||||||
impl = impls[api]
|
impl = impls[api]
|
||||||
|
|
||||||
for endpoint in endpoints:
|
for route in routes:
|
||||||
if not hasattr(impl, endpoint.name):
|
if not hasattr(impl, route.name):
|
||||||
# ideally this should be a typing violation already
|
# ideally this should be a typing violation already
|
||||||
raise ValueError(f"Could not find method {endpoint.name} on {impl}!!")
|
raise ValueError(f"Could not find method {route.name} on {impl}!")
|
||||||
|
|
||||||
impl_method = getattr(impl, endpoint.name)
|
impl_method = getattr(impl, route.name)
|
||||||
logger.debug(f"{endpoint.method.upper()} {endpoint.route}")
|
# Filter out HEAD method since it's automatically handled by FastAPI for GET routes
|
||||||
|
available_methods = [m for m in route.methods if m != "HEAD"]
|
||||||
|
if not available_methods:
|
||||||
|
raise ValueError(f"No methods found for {route.name} on {impl}")
|
||||||
|
method = available_methods[0]
|
||||||
|
logger.debug(f"{method} {route.path}")
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
|
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
|
||||||
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
|
getattr(app, method.lower())(route.path, response_model=None)(
|
||||||
create_dynamic_typed_route(
|
create_dynamic_typed_route(
|
||||||
impl_method,
|
impl_method,
|
||||||
endpoint.method,
|
method.lower(),
|
||||||
endpoint.route,
|
route.path,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -36,7 +36,7 @@ class DistributionRegistry(Protocol):
|
||||||
|
|
||||||
|
|
||||||
REGISTER_PREFIX = "distributions:registry"
|
REGISTER_PREFIX = "distributions:registry"
|
||||||
KEY_VERSION = "v8"
|
KEY_VERSION = "v9"
|
||||||
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ from llama_stack.apis.datatypes import Api
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
from llama_stack.apis.scoring_functions import ScoringFn
|
from llama_stack.apis.scoring_functions import ScoringFn
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.apis.tools import Tool
|
from llama_stack.apis.tools import ToolGroup
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
@ -74,10 +74,10 @@ class BenchmarksProtocolPrivate(Protocol):
|
||||||
async def register_benchmark(self, benchmark: Benchmark) -> None: ...
|
async def register_benchmark(self, benchmark: Benchmark) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
class ToolsProtocolPrivate(Protocol):
|
class ToolGroupsProtocolPrivate(Protocol):
|
||||||
async def register_tool(self, tool: Tool) -> None: ...
|
async def register_toolgroup(self, toolgroup: ToolGroup) -> None: ...
|
||||||
|
|
||||||
async def unregister_tool(self, tool_id: str) -> None: ...
|
async def unregister_toolgroup(self, toolgroup_id: str) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
@ -29,10 +30,12 @@ from llama_stack.apis.agents.openai_responses import (
|
||||||
OpenAIResponseObjectStream,
|
OpenAIResponseObjectStream,
|
||||||
OpenAIResponseObjectStreamResponseCompleted,
|
OpenAIResponseObjectStreamResponseCompleted,
|
||||||
OpenAIResponseObjectStreamResponseCreated,
|
OpenAIResponseObjectStreamResponseCreated,
|
||||||
|
OpenAIResponseObjectStreamResponseOutputTextDelta,
|
||||||
OpenAIResponseOutput,
|
OpenAIResponseOutput,
|
||||||
OpenAIResponseOutputMessageContent,
|
OpenAIResponseOutputMessageContent,
|
||||||
OpenAIResponseOutputMessageContentOutputText,
|
OpenAIResponseOutputMessageContentOutputText,
|
||||||
OpenAIResponseOutputMessageFunctionToolCall,
|
OpenAIResponseOutputMessageFunctionToolCall,
|
||||||
|
OpenAIResponseOutputMessageMCPListTools,
|
||||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference.inference import (
|
from llama_stack.apis.inference.inference import (
|
||||||
|
@ -255,110 +258,14 @@ class OpenAIResponsesImpl:
|
||||||
"""
|
"""
|
||||||
return await self.responses_store.list_response_input_items(response_id, after, before, include, limit, order)
|
return await self.responses_store.list_response_input_items(response_id, after, before, include, limit, order)
|
||||||
|
|
||||||
async def create_openai_response(
|
async def _process_response_choices(
|
||||||
self,
|
self,
|
||||||
input: str | list[OpenAIResponseInput],
|
chat_response: OpenAIChatCompletion,
|
||||||
model: str,
|
ctx: ChatCompletionContext,
|
||||||
instructions: str | None = None,
|
tools: list[OpenAIResponseInputTool] | None,
|
||||||
previous_response_id: str | None = None,
|
) -> list[OpenAIResponseOutput]:
|
||||||
store: bool | None = True,
|
"""Handle tool execution and response message creation."""
|
||||||
stream: bool | None = False,
|
|
||||||
temperature: float | None = None,
|
|
||||||
tools: list[OpenAIResponseInputTool] | None = None,
|
|
||||||
):
|
|
||||||
output_messages: list[OpenAIResponseOutput] = []
|
output_messages: list[OpenAIResponseOutput] = []
|
||||||
|
|
||||||
stream = False if stream is None else stream
|
|
||||||
|
|
||||||
# Huge TODO: we need to run this in a loop, until morale improves
|
|
||||||
|
|
||||||
# Create context to run "chat completion"
|
|
||||||
input = await self._prepend_previous_response(input, previous_response_id)
|
|
||||||
messages = await _convert_response_input_to_chat_messages(input)
|
|
||||||
await self._prepend_instructions(messages, instructions)
|
|
||||||
chat_tools, mcp_tool_to_server, mcp_list_message = (
|
|
||||||
await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {}, None)
|
|
||||||
)
|
|
||||||
if mcp_list_message:
|
|
||||||
output_messages.append(mcp_list_message)
|
|
||||||
|
|
||||||
ctx = ChatCompletionContext(
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
tools=chat_tools,
|
|
||||||
mcp_tool_to_server=mcp_tool_to_server,
|
|
||||||
stream=stream,
|
|
||||||
temperature=temperature,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run inference
|
|
||||||
chat_response = await self.inference_api.openai_chat_completion(
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
tools=chat_tools,
|
|
||||||
stream=stream,
|
|
||||||
temperature=temperature,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Collect output
|
|
||||||
if stream:
|
|
||||||
# TODO: refactor this into a separate method that handles streaming
|
|
||||||
chat_response_id = ""
|
|
||||||
chat_response_content = []
|
|
||||||
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
|
|
||||||
# TODO: these chunk_ fields are hacky and only take the last chunk into account
|
|
||||||
chunk_created = 0
|
|
||||||
chunk_model = ""
|
|
||||||
chunk_finish_reason = ""
|
|
||||||
async for chunk in chat_response:
|
|
||||||
chat_response_id = chunk.id
|
|
||||||
chunk_created = chunk.created
|
|
||||||
chunk_model = chunk.model
|
|
||||||
for chunk_choice in chunk.choices:
|
|
||||||
# TODO: this only works for text content
|
|
||||||
chat_response_content.append(chunk_choice.delta.content or "")
|
|
||||||
if chunk_choice.finish_reason:
|
|
||||||
chunk_finish_reason = chunk_choice.finish_reason
|
|
||||||
|
|
||||||
# Aggregate tool call arguments across chunks, using their index as the aggregation key
|
|
||||||
if chunk_choice.delta.tool_calls:
|
|
||||||
for tool_call in chunk_choice.delta.tool_calls:
|
|
||||||
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
|
|
||||||
if response_tool_call:
|
|
||||||
response_tool_call.function.arguments += tool_call.function.arguments
|
|
||||||
else:
|
|
||||||
tool_call_dict: dict[str, Any] = tool_call.model_dump()
|
|
||||||
# Ensure we don't have any empty type field in the tool call dict.
|
|
||||||
# The OpenAI client used by providers often returns a type=None here.
|
|
||||||
tool_call_dict.pop("type", None)
|
|
||||||
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
|
|
||||||
chat_response_tool_calls[tool_call.index] = response_tool_call
|
|
||||||
|
|
||||||
# Convert the dict of tool calls by index to a list of tool calls to pass back in our response
|
|
||||||
if chat_response_tool_calls:
|
|
||||||
tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())]
|
|
||||||
else:
|
|
||||||
tool_calls = None
|
|
||||||
assistant_message = OpenAIAssistantMessageParam(
|
|
||||||
content="".join(chat_response_content),
|
|
||||||
tool_calls=tool_calls,
|
|
||||||
)
|
|
||||||
chat_response = OpenAIChatCompletion(
|
|
||||||
id=chat_response_id,
|
|
||||||
choices=[
|
|
||||||
OpenAIChoice(
|
|
||||||
message=assistant_message,
|
|
||||||
finish_reason=chunk_finish_reason,
|
|
||||||
index=0,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
created=chunk_created,
|
|
||||||
model=chunk_model,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# dump and reload to map to our pydantic types
|
|
||||||
chat_response = OpenAIChatCompletion(**chat_response.model_dump())
|
|
||||||
|
|
||||||
# Execute tool calls if any
|
# Execute tool calls if any
|
||||||
for choice in chat_response.choices:
|
for choice in chat_response.choices:
|
||||||
if choice.message.tool_calls and tools:
|
if choice.message.tool_calls and tools:
|
||||||
|
@ -380,19 +287,13 @@ class OpenAIResponsesImpl:
|
||||||
else:
|
else:
|
||||||
output_messages.append(await _convert_chat_choice_to_response_message(choice))
|
output_messages.append(await _convert_chat_choice_to_response_message(choice))
|
||||||
|
|
||||||
# Create response object
|
return output_messages
|
||||||
response = OpenAIResponseObject(
|
|
||||||
created_at=chat_response.created,
|
|
||||||
id=f"resp-{uuid.uuid4()}",
|
|
||||||
model=model,
|
|
||||||
object="response",
|
|
||||||
status="completed",
|
|
||||||
output=output_messages,
|
|
||||||
)
|
|
||||||
logger.debug(f"OpenAI Responses response: {response}")
|
|
||||||
|
|
||||||
# Store response if requested
|
async def _store_response(
|
||||||
if store:
|
self,
|
||||||
|
response: OpenAIResponseObject,
|
||||||
|
input: str | list[OpenAIResponseInput],
|
||||||
|
) -> None:
|
||||||
new_input_id = f"msg_{uuid.uuid4()}"
|
new_input_id = f"msg_{uuid.uuid4()}"
|
||||||
if isinstance(input, str):
|
if isinstance(input, str):
|
||||||
# synthesize a message from the input string
|
# synthesize a message from the input string
|
||||||
|
@ -421,17 +322,233 @@ class OpenAIResponsesImpl:
|
||||||
input=input_items_data,
|
input=input_items_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def create_openai_response(
|
||||||
|
self,
|
||||||
|
input: str | list[OpenAIResponseInput],
|
||||||
|
model: str,
|
||||||
|
instructions: str | None = None,
|
||||||
|
previous_response_id: str | None = None,
|
||||||
|
store: bool | None = True,
|
||||||
|
stream: bool | None = False,
|
||||||
|
temperature: float | None = None,
|
||||||
|
tools: list[OpenAIResponseInputTool] | None = None,
|
||||||
|
):
|
||||||
|
stream = False if stream is None else stream
|
||||||
|
|
||||||
|
output_messages: list[OpenAIResponseOutput] = []
|
||||||
|
|
||||||
|
# Input preprocessing
|
||||||
|
input = await self._prepend_previous_response(input, previous_response_id)
|
||||||
|
messages = await _convert_response_input_to_chat_messages(input)
|
||||||
|
await self._prepend_instructions(messages, instructions)
|
||||||
|
|
||||||
|
# Tool setup
|
||||||
|
chat_tools, mcp_tool_to_server, mcp_list_message = (
|
||||||
|
await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {}, None)
|
||||||
|
)
|
||||||
|
if mcp_list_message:
|
||||||
|
output_messages.append(mcp_list_message)
|
||||||
|
|
||||||
|
ctx = ChatCompletionContext(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
tools=chat_tools,
|
||||||
|
mcp_tool_to_server=mcp_tool_to_server,
|
||||||
|
stream=stream,
|
||||||
|
temperature=temperature,
|
||||||
|
)
|
||||||
|
|
||||||
|
inference_result = await self.inference_api.openai_chat_completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
tools=chat_tools,
|
||||||
|
stream=stream,
|
||||||
|
temperature=temperature,
|
||||||
|
)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
|
return self._create_streaming_response(
|
||||||
|
inference_result=inference_result,
|
||||||
|
ctx=ctx,
|
||||||
|
output_messages=output_messages,
|
||||||
|
input=input,
|
||||||
|
model=model,
|
||||||
|
store=store,
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return await self._create_non_streaming_response(
|
||||||
|
inference_result=inference_result,
|
||||||
|
ctx=ctx,
|
||||||
|
output_messages=output_messages,
|
||||||
|
input=input,
|
||||||
|
model=model,
|
||||||
|
store=store,
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
|
||||||
async def async_response() -> AsyncIterator[OpenAIResponseObjectStream]:
|
async def _create_non_streaming_response(
|
||||||
# TODO: response created should actually get emitted much earlier in the process
|
self,
|
||||||
yield OpenAIResponseObjectStreamResponseCreated(response=response)
|
inference_result: Any,
|
||||||
yield OpenAIResponseObjectStreamResponseCompleted(response=response)
|
ctx: ChatCompletionContext,
|
||||||
|
output_messages: list[OpenAIResponseOutput],
|
||||||
|
input: str | list[OpenAIResponseInput],
|
||||||
|
model: str,
|
||||||
|
store: bool | None,
|
||||||
|
tools: list[OpenAIResponseInputTool] | None,
|
||||||
|
) -> OpenAIResponseObject:
|
||||||
|
chat_response = OpenAIChatCompletion(**inference_result.model_dump())
|
||||||
|
|
||||||
return async_response()
|
# Process response choices (tool execution and message creation)
|
||||||
|
output_messages.extend(
|
||||||
|
await self._process_response_choices(
|
||||||
|
chat_response=chat_response,
|
||||||
|
ctx=ctx,
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
response = OpenAIResponseObject(
|
||||||
|
created_at=chat_response.created,
|
||||||
|
id=f"resp-{uuid.uuid4()}",
|
||||||
|
model=model,
|
||||||
|
object="response",
|
||||||
|
status="completed",
|
||||||
|
output=output_messages,
|
||||||
|
)
|
||||||
|
logger.debug(f"OpenAI Responses response: {response}")
|
||||||
|
|
||||||
|
# Store response if requested
|
||||||
|
if store:
|
||||||
|
await self._store_response(
|
||||||
|
response=response,
|
||||||
|
input=input,
|
||||||
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
async def _create_streaming_response(
|
||||||
|
self,
|
||||||
|
inference_result: Any,
|
||||||
|
ctx: ChatCompletionContext,
|
||||||
|
output_messages: list[OpenAIResponseOutput],
|
||||||
|
input: str | list[OpenAIResponseInput],
|
||||||
|
model: str,
|
||||||
|
store: bool | None,
|
||||||
|
tools: list[OpenAIResponseInputTool] | None,
|
||||||
|
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||||
|
# Create initial response and emit response.created immediately
|
||||||
|
response_id = f"resp-{uuid.uuid4()}"
|
||||||
|
created_at = int(time.time())
|
||||||
|
|
||||||
|
initial_response = OpenAIResponseObject(
|
||||||
|
created_at=created_at,
|
||||||
|
id=response_id,
|
||||||
|
model=model,
|
||||||
|
object="response",
|
||||||
|
status="in_progress",
|
||||||
|
output=output_messages.copy(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Emit response.created immediately
|
||||||
|
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
|
||||||
|
|
||||||
|
# For streaming, inference_result is an async iterator of chunks
|
||||||
|
# Stream chunks and emit delta events as they arrive
|
||||||
|
chat_response_id = ""
|
||||||
|
chat_response_content = []
|
||||||
|
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
|
||||||
|
chunk_created = 0
|
||||||
|
chunk_model = ""
|
||||||
|
chunk_finish_reason = ""
|
||||||
|
sequence_number = 0
|
||||||
|
|
||||||
|
# Create a placeholder message item for delta events
|
||||||
|
message_item_id = f"msg_{uuid.uuid4()}"
|
||||||
|
|
||||||
|
async for chunk in inference_result:
|
||||||
|
chat_response_id = chunk.id
|
||||||
|
chunk_created = chunk.created
|
||||||
|
chunk_model = chunk.model
|
||||||
|
for chunk_choice in chunk.choices:
|
||||||
|
# Emit incremental text content as delta events
|
||||||
|
if chunk_choice.delta.content:
|
||||||
|
sequence_number += 1
|
||||||
|
yield OpenAIResponseObjectStreamResponseOutputTextDelta(
|
||||||
|
content_index=0,
|
||||||
|
delta=chunk_choice.delta.content,
|
||||||
|
item_id=message_item_id,
|
||||||
|
output_index=0,
|
||||||
|
sequence_number=sequence_number,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Collect content for final response
|
||||||
|
chat_response_content.append(chunk_choice.delta.content or "")
|
||||||
|
if chunk_choice.finish_reason:
|
||||||
|
chunk_finish_reason = chunk_choice.finish_reason
|
||||||
|
|
||||||
|
# Aggregate tool call arguments across chunks, using their index as the aggregation key
|
||||||
|
if chunk_choice.delta.tool_calls:
|
||||||
|
for tool_call in chunk_choice.delta.tool_calls:
|
||||||
|
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
|
||||||
|
if response_tool_call:
|
||||||
|
response_tool_call.function.arguments += tool_call.function.arguments
|
||||||
|
else:
|
||||||
|
tool_call_dict: dict[str, Any] = tool_call.model_dump()
|
||||||
|
tool_call_dict.pop("type", None)
|
||||||
|
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
|
||||||
|
chat_response_tool_calls[tool_call.index] = response_tool_call
|
||||||
|
|
||||||
|
# Convert collected chunks to complete response
|
||||||
|
if chat_response_tool_calls:
|
||||||
|
tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())]
|
||||||
|
else:
|
||||||
|
tool_calls = None
|
||||||
|
assistant_message = OpenAIAssistantMessageParam(
|
||||||
|
content="".join(chat_response_content),
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
)
|
||||||
|
chat_response_obj = OpenAIChatCompletion(
|
||||||
|
id=chat_response_id,
|
||||||
|
choices=[
|
||||||
|
OpenAIChoice(
|
||||||
|
message=assistant_message,
|
||||||
|
finish_reason=chunk_finish_reason,
|
||||||
|
index=0,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
created=chunk_created,
|
||||||
|
model=chunk_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process response choices (tool execution and message creation)
|
||||||
|
output_messages.extend(
|
||||||
|
await self._process_response_choices(
|
||||||
|
chat_response=chat_response_obj,
|
||||||
|
ctx=ctx,
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create final response
|
||||||
|
final_response = OpenAIResponseObject(
|
||||||
|
created_at=created_at,
|
||||||
|
id=response_id,
|
||||||
|
model=model,
|
||||||
|
object="response",
|
||||||
|
status="completed",
|
||||||
|
output=output_messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
if store:
|
||||||
|
await self._store_response(
|
||||||
|
response=final_response,
|
||||||
|
input=input,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Emit response.completed
|
||||||
|
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
|
||||||
|
|
||||||
async def _convert_response_tools_to_chat_tools(
|
async def _convert_response_tools_to_chat_tools(
|
||||||
self, tools: list[OpenAIResponseInputTool]
|
self, tools: list[OpenAIResponseInputTool]
|
||||||
) -> tuple[
|
) -> tuple[
|
||||||
|
@ -441,7 +558,6 @@ class OpenAIResponsesImpl:
|
||||||
]:
|
]:
|
||||||
from llama_stack.apis.agents.openai_responses import (
|
from llama_stack.apis.agents.openai_responses import (
|
||||||
MCPListToolsTool,
|
MCPListToolsTool,
|
||||||
OpenAIResponseOutputMessageMCPListTools,
|
|
||||||
)
|
)
|
||||||
from llama_stack.apis.tools.tools import Tool
|
from llama_stack.apis.tools.tools import Tool
|
||||||
|
|
||||||
|
|
|
@ -75,6 +75,8 @@ class PromptGuardShield:
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.threshold = threshold
|
self.threshold = threshold
|
||||||
|
|
||||||
|
self.device = "cpu"
|
||||||
|
if torch.cuda.is_available():
|
||||||
self.device = "cuda"
|
self.device = "cuda"
|
||||||
|
|
||||||
# load model and tokenizer
|
# load model and tokenizer
|
||||||
|
|
|
@ -25,14 +25,14 @@ from llama_stack.apis.tools import (
|
||||||
RAGQueryConfig,
|
RAGQueryConfig,
|
||||||
RAGQueryResult,
|
RAGQueryResult,
|
||||||
RAGToolRuntime,
|
RAGToolRuntime,
|
||||||
Tool,
|
|
||||||
ToolDef,
|
ToolDef,
|
||||||
|
ToolGroup,
|
||||||
ToolInvocationResult,
|
ToolInvocationResult,
|
||||||
ToolParameter,
|
ToolParameter,
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
||||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
content_from_doc,
|
content_from_doc,
|
||||||
|
@ -49,7 +49,7 @@ def make_random_string(length: int = 8):
|
||||||
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
|
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
|
||||||
|
|
||||||
|
|
||||||
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: RagToolRuntimeConfig,
|
config: RagToolRuntimeConfig,
|
||||||
|
@ -66,10 +66,10 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||||
async def shutdown(self):
|
async def shutdown(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_tool(self, tool: Tool) -> None:
|
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def unregister_tool(self, tool_id: str) -> None:
|
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||||
return
|
return
|
||||||
|
|
||||||
async def insert(
|
async def insert(
|
||||||
|
|
|
@ -19,10 +19,10 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
api=Api.agents,
|
api=Api.agents,
|
||||||
provider_type="inline::meta-reference",
|
provider_type="inline::meta-reference",
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
"matplotlib",
|
# "matplotlib",
|
||||||
"pillow",
|
# "pillow",
|
||||||
"pandas",
|
# "pandas",
|
||||||
"scikit-learn",
|
# "scikit-learn",
|
||||||
]
|
]
|
||||||
+ kvstore_dependencies(),
|
+ kvstore_dependencies(),
|
||||||
module="llama_stack.providers.inline.agents.meta_reference",
|
module="llama_stack.providers.inline.agents.meta_reference",
|
||||||
|
|
|
@ -13,7 +13,7 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.eval,
|
api=Api.eval,
|
||||||
provider_type="inline::meta-reference",
|
provider_type="inline::meta-reference",
|
||||||
pip_packages=["tree_sitter", "pythainlp", "langdetect", "emoji", "nltk"],
|
# pip_packages=["tree_sitter", "pythainlp", "langdetect", "emoji", "nltk"],
|
||||||
module="llama_stack.providers.inline.eval.meta_reference",
|
module="llama_stack.providers.inline.eval.meta_reference",
|
||||||
config_class="llama_stack.providers.inline.eval.meta_reference.MetaReferenceEvalConfig",
|
config_class="llama_stack.providers.inline.eval.meta_reference.MetaReferenceEvalConfig",
|
||||||
api_dependencies=[
|
api_dependencies=[
|
||||||
|
|
|
@ -20,16 +20,16 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
api=Api.tool_runtime,
|
api=Api.tool_runtime,
|
||||||
provider_type="inline::rag-runtime",
|
provider_type="inline::rag-runtime",
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
"blobfile",
|
# "blobfile",
|
||||||
"chardet",
|
# "chardet",
|
||||||
"pypdf",
|
# "pypdf",
|
||||||
"tqdm",
|
# "tqdm",
|
||||||
"numpy",
|
# "numpy",
|
||||||
"scikit-learn",
|
# "scikit-learn",
|
||||||
"scipy",
|
# "scipy",
|
||||||
"nltk",
|
# "nltk",
|
||||||
"sentencepiece",
|
# "sentencepiece",
|
||||||
"transformers",
|
# "transformers",
|
||||||
],
|
],
|
||||||
module="llama_stack.providers.inline.tool_runtime.rag",
|
module="llama_stack.providers.inline.tool_runtime.rag",
|
||||||
config_class="llama_stack.providers.inline.tool_runtime.rag.config.RagToolRuntimeConfig",
|
config_class="llama_stack.providers.inline.tool_runtime.rag.config.RagToolRuntimeConfig",
|
||||||
|
|
|
@ -4,8 +4,9 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
@ -24,11 +25,27 @@ class VLLMInferenceAdapterConfig(BaseModel):
|
||||||
default="fake",
|
default="fake",
|
||||||
description="The API token",
|
description="The API token",
|
||||||
)
|
)
|
||||||
tls_verify: bool = Field(
|
tls_verify: bool | str = Field(
|
||||||
default=True,
|
default=True,
|
||||||
description="Whether to verify TLS certificates",
|
description="Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@field_validator("tls_verify")
|
||||||
|
@classmethod
|
||||||
|
def validate_tls_verify(cls, v):
|
||||||
|
if isinstance(v, str):
|
||||||
|
# Check if it's a boolean string
|
||||||
|
if v.lower() in ("true", "false"):
|
||||||
|
return v.lower() == "true"
|
||||||
|
# Otherwise, treat it as a cert path
|
||||||
|
cert_path = Path(v).expanduser().resolve()
|
||||||
|
if not cert_path.exists():
|
||||||
|
raise ValueError(f"TLS certificate file does not exist: {v}")
|
||||||
|
if not cert_path.is_file():
|
||||||
|
raise ValueError(f"TLS certificate path is not a file: {v}")
|
||||||
|
return v
|
||||||
|
return v
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(
|
def sample_run_config(
|
||||||
cls,
|
cls,
|
||||||
|
|
|
@ -313,7 +313,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
return AsyncOpenAI(
|
return AsyncOpenAI(
|
||||||
base_url=self.config.url,
|
base_url=self.config.url,
|
||||||
api_key=self.config.api_token,
|
api_key=self.config.api_token,
|
||||||
http_client=None if self.config.tls_verify else httpx.AsyncClient(verify=False),
|
http_client=httpx.AsyncClient(verify=self.config.tls_verify),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
|
|
|
@ -12,19 +12,19 @@ import httpx
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
ListToolDefsResponse,
|
ListToolDefsResponse,
|
||||||
Tool,
|
|
||||||
ToolDef,
|
ToolDef,
|
||||||
|
ToolGroup,
|
||||||
ToolInvocationResult,
|
ToolInvocationResult,
|
||||||
ToolParameter,
|
ToolParameter,
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||||
|
|
||||||
from .config import BingSearchToolConfig
|
from .config import BingSearchToolConfig
|
||||||
|
|
||||||
|
|
||||||
class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
class BingSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||||
def __init__(self, config: BingSearchToolConfig):
|
def __init__(self, config: BingSearchToolConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.url = "https://api.bing.microsoft.com/v7.0/search"
|
self.url = "https://api.bing.microsoft.com/v7.0/search"
|
||||||
|
@ -32,10 +32,10 @@ class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestP
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_tool(self, tool: Tool) -> None:
|
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def unregister_tool(self, tool_id: str) -> None:
|
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||||
return
|
return
|
||||||
|
|
||||||
def _get_api_key(self) -> str:
|
def _get_api_key(self) -> str:
|
||||||
|
|
|
@ -11,30 +11,30 @@ import httpx
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
ListToolDefsResponse,
|
ListToolDefsResponse,
|
||||||
Tool,
|
|
||||||
ToolDef,
|
ToolDef,
|
||||||
|
ToolGroup,
|
||||||
ToolInvocationResult,
|
ToolInvocationResult,
|
||||||
ToolParameter,
|
ToolParameter,
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.models.llama.datatypes import BuiltinTool
|
from llama_stack.models.llama.datatypes import BuiltinTool
|
||||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||||
|
|
||||||
from .config import BraveSearchToolConfig
|
from .config import BraveSearchToolConfig
|
||||||
|
|
||||||
|
|
||||||
class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
class BraveSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||||
def __init__(self, config: BraveSearchToolConfig):
|
def __init__(self, config: BraveSearchToolConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_tool(self, tool: Tool) -> None:
|
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def unregister_tool(self, tool_id: str) -> None:
|
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||||
return
|
return
|
||||||
|
|
||||||
def _get_api_key(self) -> str:
|
def _get_api_key(self) -> str:
|
||||||
|
|
|
@ -10,8 +10,8 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class MCPProviderDataValidator(BaseModel):
|
class MCPProviderDataValidator(BaseModel):
|
||||||
# mcp_endpoint => list of headers to send
|
# mcp_endpoint => dict of headers to send
|
||||||
mcp_headers: dict[str, list[str]] | None = None
|
mcp_headers: dict[str, dict[str, str]] | None = None
|
||||||
|
|
||||||
|
|
||||||
class MCPProviderConfig(BaseModel):
|
class MCPProviderConfig(BaseModel):
|
||||||
|
|
|
@ -11,26 +11,33 @@ from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.datatypes import Api
|
from llama_stack.apis.datatypes import Api
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
ListToolDefsResponse,
|
ListToolDefsResponse,
|
||||||
|
ToolGroup,
|
||||||
ToolInvocationResult,
|
ToolInvocationResult,
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||||
from llama_stack.providers.utils.tools.mcp import convert_header_list_to_dict, invoke_mcp_tool, list_mcp_tools
|
from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool, list_mcp_tools
|
||||||
|
|
||||||
from .config import MCPProviderConfig
|
from .config import MCPProviderConfig
|
||||||
|
|
||||||
logger = get_logger(__name__, category="tools")
|
logger = get_logger(__name__, category="tools")
|
||||||
|
|
||||||
|
|
||||||
class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||||
def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]):
|
def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
async def list_runtime_tools(
|
async def list_runtime_tools(
|
||||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||||
) -> ListToolDefsResponse:
|
) -> ListToolDefsResponse:
|
||||||
|
@ -62,5 +69,5 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, Nee
|
||||||
for uri, values in provider_data.mcp_headers.items():
|
for uri, values in provider_data.mcp_headers.items():
|
||||||
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
|
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
|
||||||
continue
|
continue
|
||||||
headers.update(convert_header_list_to_dict(values))
|
headers.update(values)
|
||||||
return headers
|
return headers
|
||||||
|
|
|
@ -12,29 +12,29 @@ import httpx
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
ListToolDefsResponse,
|
ListToolDefsResponse,
|
||||||
Tool,
|
|
||||||
ToolDef,
|
ToolDef,
|
||||||
|
ToolGroup,
|
||||||
ToolInvocationResult,
|
ToolInvocationResult,
|
||||||
ToolParameter,
|
ToolParameter,
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||||
|
|
||||||
from .config import TavilySearchToolConfig
|
from .config import TavilySearchToolConfig
|
||||||
|
|
||||||
|
|
||||||
class TavilySearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
class TavilySearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||||
def __init__(self, config: TavilySearchToolConfig):
|
def __init__(self, config: TavilySearchToolConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_tool(self, tool: Tool) -> None:
|
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def unregister_tool(self, tool_id: str) -> None:
|
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||||
return
|
return
|
||||||
|
|
||||||
def _get_api_key(self) -> str:
|
def _get_api_key(self) -> str:
|
||||||
|
|
|
@ -12,19 +12,19 @@ import httpx
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
ListToolDefsResponse,
|
ListToolDefsResponse,
|
||||||
Tool,
|
|
||||||
ToolDef,
|
ToolDef,
|
||||||
|
ToolGroup,
|
||||||
ToolInvocationResult,
|
ToolInvocationResult,
|
||||||
ToolParameter,
|
ToolParameter,
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||||
|
|
||||||
from .config import WolframAlphaToolConfig
|
from .config import WolframAlphaToolConfig
|
||||||
|
|
||||||
|
|
||||||
class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
class WolframAlphaToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||||
def __init__(self, config: WolframAlphaToolConfig):
|
def __init__(self, config: WolframAlphaToolConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.url = "https://api.wolframalpha.com/v2/query"
|
self.url = "https://api.wolframalpha.com/v2/query"
|
||||||
|
@ -32,10 +32,10 @@ class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_tool(self, tool: Tool) -> None:
|
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def unregister_tool(self, tool_id: str) -> None:
|
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||||
return
|
return
|
||||||
|
|
||||||
def _get_api_key(self) -> str:
|
def _get_api_key(self) -> str:
|
||||||
|
|
|
@ -1402,9 +1402,8 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
||||||
outstanding_responses: list[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]],
|
outstanding_responses: list[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]],
|
||||||
):
|
):
|
||||||
id = f"chatcmpl-{uuid.uuid4()}"
|
id = f"chatcmpl-{uuid.uuid4()}"
|
||||||
for outstanding_response in outstanding_responses:
|
for i, outstanding_response in enumerate(outstanding_responses):
|
||||||
response = await outstanding_response
|
response = await outstanding_response
|
||||||
i = 0
|
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
event = chunk.event
|
event = chunk.event
|
||||||
finish_reason = _convert_stop_reason_to_openai_finish_reason(event.stop_reason)
|
finish_reason = _convert_stop_reason_to_openai_finish_reason(event.stop_reason)
|
||||||
|
@ -1459,7 +1458,6 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
||||||
model=model,
|
model=model,
|
||||||
object="chat.completion.chunk",
|
object="chat.completion.chunk",
|
||||||
)
|
)
|
||||||
i = i + 1
|
|
||||||
|
|
||||||
async def _process_non_stream_response(
|
async def _process_non_stream_response(
|
||||||
self, model: str, outstanding_responses: list[Awaitable[ChatCompletionResponse]]
|
self, model: str, outstanding_responses: list[Awaitable[ChatCompletionResponse]]
|
||||||
|
|
|
@ -51,16 +51,6 @@ async def sse_client_wrapper(endpoint: str, headers: dict[str, str]):
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def convert_header_list_to_dict(header_list: list[str]) -> dict[str, str]:
|
|
||||||
headers = {}
|
|
||||||
for header in header_list:
|
|
||||||
parts = header.split(":")
|
|
||||||
if len(parts) == 2:
|
|
||||||
k, v = parts
|
|
||||||
headers[k.strip()] = v.strip()
|
|
||||||
return headers
|
|
||||||
|
|
||||||
|
|
||||||
async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse:
|
async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse:
|
||||||
tools = []
|
tools = []
|
||||||
async with sse_client_wrapper(endpoint, headers) as session:
|
async with sse_client_wrapper(endpoint, headers) as session:
|
||||||
|
|
|
@ -1,855 +0,0 @@
|
||||||
{
|
|
||||||
"bedrock": [
|
|
||||||
"aiosqlite",
|
|
||||||
"autoevals",
|
|
||||||
"blobfile",
|
|
||||||
"boto3",
|
|
||||||
"chardet",
|
|
||||||
"chromadb-client",
|
|
||||||
"datasets",
|
|
||||||
"emoji",
|
|
||||||
"faiss-cpu",
|
|
||||||
"fastapi",
|
|
||||||
"fire",
|
|
||||||
"httpx",
|
|
||||||
"langdetect",
|
|
||||||
"matplotlib",
|
|
||||||
"mcp",
|
|
||||||
"nltk",
|
|
||||||
"numpy",
|
|
||||||
"openai",
|
|
||||||
"opentelemetry-exporter-otlp-proto-http",
|
|
||||||
"opentelemetry-sdk",
|
|
||||||
"pandas",
|
|
||||||
"pillow",
|
|
||||||
"psycopg2-binary",
|
|
||||||
"pymongo",
|
|
||||||
"pypdf",
|
|
||||||
"pythainlp",
|
|
||||||
"redis",
|
|
||||||
"requests",
|
|
||||||
"scikit-learn",
|
|
||||||
"scipy",
|
|
||||||
"sentencepiece",
|
|
||||||
"sqlalchemy[asyncio]",
|
|
||||||
"tqdm",
|
|
||||||
"transformers",
|
|
||||||
"tree_sitter",
|
|
||||||
"uvicorn"
|
|
||||||
],
|
|
||||||
"cerebras": [
|
|
||||||
"aiosqlite",
|
|
||||||
"autoevals",
|
|
||||||
"blobfile",
|
|
||||||
"cerebras_cloud_sdk",
|
|
||||||
"chardet",
|
|
||||||
"chromadb-client",
|
|
||||||
"datasets",
|
|
||||||
"emoji",
|
|
||||||
"faiss-cpu",
|
|
||||||
"fastapi",
|
|
||||||
"fire",
|
|
||||||
"httpx",
|
|
||||||
"langdetect",
|
|
||||||
"matplotlib",
|
|
||||||
"nltk",
|
|
||||||
"numpy",
|
|
||||||
"openai",
|
|
||||||
"opentelemetry-exporter-otlp-proto-http",
|
|
||||||
"opentelemetry-sdk",
|
|
||||||
"pandas",
|
|
||||||
"pillow",
|
|
||||||
"psycopg2-binary",
|
|
||||||
"pymongo",
|
|
||||||
"pypdf",
|
|
||||||
"pythainlp",
|
|
||||||
"redis",
|
|
||||||
"requests",
|
|
||||||
"scikit-learn",
|
|
||||||
"scipy",
|
|
||||||
"sentencepiece",
|
|
||||||
"sqlalchemy[asyncio]",
|
|
||||||
"tqdm",
|
|
||||||
"transformers",
|
|
||||||
"tree_sitter",
|
|
||||||
"uvicorn",
|
|
||||||
"sentence-transformers --no-deps",
|
|
||||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
|
||||||
],
|
|
||||||
"ci-tests": [
|
|
||||||
"aiosqlite",
|
|
||||||
"autoevals",
|
|
||||||
"blobfile",
|
|
||||||
"chardet",
|
|
||||||
"chromadb-client",
|
|
||||||
"datasets",
|
|
||||||
"emoji",
|
|
||||||
"fastapi",
|
|
||||||
"fire",
|
|
||||||
"fireworks-ai",
|
|
||||||
"httpx",
|
|
||||||
"langdetect",
|
|
||||||
"matplotlib",
|
|
||||||
"mcp",
|
|
||||||
"nltk",
|
|
||||||
"numpy",
|
|
||||||
"openai",
|
|
||||||
"opentelemetry-exporter-otlp-proto-http",
|
|
||||||
"opentelemetry-sdk",
|
|
||||||
"pandas",
|
|
||||||
"pillow",
|
|
||||||
"psycopg2-binary",
|
|
||||||
"pymongo",
|
|
||||||
"pypdf",
|
|
||||||
"pythainlp",
|
|
||||||
"redis",
|
|
||||||
"requests",
|
|
||||||
"scikit-learn",
|
|
||||||
"scipy",
|
|
||||||
"sentencepiece",
|
|
||||||
"sqlalchemy[asyncio]",
|
|
||||||
"sqlite-vec",
|
|
||||||
"tqdm",
|
|
||||||
"transformers",
|
|
||||||
"tree_sitter",
|
|
||||||
"uvicorn",
|
|
||||||
"sentence-transformers --no-deps",
|
|
||||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
|
||||||
],
|
|
||||||
"dell": [
|
|
||||||
"aiohttp",
|
|
||||||
"aiosqlite",
|
|
||||||
"autoevals",
|
|
||||||
"blobfile",
|
|
||||||
"chardet",
|
|
||||||
"chromadb-client",
|
|
||||||
"datasets",
|
|
||||||
"emoji",
|
|
||||||
"faiss-cpu",
|
|
||||||
"fastapi",
|
|
||||||
"fire",
|
|
||||||
"httpx",
|
|
||||||
"huggingface_hub",
|
|
||||||
"langdetect",
|
|
||||||
"matplotlib",
|
|
||||||
"nltk",
|
|
||||||
"numpy",
|
|
||||||
"openai",
|
|
||||||
"opentelemetry-exporter-otlp-proto-http",
|
|
||||||
"opentelemetry-sdk",
|
|
||||||
"pandas",
|
|
||||||
"pillow",
|
|
||||||
"psycopg2-binary",
|
|
||||||
"pymongo",
|
|
||||||
"pypdf",
|
|
||||||
"pythainlp",
|
|
||||||
"redis",
|
|
||||||
"requests",
|
|
||||||
"scikit-learn",
|
|
||||||
"scipy",
|
|
||||||
"sentencepiece",
|
|
||||||
"sqlalchemy[asyncio]",
|
|
||||||
"tqdm",
|
|
||||||
"transformers",
|
|
||||||
"tree_sitter",
|
|
||||||
"uvicorn",
|
|
||||||
"sentence-transformers --no-deps",
|
|
||||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
|
||||||
],
|
|
||||||
"fireworks": [
|
|
||||||
"aiosqlite",
|
|
||||||
"autoevals",
|
|
||||||
"blobfile",
|
|
||||||
"chardet",
|
|
||||||
"chromadb-client",
|
|
||||||
"datasets",
|
|
||||||
"emoji",
|
|
||||||
"faiss-cpu",
|
|
||||||
"fastapi",
|
|
||||||
"fire",
|
|
||||||
"fireworks-ai",
|
|
||||||
"httpx",
|
|
||||||
"langdetect",
|
|
||||||
"matplotlib",
|
|
||||||
"mcp",
|
|
||||||
"nltk",
|
|
||||||
"numpy",
|
|
||||||
"openai",
|
|
||||||
"opentelemetry-exporter-otlp-proto-http",
|
|
||||||
"opentelemetry-sdk",
|
|
||||||
"pandas",
|
|
||||||
"pillow",
|
|
||||||
"psycopg2-binary",
|
|
||||||
"pymongo",
|
|
||||||
"pypdf",
|
|
||||||
"pythainlp",
|
|
||||||
"redis",
|
|
||||||
"requests",
|
|
||||||
"scikit-learn",
|
|
||||||
"scipy",
|
|
||||||
"sentencepiece",
|
|
||||||
"sqlalchemy[asyncio]",
|
|
||||||
"tqdm",
|
|
||||||
"transformers",
|
|
||||||
"tree_sitter",
|
|
||||||
"uvicorn",
|
|
||||||
"sentence-transformers --no-deps",
|
|
||||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
|
||||||
],
|
|
||||||
"groq": [
|
|
||||||
"aiosqlite",
|
|
||||||
"autoevals",
|
|
||||||
"blobfile",
|
|
||||||
"chardet",
|
|
||||||
"datasets",
|
|
||||||
"emoji",
|
|
||||||
"faiss-cpu",
|
|
||||||
"fastapi",
|
|
||||||
"fire",
|
|
||||||
"httpx",
|
|
||||||
"langdetect",
|
|
||||||
"litellm",
|
|
||||||
"matplotlib",
|
|
||||||
"nltk",
|
|
||||||
"numpy",
|
|
||||||
"openai",
|
|
||||||
"opentelemetry-exporter-otlp-proto-http",
|
|
||||||
"opentelemetry-sdk",
|
|
||||||
"pandas",
|
|
||||||
"pillow",
|
|
||||||
"psycopg2-binary",
|
|
||||||
"pymongo",
|
|
||||||
"pypdf",
|
|
||||||
"pythainlp",
|
|
||||||
"redis",
|
|
||||||
"requests",
|
|
||||||
"scikit-learn",
|
|
||||||
"scipy",
|
|
||||||
"sentencepiece",
|
|
||||||
"sqlalchemy[asyncio]",
|
|
||||||
"tqdm",
|
|
||||||
"transformers",
|
|
||||||
"tree_sitter",
|
|
||||||
"uvicorn"
|
|
||||||
],
|
|
||||||
"hf-endpoint": [
|
|
||||||
"aiohttp",
|
|
||||||
"aiosqlite",
|
|
||||||
"autoevals",
|
|
||||||
"blobfile",
|
|
||||||
"chardet",
|
|
||||||
"chromadb-client",
|
|
||||||
"datasets",
|
|
||||||
"emoji",
|
|
||||||
"faiss-cpu",
|
|
||||||
"fastapi",
|
|
||||||
"fire",
|
|
||||||
"httpx",
|
|
||||||
"huggingface_hub",
|
|
||||||
"langdetect",
|
|
||||||
"matplotlib",
|
|
||||||
"mcp",
|
|
||||||
"nltk",
|
|
||||||
"numpy",
|
|
||||||
"openai",
|
|
||||||
"opentelemetry-exporter-otlp-proto-http",
|
|
||||||
"opentelemetry-sdk",
|
|
||||||
"pandas",
|
|
||||||
"pillow",
|
|
||||||
"psycopg2-binary",
|
|
||||||
"pymongo",
|
|
||||||
"pypdf",
|
|
||||||
"pythainlp",
|
|
||||||
"redis",
|
|
||||||
"requests",
|
|
||||||
"scikit-learn",
|
|
||||||
"scipy",
|
|
||||||
"sentencepiece",
|
|
||||||
"sqlalchemy[asyncio]",
|
|
||||||
"tqdm",
|
|
||||||
"transformers",
|
|
||||||
"tree_sitter",
|
|
||||||
"uvicorn"
|
|
||||||
],
|
|
||||||
"hf-serverless": [
|
|
||||||
"aiohttp",
|
|
||||||
"aiosqlite",
|
|
||||||
"autoevals",
|
|
||||||
"blobfile",
|
|
||||||
"chardet",
|
|
||||||
"chromadb-client",
|
|
||||||
"datasets",
|
|
||||||
"emoji",
|
|
||||||
"faiss-cpu",
|
|
||||||
"fastapi",
|
|
||||||
"fire",
|
|
||||||
"httpx",
|
|
||||||
"huggingface_hub",
|
|
||||||
"langdetect",
|
|
||||||
"matplotlib",
|
|
||||||
"mcp",
|
|
||||||
"nltk",
|
|
||||||
"numpy",
|
|
||||||
"openai",
|
|
||||||
"opentelemetry-exporter-otlp-proto-http",
|
|
||||||
"opentelemetry-sdk",
|
|
||||||
"pandas",
|
|
||||||
"pillow",
|
|
||||||
"psycopg2-binary",
|
|
||||||
"pymongo",
|
|
||||||
"pypdf",
|
|
||||||
"pythainlp",
|
|
||||||
"redis",
|
|
||||||
"requests",
|
|
||||||
"scikit-learn",
|
|
||||||
"scipy",
|
|
||||||
"sentencepiece",
|
|
||||||
"sqlalchemy[asyncio]",
|
|
||||||
"tqdm",
|
|
||||||
"transformers",
|
|
||||||
"tree_sitter",
|
|
||||||
"uvicorn",
|
|
||||||
"sentence-transformers --no-deps",
|
|
||||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
|
||||||
],
|
|
||||||
"llama_api": [
|
|
||||||
"aiosqlite",
|
|
||||||
"autoevals",
|
|
||||||
"blobfile",
|
|
||||||
"chardet",
|
|
||||||
"chromadb-client",
|
|
||||||
"datasets",
|
|
||||||
"emoji",
|
|
||||||
"fastapi",
|
|
||||||
"fire",
|
|
||||||
"httpx",
|
|
||||||
"langdetect",
|
|
||||||
"litellm",
|
|
||||||
"matplotlib",
|
|
||||||
"mcp",
|
|
||||||
"nltk",
|
|
||||||
"numpy",
|
|
||||||
"openai",
|
|
||||||
"opentelemetry-exporter-otlp-proto-http",
|
|
||||||
"opentelemetry-sdk",
|
|
||||||
"pandas",
|
|
||||||
"pillow",
|
|
||||||
"psycopg2-binary",
|
|
||||||
"pymongo",
|
|
||||||
"pypdf",
|
|
||||||
"pythainlp",
|
|
||||||
"redis",
|
|
||||||
"requests",
|
|
||||||
"scikit-learn",
|
|
||||||
"scipy",
|
|
||||||
"sentencepiece",
|
|
||||||
"sqlalchemy[asyncio]",
|
|
||||||
"sqlite-vec",
|
|
||||||
"tqdm",
|
|
||||||
"transformers",
|
|
||||||
"tree_sitter",
|
|
||||||
"uvicorn",
|
|
||||||
"sentence-transformers --no-deps",
|
|
||||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
|
||||||
],
|
|
||||||
"meta-reference-gpu": [
|
|
||||||
"accelerate",
|
|
||||||
"aiosqlite",
|
|
||||||
"autoevals",
|
|
||||||
"blobfile",
|
|
||||||
"chardet",
|
|
||||||
"chromadb-client",
|
|
||||||
"datasets",
|
|
||||||
"emoji",
|
|
||||||
"fairscale",
|
|
||||||
"faiss-cpu",
|
|
||||||
"fastapi",
|
|
||||||
"fbgemm-gpu-genai==1.1.2",
|
|
||||||
"fire",
|
|
||||||
"httpx",
|
|
||||||
"langdetect",
|
|
||||||
"lm-format-enforcer",
|
|
||||||
"matplotlib",
|
|
||||||
"mcp",
|
|
||||||
"nltk",
|
|
||||||
"numpy",
|
|
||||||
"openai",
|
|
||||||
"opentelemetry-exporter-otlp-proto-http",
|
|
||||||
"opentelemetry-sdk",
|
|
||||||
"pandas",
|
|
||||||
"pillow",
|
|
||||||
"psycopg2-binary",
|
|
||||||
"pymongo",
|
|
||||||
"pypdf",
|
|
||||||
"pythainlp",
|
|
||||||
"redis",
|
|
||||||
"requests",
|
|
||||||
"scikit-learn",
|
|
||||||
"scipy",
|
|
||||||
"sentence-transformers",
|
|
||||||
"sentencepiece",
|
|
||||||
"sqlalchemy[asyncio]",
|
|
||||||
"torch",
|
|
||||||
"torchao==0.8.0",
|
|
||||||
"torchvision",
|
|
||||||
"tqdm",
|
|
||||||
"transformers",
|
|
||||||
"tree_sitter",
|
|
||||||
"uvicorn",
|
|
||||||
"zmq"
|
|
||||||
],
|
|
||||||
"nvidia": [
|
|
||||||
"aiohttp",
|
|
||||||
"aiosqlite",
|
|
||||||
"blobfile",
|
|
||||||
"chardet",
|
|
||||||
"datasets",
|
|
||||||
"faiss-cpu",
|
|
||||||
"fastapi",
|
|
||||||
"fire",
|
|
||||||
"httpx",
|
|
||||||
"matplotlib",
|
|
||||||
"nltk",
|
|
||||||
"numpy",
|
|
||||||
"openai",
|
|
||||||
"opentelemetry-exporter-otlp-proto-http",
|
|
||||||
"opentelemetry-sdk",
|
|
||||||
"pandas",
|
|
||||||
"pillow",
|
|
||||||
"psycopg2-binary",
|
|
||||||
"pymongo",
|
|
||||||
"pypdf",
|
|
||||||
"redis",
|
|
||||||
"requests",
|
|
||||||
"scikit-learn",
|
|
||||||
"scipy",
|
|
||||||
"sentencepiece",
|
|
||||||
"sqlalchemy[asyncio]",
|
|
||||||
"tqdm",
|
|
||||||
"transformers",
|
|
||||||
"uvicorn"
|
|
||||||
],
|
|
||||||
"ollama": [
|
|
||||||
"aiohttp",
|
|
||||||
"aiosqlite",
|
|
||||||
"blobfile",
|
|
||||||
"chardet",
|
|
||||||
"chromadb-client",
|
|
||||||
"emoji",
|
|
||||||
"faiss-cpu",
|
|
||||||
"fastapi",
|
|
||||||
"fire",
|
|
||||||
"httpx",
|
|
||||||
"langdetect",
|
|
||||||
"matplotlib",
|
|
||||||
"mcp",
|
|
||||||
"nltk",
|
|
||||||
"ollama",
|
|
||||||
"openai",
|
|
||||||
"opentelemetry-exporter-otlp-proto-http",
|
|
||||||
"opentelemetry-sdk",
|
|
||||||
"peft",
|
|
||||||
"pillow",
|
|
||||||
"psycopg2-binary",
|
|
||||||
"pymongo",
|
|
||||||
"pypdf",
|
|
||||||
"pythainlp",
|
|
||||||
"requests",
|
|
||||||
"sqlalchemy[asyncio]",
|
|
||||||
"tqdm",
|
|
||||||
"tree_sitter",
|
|
||||||
"trl",
|
|
||||||
"uvicorn"
|
|
||||||
],
|
|
||||||
"open-benchmark": [
|
|
||||||
"aiosqlite",
|
|
||||||
"autoevals",
|
|
||||||
"blobfile",
|
|
||||||
"chardet",
|
|
||||||
"chromadb-client",
|
|
||||||
"datasets",
|
|
||||||
"emoji",
|
|
||||||
"fastapi",
|
|
||||||
"fire",
|
|
||||||
"httpx",
|
|
||||||
"langdetect",
|
|
||||||
"litellm",
|
|
||||||
"matplotlib",
|
|
||||||
"mcp",
|
|
||||||
"nltk",
|
|
||||||
"numpy",
|
|
||||||
"openai",
|
|
||||||
"opentelemetry-exporter-otlp-proto-http",
|
|
||||||
"opentelemetry-sdk",
|
|
||||||
"pandas",
|
|
||||||
"pillow",
|
|
||||||
"psycopg2-binary",
|
|
||||||
"pymongo",
|
|
||||||
"pypdf",
|
|
||||||
"pythainlp",
|
|
||||||
"redis",
|
|
||||||
"requests",
|
|
||||||
"scikit-learn",
|
|
||||||
"scipy",
|
|
||||||
"sentencepiece",
|
|
||||||
"sqlalchemy[asyncio]",
|
|
||||||
"sqlite-vec",
|
|
||||||
"together",
|
|
||||||
"tqdm",
|
|
||||||
"transformers",
|
|
||||||
"tree_sitter",
|
|
||||||
"uvicorn"
|
|
||||||
],
|
|
||||||
"passthrough": [
|
|
||||||
"aiosqlite",
|
|
||||||
"autoevals",
|
|
||||||
"blobfile",
|
|
||||||
"chardet",
|
|
||||||
"chromadb-client",
|
|
||||||
"datasets",
|
|
||||||
"emoji",
|
|
||||||
"faiss-cpu",
|
|
||||||
"fastapi",
|
|
||||||
"fire",
|
|
||||||
"httpx",
|
|
||||||
"langdetect",
|
|
||||||
"matplotlib",
|
|
||||||
"mcp",
|
|
||||||
"nltk",
|
|
||||||
"numpy",
|
|
||||||
"openai",
|
|
||||||
"opentelemetry-exporter-otlp-proto-http",
|
|
||||||
"opentelemetry-sdk",
|
|
||||||
"pandas",
|
|
||||||
"pillow",
|
|
||||||
"psycopg2-binary",
|
|
||||||
"pymongo",
|
|
||||||
"pypdf",
|
|
||||||
"pythainlp",
|
|
||||||
"redis",
|
|
||||||
"requests",
|
|
||||||
"scikit-learn",
|
|
||||||
"scipy",
|
|
||||||
"sentencepiece",
|
|
||||||
"sqlalchemy[asyncio]",
|
|
||||||
"tqdm",
|
|
||||||
"transformers",
|
|
||||||
"tree_sitter",
|
|
||||||
"uvicorn",
|
|
||||||
"sentence-transformers --no-deps",
|
|
||||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
|
||||||
],
|
|
||||||
"remote-vllm": [
|
|
||||||
"aiosqlite",
|
|
||||||
"autoevals",
|
|
||||||
"blobfile",
|
|
||||||
"chardet",
|
|
||||||
"chromadb-client",
|
|
||||||
"datasets",
|
|
||||||
"emoji",
|
|
||||||
"faiss-cpu",
|
|
||||||
"fastapi",
|
|
||||||
"fire",
|
|
||||||
"httpx",
|
|
||||||
"langdetect",
|
|
||||||
"matplotlib",
|
|
||||||
"mcp",
|
|
||||||
"nltk",
|
|
||||||
"numpy",
|
|
||||||
"openai",
|
|
||||||
"opentelemetry-exporter-otlp-proto-http",
|
|
||||||
"opentelemetry-sdk",
|
|
||||||
"pandas",
|
|
||||||
"pillow",
|
|
||||||
"psycopg2-binary",
|
|
||||||
"pymongo",
|
|
||||||
"pypdf",
|
|
||||||
"pythainlp",
|
|
||||||
"redis",
|
|
||||||
"requests",
|
|
||||||
"scikit-learn",
|
|
||||||
"scipy",
|
|
||||||
"sentencepiece",
|
|
||||||
"sqlalchemy[asyncio]",
|
|
||||||
"tqdm",
|
|
||||||
"transformers",
|
|
||||||
"tree_sitter",
|
|
||||||
"uvicorn",
|
|
||||||
"sentence-transformers --no-deps",
|
|
||||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
|
||||||
],
|
|
||||||
"sambanova": [
|
|
||||||
"aiosqlite",
|
|
||||||
"blobfile",
|
|
||||||
"chardet",
|
|
||||||
"chromadb-client",
|
|
||||||
"faiss-cpu",
|
|
||||||
"fastapi",
|
|
||||||
"fire",
|
|
||||||
"httpx",
|
|
||||||
"litellm",
|
|
||||||
"matplotlib",
|
|
||||||
"mcp",
|
|
||||||
"nltk",
|
|
||||||
"numpy",
|
|
||||||
"opentelemetry-exporter-otlp-proto-http",
|
|
||||||
"opentelemetry-sdk",
|
|
||||||
"pandas",
|
|
||||||
"pillow",
|
|
||||||
"psycopg2-binary",
|
|
||||||
"pymongo",
|
|
||||||
"pypdf",
|
|
||||||
"redis",
|
|
||||||
"requests",
|
|
||||||
"scikit-learn",
|
|
||||||
"scipy",
|
|
||||||
"sentencepiece",
|
|
||||||
"sqlalchemy[asyncio]",
|
|
||||||
"tqdm",
|
|
||||||
"transformers",
|
|
||||||
"uvicorn",
|
|
||||||
"sentence-transformers --no-deps",
|
|
||||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
|
||||||
],
|
|
||||||
"starter": [
|
|
||||||
"aiosqlite",
|
|
||||||
"autoevals",
|
|
||||||
"blobfile",
|
|
||||||
"chardet",
|
|
||||||
"chromadb-client",
|
|
||||||
"datasets",
|
|
||||||
"emoji",
|
|
||||||
"fastapi",
|
|
||||||
"fire",
|
|
||||||
"fireworks-ai",
|
|
||||||
"httpx",
|
|
||||||
"langdetect",
|
|
||||||
"litellm",
|
|
||||||
"matplotlib",
|
|
||||||
"mcp",
|
|
||||||
"nltk",
|
|
||||||
"numpy",
|
|
||||||
"openai",
|
|
||||||
"opentelemetry-exporter-otlp-proto-http",
|
|
||||||
"opentelemetry-sdk",
|
|
||||||
"pandas",
|
|
||||||
"pillow",
|
|
||||||
"psycopg2-binary",
|
|
||||||
"pymongo",
|
|
||||||
"pypdf",
|
|
||||||
"pythainlp",
|
|
||||||
"redis",
|
|
||||||
"requests",
|
|
||||||
"scikit-learn",
|
|
||||||
"scipy",
|
|
||||||
"sentencepiece",
|
|
||||||
"sqlalchemy[asyncio]",
|
|
||||||
"sqlite-vec",
|
|
||||||
"tqdm",
|
|
||||||
"transformers",
|
|
||||||
"tree_sitter",
|
|
||||||
"uvicorn",
|
|
||||||
"sentence-transformers --no-deps",
|
|
||||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
|
||||||
],
|
|
||||||
"tgi": [
|
|
||||||
"aiohttp",
|
|
||||||
"aiosqlite",
|
|
||||||
"autoevals",
|
|
||||||
"blobfile",
|
|
||||||
"chardet",
|
|
||||||
"chromadb-client",
|
|
||||||
"datasets",
|
|
||||||
"emoji",
|
|
||||||
"faiss-cpu",
|
|
||||||
"fastapi",
|
|
||||||
"fire",
|
|
||||||
"httpx",
|
|
||||||
"huggingface_hub",
|
|
||||||
"langdetect",
|
|
||||||
"matplotlib",
|
|
||||||
"mcp",
|
|
||||||
"nltk",
|
|
||||||
"numpy",
|
|
||||||
"openai",
|
|
||||||
"opentelemetry-exporter-otlp-proto-http",
|
|
||||||
"opentelemetry-sdk",
|
|
||||||
"pandas",
|
|
||||||
"pillow",
|
|
||||||
"psycopg2-binary",
|
|
||||||
"pymongo",
|
|
||||||
"pypdf",
|
|
||||||
"pythainlp",
|
|
||||||
"redis",
|
|
||||||
"requests",
|
|
||||||
"scikit-learn",
|
|
||||||
"scipy",
|
|
||||||
"sentencepiece",
|
|
||||||
"sqlalchemy[asyncio]",
|
|
||||||
"tqdm",
|
|
||||||
"transformers",
|
|
||||||
"tree_sitter",
|
|
||||||
"uvicorn",
|
|
||||||
"sentence-transformers --no-deps",
|
|
||||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
|
||||||
],
|
|
||||||
"together": [
|
|
||||||
"aiosqlite",
|
|
||||||
"autoevals",
|
|
||||||
"blobfile",
|
|
||||||
"chardet",
|
|
||||||
"chromadb-client",
|
|
||||||
"datasets",
|
|
||||||
"emoji",
|
|
||||||
"faiss-cpu",
|
|
||||||
"fastapi",
|
|
||||||
"fire",
|
|
||||||
"httpx",
|
|
||||||
"langdetect",
|
|
||||||
"matplotlib",
|
|
||||||
"mcp",
|
|
||||||
"nltk",
|
|
||||||
"numpy",
|
|
||||||
"openai",
|
|
||||||
"opentelemetry-exporter-otlp-proto-http",
|
|
||||||
"opentelemetry-sdk",
|
|
||||||
"pandas",
|
|
||||||
"pillow",
|
|
||||||
"psycopg2-binary",
|
|
||||||
"pymongo",
|
|
||||||
"pypdf",
|
|
||||||
"pythainlp",
|
|
||||||
"redis",
|
|
||||||
"requests",
|
|
||||||
"scikit-learn",
|
|
||||||
"scipy",
|
|
||||||
"sentencepiece",
|
|
||||||
"sqlalchemy[asyncio]",
|
|
||||||
"together",
|
|
||||||
"tqdm",
|
|
||||||
"transformers",
|
|
||||||
"tree_sitter",
|
|
||||||
"uvicorn",
|
|
||||||
"sentence-transformers --no-deps",
|
|
||||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
|
||||||
],
|
|
||||||
"verification": [
|
|
||||||
"aiosqlite",
|
|
||||||
"autoevals",
|
|
||||||
"blobfile",
|
|
||||||
"chardet",
|
|
||||||
"chromadb-client",
|
|
||||||
"datasets",
|
|
||||||
"emoji",
|
|
||||||
"fastapi",
|
|
||||||
"fire",
|
|
||||||
"httpx",
|
|
||||||
"langdetect",
|
|
||||||
"litellm",
|
|
||||||
"matplotlib",
|
|
||||||
"mcp",
|
|
||||||
"nltk",
|
|
||||||
"numpy",
|
|
||||||
"openai",
|
|
||||||
"opentelemetry-exporter-otlp-proto-http",
|
|
||||||
"opentelemetry-sdk",
|
|
||||||
"pandas",
|
|
||||||
"pillow",
|
|
||||||
"psycopg2-binary",
|
|
||||||
"pymongo",
|
|
||||||
"pypdf",
|
|
||||||
"pythainlp",
|
|
||||||
"redis",
|
|
||||||
"requests",
|
|
||||||
"scikit-learn",
|
|
||||||
"scipy",
|
|
||||||
"sentencepiece",
|
|
||||||
"sqlalchemy[asyncio]",
|
|
||||||
"sqlite-vec",
|
|
||||||
"tqdm",
|
|
||||||
"transformers",
|
|
||||||
"tree_sitter",
|
|
||||||
"uvicorn",
|
|
||||||
"sentence-transformers --no-deps",
|
|
||||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
|
||||||
],
|
|
||||||
"vllm-gpu": [
|
|
||||||
"aiosqlite",
|
|
||||||
"autoevals",
|
|
||||||
"blobfile",
|
|
||||||
"chardet",
|
|
||||||
"chromadb-client",
|
|
||||||
"datasets",
|
|
||||||
"emoji",
|
|
||||||
"faiss-cpu",
|
|
||||||
"fastapi",
|
|
||||||
"fire",
|
|
||||||
"httpx",
|
|
||||||
"langdetect",
|
|
||||||
"matplotlib",
|
|
||||||
"mcp",
|
|
||||||
"nltk",
|
|
||||||
"numpy",
|
|
||||||
"openai",
|
|
||||||
"opentelemetry-exporter-otlp-proto-http",
|
|
||||||
"opentelemetry-sdk",
|
|
||||||
"pandas",
|
|
||||||
"pillow",
|
|
||||||
"psycopg2-binary",
|
|
||||||
"pymongo",
|
|
||||||
"pypdf",
|
|
||||||
"pythainlp",
|
|
||||||
"redis",
|
|
||||||
"requests",
|
|
||||||
"scikit-learn",
|
|
||||||
"scipy",
|
|
||||||
"sentencepiece",
|
|
||||||
"sqlalchemy[asyncio]",
|
|
||||||
"tqdm",
|
|
||||||
"transformers",
|
|
||||||
"tree_sitter",
|
|
||||||
"uvicorn",
|
|
||||||
"vllm",
|
|
||||||
"sentence-transformers --no-deps",
|
|
||||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
|
||||||
],
|
|
||||||
"watsonx": [
|
|
||||||
"aiosqlite",
|
|
||||||
"autoevals",
|
|
||||||
"blobfile",
|
|
||||||
"chardet",
|
|
||||||
"datasets",
|
|
||||||
"emoji",
|
|
||||||
"faiss-cpu",
|
|
||||||
"fastapi",
|
|
||||||
"fire",
|
|
||||||
"httpx",
|
|
||||||
"ibm_watson_machine_learning",
|
|
||||||
"langdetect",
|
|
||||||
"matplotlib",
|
|
||||||
"mcp",
|
|
||||||
"nltk",
|
|
||||||
"numpy",
|
|
||||||
"openai",
|
|
||||||
"opentelemetry-exporter-otlp-proto-http",
|
|
||||||
"opentelemetry-sdk",
|
|
||||||
"pandas",
|
|
||||||
"pillow",
|
|
||||||
"psycopg2-binary",
|
|
||||||
"pymongo",
|
|
||||||
"pypdf",
|
|
||||||
"pythainlp",
|
|
||||||
"redis",
|
|
||||||
"requests",
|
|
||||||
"scikit-learn",
|
|
||||||
"scipy",
|
|
||||||
"sentencepiece",
|
|
||||||
"sqlalchemy[asyncio]",
|
|
||||||
"tqdm",
|
|
||||||
"transformers",
|
|
||||||
"tree_sitter",
|
|
||||||
"uvicorn",
|
|
||||||
"sentence-transformers --no-deps",
|
|
||||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
|
||||||
]
|
|
||||||
}
|
|
|
@ -25,23 +25,7 @@ distribution_spec:
|
||||||
- inline::rag-runtime
|
- inline::rag-runtime
|
||||||
- remote::model-context-protocol
|
- remote::model-context-protocol
|
||||||
- remote::wolfram-alpha
|
- remote::wolfram-alpha
|
||||||
metadata_store:
|
image_type: conda
|
||||||
type: sqlite
|
|
||||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db
|
|
||||||
inference_store:
|
|
||||||
type: sqlite
|
|
||||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/inference_store.db
|
|
||||||
models:
|
|
||||||
- metadata: {}
|
|
||||||
model_id: ${env.INFERENCE_MODEL}
|
|
||||||
provider_id: remote::ollama
|
|
||||||
model_type: llm
|
|
||||||
- metadata:
|
|
||||||
embedding_dimension: 384
|
|
||||||
model_id: all-MiniLM-L6-v2
|
|
||||||
provider_id: remote::ollama
|
|
||||||
provider_model_id: all-minilm:latest
|
|
||||||
model_type: embedding
|
|
||||||
image_type: container
|
|
||||||
additional_pip_packages:
|
additional_pip_packages:
|
||||||
- sqlalchemy[asyncio]
|
- sqlalchemy[asyncio]
|
||||||
|
- blobfile
|
||||||
|
|
|
@ -13,8 +13,8 @@ from llama_stack.distribution.datatypes import (
|
||||||
ShieldInput,
|
ShieldInput,
|
||||||
ToolGroupInput,
|
ToolGroupInput,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.inline.post_training.huggingface import HuggingFacePostTrainingConfig
|
#from llama_stack.providers.inline.post_training.huggingface import HuggingFacePostTrainingConfig
|
||||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
#from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||||
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
||||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
||||||
|
|
||||||
|
@ -32,7 +32,6 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
"tool_runtime": [
|
"tool_runtime": [
|
||||||
"remote::brave-search",
|
"remote::brave-search",
|
||||||
"remote::tavily-search",
|
"remote::tavily-search",
|
||||||
"inline::rag-runtime",
|
|
||||||
"remote::model-context-protocol",
|
"remote::model-context-protocol",
|
||||||
"remote::wolfram-alpha",
|
"remote::wolfram-alpha",
|
||||||
],
|
],
|
||||||
|
@ -43,11 +42,11 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
provider_type="remote::ollama",
|
provider_type="remote::ollama",
|
||||||
config=OllamaImplConfig.sample_run_config(),
|
config=OllamaImplConfig.sample_run_config(),
|
||||||
)
|
)
|
||||||
vector_io_provider_faiss = Provider(
|
#vector_io_provider_faiss = Provider(
|
||||||
provider_id="faiss",
|
# provider_id="faiss",
|
||||||
provider_type="inline::faiss",
|
# provider_type="inline::faiss",
|
||||||
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
# config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||||
)
|
#)
|
||||||
inference_model = ModelInput(
|
inference_model = ModelInput(
|
||||||
model_id="${env.INFERENCE_MODEL}",
|
model_id="${env.INFERENCE_MODEL}",
|
||||||
provider_id="ollama",
|
provider_id="ollama",
|
||||||
|
@ -70,10 +69,6 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
toolgroup_id="builtin::websearch",
|
toolgroup_id="builtin::websearch",
|
||||||
provider_id="tavily-search",
|
provider_id="tavily-search",
|
||||||
),
|
),
|
||||||
ToolGroupInput(
|
|
||||||
toolgroup_id="builtin::rag",
|
|
||||||
provider_id="rag-runtime",
|
|
||||||
),
|
|
||||||
ToolGroupInput(
|
ToolGroupInput(
|
||||||
toolgroup_id="builtin::wolfram_alpha",
|
toolgroup_id="builtin::wolfram_alpha",
|
||||||
provider_id="wolfram-alpha",
|
provider_id="wolfram-alpha",
|
||||||
|
|
|
@ -24,6 +24,10 @@ providers:
|
||||||
type: sqlite
|
type: sqlite
|
||||||
namespace: null
|
namespace: null
|
||||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db
|
||||||
|
- provider_id: chromadb
|
||||||
|
provider_type: remote::chromadb
|
||||||
|
config:
|
||||||
|
url: ${env.CHROMADB_URL:http://host.docker.internal:8000}
|
||||||
safety:
|
safety:
|
||||||
- provider_id: llama-guard
|
- provider_id: llama-guard
|
||||||
provider_type: inline::llama-guard
|
provider_type: inline::llama-guard
|
||||||
|
|
|
@ -2,9 +2,9 @@
|
||||||
|
|
||||||
import { useEffect, useState } from "react";
|
import { useEffect, useState } from "react";
|
||||||
import { useParams } from "next/navigation";
|
import { useParams } from "next/navigation";
|
||||||
import LlamaStackClient from "llama-stack-client";
|
|
||||||
import { ChatCompletion } from "@/lib/types";
|
import { ChatCompletion } from "@/lib/types";
|
||||||
import { ChatCompletionDetailView } from "@/components/chat-completions/chat-completion-detail";
|
import { ChatCompletionDetailView } from "@/components/chat-completions/chat-completion-detail";
|
||||||
|
import { client } from "@/lib/client";
|
||||||
|
|
||||||
export default function ChatCompletionDetailPage() {
|
export default function ChatCompletionDetailPage() {
|
||||||
const params = useParams();
|
const params = useParams();
|
||||||
|
@ -22,10 +22,6 @@ export default function ChatCompletionDetailPage() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const client = new LlamaStackClient({
|
|
||||||
baseURL: process.env.NEXT_PUBLIC_LLAMA_STACK_BASE_URL,
|
|
||||||
});
|
|
||||||
|
|
||||||
const fetchCompletionDetail = async () => {
|
const fetchCompletionDetail = async () => {
|
||||||
setIsLoading(true);
|
setIsLoading(true);
|
||||||
setError(null);
|
setError(null);
|
||||||
|
|
|
@ -1,45 +1,19 @@
|
||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import React from "react";
|
import React from "react";
|
||||||
import { usePathname, useParams } from "next/navigation";
|
import LogsLayout from "@/components/layout/logs-layout";
|
||||||
import {
|
|
||||||
PageBreadcrumb,
|
|
||||||
BreadcrumbSegment,
|
|
||||||
} from "@/components/layout/page-breadcrumb";
|
|
||||||
import { truncateText } from "@/lib/truncate-text";
|
|
||||||
|
|
||||||
export default function ChatCompletionsLayout({
|
export default function ChatCompletionsLayout({
|
||||||
children,
|
children,
|
||||||
}: {
|
}: {
|
||||||
children: React.ReactNode;
|
children: React.ReactNode;
|
||||||
}) {
|
}) {
|
||||||
const pathname = usePathname();
|
|
||||||
const params = useParams();
|
|
||||||
|
|
||||||
let segments: BreadcrumbSegment[] = [];
|
|
||||||
|
|
||||||
// Default for /logs/chat-completions
|
|
||||||
if (pathname === "/logs/chat-completions") {
|
|
||||||
segments = [{ label: "Chat Completions" }];
|
|
||||||
}
|
|
||||||
|
|
||||||
// For /logs/chat-completions/[id]
|
|
||||||
const idParam = params?.id;
|
|
||||||
if (idParam && typeof idParam === "string") {
|
|
||||||
segments = [
|
|
||||||
{ label: "Chat Completions", href: "/logs/chat-completions" },
|
|
||||||
{ label: `Details (${truncateText(idParam, 20)})` },
|
|
||||||
];
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="container mx-auto p-4">
|
<LogsLayout
|
||||||
<>
|
sectionLabel="Chat Completions"
|
||||||
{segments.length > 0 && (
|
basePath="/logs/chat-completions"
|
||||||
<PageBreadcrumb segments={segments} className="mb-4" />
|
>
|
||||||
)}
|
|
||||||
{children}
|
{children}
|
||||||
</>
|
</LogsLayout>
|
||||||
</div>
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { useEffect, useState } from "react";
|
import { useEffect, useState } from "react";
|
||||||
import LlamaStackClient from "llama-stack-client";
|
|
||||||
import { ChatCompletion } from "@/lib/types";
|
import { ChatCompletion } from "@/lib/types";
|
||||||
import { ChatCompletionsTable } from "@/components/chat-completions/chat-completion-table";
|
import { ChatCompletionsTable } from "@/components/chat-completions/chat-completions-table";
|
||||||
|
import { client } from "@/lib/client";
|
||||||
|
|
||||||
export default function ChatCompletionsPage() {
|
export default function ChatCompletionsPage() {
|
||||||
const [completions, setCompletions] = useState<ChatCompletion[]>([]);
|
const [completions, setCompletions] = useState<ChatCompletion[]>([]);
|
||||||
|
@ -11,9 +11,6 @@ export default function ChatCompletionsPage() {
|
||||||
const [error, setError] = useState<Error | null>(null);
|
const [error, setError] = useState<Error | null>(null);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const client = new LlamaStackClient({
|
|
||||||
baseURL: process.env.NEXT_PUBLIC_LLAMA_STACK_BASE_URL,
|
|
||||||
});
|
|
||||||
const fetchCompletions = async () => {
|
const fetchCompletions = async () => {
|
||||||
setIsLoading(true);
|
setIsLoading(true);
|
||||||
setError(null);
|
setError(null);
|
||||||
|
@ -21,7 +18,7 @@ export default function ChatCompletionsPage() {
|
||||||
const response = await client.chat.completions.list();
|
const response = await client.chat.completions.list();
|
||||||
const data = Array.isArray(response)
|
const data = Array.isArray(response)
|
||||||
? response
|
? response
|
||||||
: (response as any).data;
|
: (response as { data: ChatCompletion[] }).data;
|
||||||
|
|
||||||
if (Array.isArray(data)) {
|
if (Array.isArray(data)) {
|
||||||
setCompletions(data);
|
setCompletions(data);
|
||||||
|
@ -46,7 +43,7 @@ export default function ChatCompletionsPage() {
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<ChatCompletionsTable
|
<ChatCompletionsTable
|
||||||
completions={completions}
|
data={completions}
|
||||||
isLoading={isLoading}
|
isLoading={isLoading}
|
||||||
error={error}
|
error={error}
|
||||||
/>
|
/>
|
||||||
|
|
125
llama_stack/ui/app/logs/responses/[id]/page.tsx
Normal file
125
llama_stack/ui/app/logs/responses/[id]/page.tsx
Normal file
|
@ -0,0 +1,125 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
|
import { useEffect, useState } from "react";
|
||||||
|
import { useParams } from "next/navigation";
|
||||||
|
import type { ResponseObject } from "llama-stack-client/resources/responses/responses";
|
||||||
|
import { OpenAIResponse, InputItemListResponse } from "@/lib/types";
|
||||||
|
import { ResponseDetailView } from "@/components/responses/responses-detail";
|
||||||
|
import { client } from "@/lib/client";
|
||||||
|
|
||||||
|
export default function ResponseDetailPage() {
|
||||||
|
const params = useParams();
|
||||||
|
const id = params.id as string;
|
||||||
|
|
||||||
|
const [responseDetail, setResponseDetail] = useState<OpenAIResponse | null>(
|
||||||
|
null,
|
||||||
|
);
|
||||||
|
const [inputItems, setInputItems] = useState<InputItemListResponse | null>(
|
||||||
|
null,
|
||||||
|
);
|
||||||
|
const [isLoading, setIsLoading] = useState<boolean>(true);
|
||||||
|
const [isLoadingInputItems, setIsLoadingInputItems] = useState<boolean>(true);
|
||||||
|
const [error, setError] = useState<Error | null>(null);
|
||||||
|
const [inputItemsError, setInputItemsError] = useState<Error | null>(null);
|
||||||
|
|
||||||
|
// Helper function to convert ResponseObject to OpenAIResponse
|
||||||
|
const convertResponseObject = (
|
||||||
|
responseData: ResponseObject,
|
||||||
|
): OpenAIResponse => {
|
||||||
|
return {
|
||||||
|
id: responseData.id,
|
||||||
|
created_at: responseData.created_at,
|
||||||
|
model: responseData.model,
|
||||||
|
object: responseData.object,
|
||||||
|
status: responseData.status,
|
||||||
|
output: responseData.output as OpenAIResponse["output"],
|
||||||
|
input: [], // ResponseObject doesn't include input; component uses inputItems prop instead
|
||||||
|
error: responseData.error,
|
||||||
|
parallel_tool_calls: responseData.parallel_tool_calls,
|
||||||
|
previous_response_id: responseData.previous_response_id,
|
||||||
|
temperature: responseData.temperature,
|
||||||
|
top_p: responseData.top_p,
|
||||||
|
truncation: responseData.truncation,
|
||||||
|
user: responseData.user,
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (!id) {
|
||||||
|
setError(new Error("Response ID is missing."));
|
||||||
|
setIsLoading(false);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const fetchResponseDetail = async () => {
|
||||||
|
setIsLoading(true);
|
||||||
|
setIsLoadingInputItems(true);
|
||||||
|
setError(null);
|
||||||
|
setInputItemsError(null);
|
||||||
|
setResponseDetail(null);
|
||||||
|
setInputItems(null);
|
||||||
|
|
||||||
|
try {
|
||||||
|
const [responseResult, inputItemsResult] = await Promise.allSettled([
|
||||||
|
client.responses.retrieve(id),
|
||||||
|
client.responses.inputItems.list(id, { order: "asc" }),
|
||||||
|
]);
|
||||||
|
|
||||||
|
// Handle response detail result
|
||||||
|
if (responseResult.status === "fulfilled") {
|
||||||
|
const convertedResponse = convertResponseObject(responseResult.value);
|
||||||
|
setResponseDetail(convertedResponse);
|
||||||
|
} else {
|
||||||
|
console.error(
|
||||||
|
`Error fetching response detail for ID ${id}:`,
|
||||||
|
responseResult.reason,
|
||||||
|
);
|
||||||
|
setError(
|
||||||
|
responseResult.reason instanceof Error
|
||||||
|
? responseResult.reason
|
||||||
|
: new Error("Failed to fetch response detail"),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle input items result
|
||||||
|
if (inputItemsResult.status === "fulfilled") {
|
||||||
|
const inputItemsData =
|
||||||
|
inputItemsResult.value as unknown as InputItemListResponse;
|
||||||
|
setInputItems(inputItemsData);
|
||||||
|
} else {
|
||||||
|
console.error(
|
||||||
|
`Error fetching input items for response ID ${id}:`,
|
||||||
|
inputItemsResult.reason,
|
||||||
|
);
|
||||||
|
setInputItemsError(
|
||||||
|
inputItemsResult.reason instanceof Error
|
||||||
|
? inputItemsResult.reason
|
||||||
|
: new Error("Failed to fetch input items"),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} catch (err) {
|
||||||
|
console.error(`Unexpected error fetching data for ID ${id}:`, err);
|
||||||
|
setError(
|
||||||
|
err instanceof Error ? err : new Error("Unexpected error occurred"),
|
||||||
|
);
|
||||||
|
} finally {
|
||||||
|
setIsLoading(false);
|
||||||
|
setIsLoadingInputItems(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
fetchResponseDetail();
|
||||||
|
}, [id]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<ResponseDetailView
|
||||||
|
response={responseDetail}
|
||||||
|
inputItems={inputItems}
|
||||||
|
isLoading={isLoading}
|
||||||
|
isLoadingInputItems={isLoadingInputItems}
|
||||||
|
error={error}
|
||||||
|
inputItemsError={inputItemsError}
|
||||||
|
id={id}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
16
llama_stack/ui/app/logs/responses/layout.tsx
Normal file
16
llama_stack/ui/app/logs/responses/layout.tsx
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
|
import React from "react";
|
||||||
|
import LogsLayout from "@/components/layout/logs-layout";
|
||||||
|
|
||||||
|
export default function ResponsesLayout({
|
||||||
|
children,
|
||||||
|
}: {
|
||||||
|
children: React.ReactNode;
|
||||||
|
}) {
|
||||||
|
return (
|
||||||
|
<LogsLayout sectionLabel="Responses" basePath="/logs/responses">
|
||||||
|
{children}
|
||||||
|
</LogsLayout>
|
||||||
|
);
|
||||||
|
}
|
|
@ -1,7 +1,66 @@
|
||||||
export default function Responses() {
|
"use client";
|
||||||
|
|
||||||
|
import { useEffect, useState } from "react";
|
||||||
|
import type { ResponseListResponse } from "llama-stack-client/resources/responses/responses";
|
||||||
|
import { OpenAIResponse } from "@/lib/types";
|
||||||
|
import { ResponsesTable } from "@/components/responses/responses-table";
|
||||||
|
import { client } from "@/lib/client";
|
||||||
|
|
||||||
|
export default function ResponsesPage() {
|
||||||
|
const [responses, setResponses] = useState<OpenAIResponse[]>([]);
|
||||||
|
const [isLoading, setIsLoading] = useState<boolean>(true);
|
||||||
|
const [error, setError] = useState<Error | null>(null);
|
||||||
|
|
||||||
|
// Helper function to convert ResponseListResponse.Data to OpenAIResponse
|
||||||
|
const convertResponseListData = (
|
||||||
|
responseData: ResponseListResponse.Data,
|
||||||
|
): OpenAIResponse => {
|
||||||
|
return {
|
||||||
|
id: responseData.id,
|
||||||
|
created_at: responseData.created_at,
|
||||||
|
model: responseData.model,
|
||||||
|
object: responseData.object,
|
||||||
|
status: responseData.status,
|
||||||
|
output: responseData.output as OpenAIResponse["output"],
|
||||||
|
input: responseData.input as OpenAIResponse["input"],
|
||||||
|
error: responseData.error,
|
||||||
|
parallel_tool_calls: responseData.parallel_tool_calls,
|
||||||
|
previous_response_id: responseData.previous_response_id,
|
||||||
|
temperature: responseData.temperature,
|
||||||
|
top_p: responseData.top_p,
|
||||||
|
truncation: responseData.truncation,
|
||||||
|
user: responseData.user,
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
const fetchResponses = async () => {
|
||||||
|
setIsLoading(true);
|
||||||
|
setError(null);
|
||||||
|
try {
|
||||||
|
const response = await client.responses.list();
|
||||||
|
const responseListData = response as ResponseListResponse;
|
||||||
|
|
||||||
|
const convertedResponses: OpenAIResponse[] = responseListData.data.map(
|
||||||
|
convertResponseListData,
|
||||||
|
);
|
||||||
|
|
||||||
|
setResponses(convertedResponses);
|
||||||
|
} catch (err) {
|
||||||
|
console.error("Error fetching responses:", err);
|
||||||
|
setError(
|
||||||
|
err instanceof Error ? err : new Error("Failed to fetch responses"),
|
||||||
|
);
|
||||||
|
setResponses([]);
|
||||||
|
} finally {
|
||||||
|
setIsLoading(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
fetchResponses();
|
||||||
|
}, []);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div>
|
<ResponsesTable data={responses} isLoading={isLoading} error={error} />
|
||||||
<h1>Under Construction</h1>
|
|
||||||
</div>
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
|
@ -75,7 +75,7 @@ describe("ChatCompletionDetailView", () => {
|
||||||
/>,
|
/>,
|
||||||
);
|
);
|
||||||
expect(
|
expect(
|
||||||
screen.getByText("No details found for completion ID: notfound-id."),
|
screen.getByText("No details found for ID: notfound-id."),
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
@ -3,45 +3,14 @@
|
||||||
import { ChatMessage, ChatCompletion } from "@/lib/types";
|
import { ChatMessage, ChatCompletion } from "@/lib/types";
|
||||||
import { ChatMessageItem } from "@/components/chat-completions/chat-messasge-item";
|
import { ChatMessageItem } from "@/components/chat-completions/chat-messasge-item";
|
||||||
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
|
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
|
||||||
import { Skeleton } from "@/components/ui/skeleton";
|
import {
|
||||||
|
DetailLoadingView,
|
||||||
function ChatCompletionDetailLoadingView() {
|
DetailErrorView,
|
||||||
return (
|
DetailNotFoundView,
|
||||||
<>
|
DetailLayout,
|
||||||
<Skeleton className="h-8 w-3/4 mb-6" /> {/* Title Skeleton */}
|
PropertiesCard,
|
||||||
<div className="flex flex-col md:flex-row gap-6">
|
PropertyItem,
|
||||||
<div className="flex-grow md:w-2/3 space-y-6">
|
} from "@/components/layout/detail-layout";
|
||||||
{[...Array(2)].map((_, i) => (
|
|
||||||
<Card key={`main-skeleton-card-${i}`}>
|
|
||||||
<CardHeader>
|
|
||||||
<CardTitle>
|
|
||||||
<Skeleton className="h-6 w-1/2" />
|
|
||||||
</CardTitle>
|
|
||||||
</CardHeader>
|
|
||||||
<CardContent className="space-y-2">
|
|
||||||
<Skeleton className="h-4 w-full" />
|
|
||||||
<Skeleton className="h-4 w-full" />
|
|
||||||
<Skeleton className="h-4 w-3/4" />
|
|
||||||
</CardContent>
|
|
||||||
</Card>
|
|
||||||
))}
|
|
||||||
</div>
|
|
||||||
<div className="md:w-1/3">
|
|
||||||
<div className="p-4 border rounded-lg shadow-sm bg-white space-y-3">
|
|
||||||
<Skeleton className="h-6 w-1/3 mb-3" />{" "}
|
|
||||||
{/* Properties Title Skeleton */}
|
|
||||||
{[...Array(5)].map((_, i) => (
|
|
||||||
<div key={`prop-skeleton-${i}`} className="space-y-1">
|
|
||||||
<Skeleton className="h-4 w-1/4" />
|
|
||||||
<Skeleton className="h-4 w-1/2" />
|
|
||||||
</div>
|
|
||||||
))}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
interface ChatCompletionDetailViewProps {
|
interface ChatCompletionDetailViewProps {
|
||||||
completion: ChatCompletion | null;
|
completion: ChatCompletion | null;
|
||||||
|
@ -56,39 +25,23 @@ export function ChatCompletionDetailView({
|
||||||
error,
|
error,
|
||||||
id,
|
id,
|
||||||
}: ChatCompletionDetailViewProps) {
|
}: ChatCompletionDetailViewProps) {
|
||||||
|
const title = "Chat Completion Details";
|
||||||
|
|
||||||
if (error) {
|
if (error) {
|
||||||
return (
|
return <DetailErrorView title={title} id={id} error={error} />;
|
||||||
<>
|
|
||||||
{/* We still want a title for consistency on error pages */}
|
|
||||||
<h1 className="text-2xl font-bold mb-6">Chat Completion Details</h1>
|
|
||||||
<p>
|
|
||||||
Error loading details for ID {id}: {error.message}
|
|
||||||
</p>
|
|
||||||
</>
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isLoading) {
|
if (isLoading) {
|
||||||
return <ChatCompletionDetailLoadingView />;
|
return <DetailLoadingView title={title} />;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!completion) {
|
if (!completion) {
|
||||||
// This state means: not loading, no error, but no completion data
|
return <DetailNotFoundView title={title} id={id} />;
|
||||||
return (
|
|
||||||
<>
|
|
||||||
{/* We still want a title for consistency on not-found pages */}
|
|
||||||
<h1 className="text-2xl font-bold mb-6">Chat Completion Details</h1>
|
|
||||||
<p>No details found for completion ID: {id}.</p>
|
|
||||||
</>
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// If no error, not loading, and completion exists, render the details:
|
// Main content cards
|
||||||
return (
|
const mainContent = (
|
||||||
<>
|
<>
|
||||||
<h1 className="text-2xl font-bold mb-6">Chat Completion Details</h1>
|
|
||||||
<div className="flex flex-col md:flex-row gap-6">
|
|
||||||
<div className="flex-grow md:w-2/3 space-y-6">
|
|
||||||
<Card>
|
<Card>
|
||||||
<CardHeader>
|
<CardHeader>
|
||||||
<CardTitle>Input</CardTitle>
|
<CardTitle>Input</CardTitle>
|
||||||
|
@ -98,13 +51,15 @@ export function ChatCompletionDetailView({
|
||||||
<ChatMessageItem key={`input-msg-${index}`} message={msg} />
|
<ChatMessageItem key={`input-msg-${index}`} message={msg} />
|
||||||
))}
|
))}
|
||||||
{completion.choices?.[0]?.message?.tool_calls &&
|
{completion.choices?.[0]?.message?.tool_calls &&
|
||||||
|
Array.isArray(completion.choices[0].message.tool_calls) &&
|
||||||
!completion.input_messages?.some(
|
!completion.input_messages?.some(
|
||||||
(im) =>
|
(im) =>
|
||||||
im.role === "assistant" &&
|
im.role === "assistant" &&
|
||||||
im.tool_calls &&
|
im.tool_calls &&
|
||||||
|
Array.isArray(im.tool_calls) &&
|
||||||
im.tool_calls.length > 0,
|
im.tool_calls.length > 0,
|
||||||
) &&
|
)
|
||||||
completion.choices[0].message.tool_calls.map(
|
? completion.choices[0].message.tool_calls.map(
|
||||||
(toolCall: any, index: number) => {
|
(toolCall: any, index: number) => {
|
||||||
const assistantToolCallMessage: ChatMessage = {
|
const assistantToolCallMessage: ChatMessage = {
|
||||||
role: "assistant",
|
role: "assistant",
|
||||||
|
@ -118,7 +73,8 @@ export function ChatCompletionDetailView({
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
},
|
},
|
||||||
)}
|
)
|
||||||
|
: null}
|
||||||
</CardContent>
|
</CardContent>
|
||||||
</Card>
|
</Card>
|
||||||
|
|
||||||
|
@ -138,61 +94,52 @@ export function ChatCompletionDetailView({
|
||||||
)}
|
)}
|
||||||
</CardContent>
|
</CardContent>
|
||||||
</Card>
|
</Card>
|
||||||
</div>
|
</>
|
||||||
|
);
|
||||||
|
|
||||||
<div className="md:w-1/3">
|
// Properties sidebar
|
||||||
<Card>
|
const sidebar = (
|
||||||
<CardHeader>
|
<PropertiesCard>
|
||||||
<CardTitle>Properties</CardTitle>
|
<PropertyItem
|
||||||
</CardHeader>
|
label="Created"
|
||||||
<CardContent>
|
value={new Date(completion.created * 1000).toLocaleString()}
|
||||||
<ul className="space-y-2 text-sm text-gray-600">
|
/>
|
||||||
<li>
|
<PropertyItem label="ID" value={completion.id} />
|
||||||
<strong>Created:</strong>{" "}
|
<PropertyItem label="Model" value={completion.model} />
|
||||||
<span className="text-gray-900 font-medium">
|
<PropertyItem
|
||||||
{new Date(completion.created * 1000).toLocaleString()}
|
label="Finish Reason"
|
||||||
</span>
|
value={completion.choices?.[0]?.finish_reason || "N/A"}
|
||||||
</li>
|
hasBorder
|
||||||
<li>
|
/>
|
||||||
<strong>ID:</strong>{" "}
|
{(() => {
|
||||||
<span className="text-gray-900 font-medium">
|
const toolCalls = completion.choices?.[0]?.message?.tool_calls;
|
||||||
{completion.id}
|
if (toolCalls && Array.isArray(toolCalls) && toolCalls.length > 0) {
|
||||||
</span>
|
return (
|
||||||
</li>
|
<PropertyItem
|
||||||
<li>
|
label="Functions/Tools Called"
|
||||||
<strong>Model:</strong>{" "}
|
value={
|
||||||
<span className="text-gray-900 font-medium">
|
<div>
|
||||||
{completion.model}
|
|
||||||
</span>
|
|
||||||
</li>
|
|
||||||
<li className="pt-1 mt-1 border-t border-gray-200">
|
|
||||||
<strong>Finish Reason:</strong>{" "}
|
|
||||||
<span className="text-gray-900 font-medium">
|
|
||||||
{completion.choices?.[0]?.finish_reason || "N/A"}
|
|
||||||
</span>
|
|
||||||
</li>
|
|
||||||
{completion.choices?.[0]?.message?.tool_calls &&
|
|
||||||
completion.choices[0].message.tool_calls.length > 0 && (
|
|
||||||
<li className="pt-1 mt-1 border-t border-gray-200">
|
|
||||||
<strong>Functions/Tools Called:</strong>
|
|
||||||
<ul className="list-disc list-inside pl-4 mt-1">
|
<ul className="list-disc list-inside pl-4 mt-1">
|
||||||
{completion.choices[0].message.tool_calls.map(
|
{toolCalls.map((toolCall: any, index: number) => (
|
||||||
(toolCall: any, index: number) => (
|
|
||||||
<li key={index}>
|
<li key={index}>
|
||||||
<span className="text-gray-900 font-medium">
|
<span className="text-gray-900 font-medium">
|
||||||
{toolCall.function?.name || "N/A"}
|
{toolCall.function?.name || "N/A"}
|
||||||
</span>
|
</span>
|
||||||
</li>
|
</li>
|
||||||
),
|
))}
|
||||||
)}
|
|
||||||
</ul>
|
</ul>
|
||||||
</li>
|
|
||||||
)}
|
|
||||||
</ul>
|
|
||||||
</CardContent>
|
|
||||||
</Card>
|
|
||||||
</div>
|
</div>
|
||||||
</div>
|
}
|
||||||
</>
|
hasBorder
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
})()}
|
||||||
|
</PropertiesCard>
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<DetailLayout title={title} mainContent={mainContent} sidebar={sidebar} />
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
import React from "react";
|
import React from "react";
|
||||||
import { render, screen, fireEvent } from "@testing-library/react";
|
import { render, screen, fireEvent } from "@testing-library/react";
|
||||||
import "@testing-library/jest-dom";
|
import "@testing-library/jest-dom";
|
||||||
import { ChatCompletionsTable } from "./chat-completion-table";
|
import { ChatCompletionsTable } from "./chat-completions-table";
|
||||||
import { ChatCompletion } from "@/lib/types"; // Assuming this path is correct
|
import { ChatCompletion } from "@/lib/types";
|
||||||
|
|
||||||
// Mock next/navigation
|
// Mock next/navigation
|
||||||
const mockPush = jest.fn();
|
const mockPush = jest.fn();
|
||||||
|
@ -13,21 +13,25 @@ jest.mock("next/navigation", () => ({
|
||||||
}));
|
}));
|
||||||
|
|
||||||
// Mock helper functions
|
// Mock helper functions
|
||||||
// These are hoisted, so their mocks are available throughout the file
|
|
||||||
jest.mock("@/lib/truncate-text");
|
jest.mock("@/lib/truncate-text");
|
||||||
jest.mock("@/lib/format-tool-call");
|
jest.mock("@/lib/format-message-content");
|
||||||
|
|
||||||
// Import the mocked functions to set up default or specific implementations
|
// Import the mocked functions to set up default or specific implementations
|
||||||
import { truncateText as originalTruncateText } from "@/lib/truncate-text";
|
import { truncateText as originalTruncateText } from "@/lib/truncate-text";
|
||||||
import { formatToolCallToString as originalFormatToolCallToString } from "@/lib/format-tool-call";
|
import {
|
||||||
|
extractTextFromContentPart as originalExtractTextFromContentPart,
|
||||||
|
extractDisplayableText as originalExtractDisplayableText,
|
||||||
|
} from "@/lib/format-message-content";
|
||||||
|
|
||||||
// Cast to jest.Mock for typings
|
// Cast to jest.Mock for typings
|
||||||
const truncateText = originalTruncateText as jest.Mock;
|
const truncateText = originalTruncateText as jest.Mock;
|
||||||
const formatToolCallToString = originalFormatToolCallToString as jest.Mock;
|
const extractTextFromContentPart =
|
||||||
|
originalExtractTextFromContentPart as jest.Mock;
|
||||||
|
const extractDisplayableText = originalExtractDisplayableText as jest.Mock;
|
||||||
|
|
||||||
describe("ChatCompletionsTable", () => {
|
describe("ChatCompletionsTable", () => {
|
||||||
const defaultProps = {
|
const defaultProps = {
|
||||||
completions: [] as ChatCompletion[],
|
data: [] as ChatCompletion[],
|
||||||
isLoading: false,
|
isLoading: false,
|
||||||
error: null,
|
error: null,
|
||||||
};
|
};
|
||||||
|
@ -36,28 +40,26 @@ describe("ChatCompletionsTable", () => {
|
||||||
// Reset all mocks before each test
|
// Reset all mocks before each test
|
||||||
mockPush.mockClear();
|
mockPush.mockClear();
|
||||||
truncateText.mockClear();
|
truncateText.mockClear();
|
||||||
formatToolCallToString.mockClear();
|
extractTextFromContentPart.mockClear();
|
||||||
|
extractDisplayableText.mockClear();
|
||||||
|
|
||||||
// Default pass-through implementation for tests not focusing on truncation/formatting
|
// Default pass-through implementations
|
||||||
truncateText.mockImplementation((text: string | undefined) => text);
|
truncateText.mockImplementation((text: string | undefined) => text);
|
||||||
formatToolCallToString.mockImplementation((toolCall: any) =>
|
extractTextFromContentPart.mockImplementation((content: unknown) =>
|
||||||
toolCall && typeof toolCall === "object" && toolCall.name
|
typeof content === "string" ? content : "extracted text",
|
||||||
? `[DefaultToolCall:${toolCall.name}]`
|
);
|
||||||
: "[InvalidToolCall]",
|
extractDisplayableText.mockImplementation(
|
||||||
|
(message: unknown) =>
|
||||||
|
(message as { content?: string })?.content || "extracted output",
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
test("renders without crashing with default props", () => {
|
test("renders without crashing with default props", () => {
|
||||||
render(<ChatCompletionsTable {...defaultProps} />);
|
render(<ChatCompletionsTable {...defaultProps} />);
|
||||||
// Check for a unique element that should be present in the non-empty, non-loading, non-error state
|
|
||||||
// For now, as per Task 1, we will test the empty state message
|
|
||||||
expect(screen.getByText("No chat completions found.")).toBeInTheDocument();
|
expect(screen.getByText("No chat completions found.")).toBeInTheDocument();
|
||||||
});
|
});
|
||||||
|
|
||||||
test("click on a row navigates to the correct URL", () => {
|
test("click on a row navigates to the correct URL", () => {
|
||||||
const { rerender } = render(<ChatCompletionsTable {...defaultProps} />);
|
|
||||||
|
|
||||||
// Simulate a scenario where a completion exists and is clicked
|
|
||||||
const mockCompletion: ChatCompletion = {
|
const mockCompletion: ChatCompletion = {
|
||||||
id: "comp_123",
|
id: "comp_123",
|
||||||
object: "chat.completion",
|
object: "chat.completion",
|
||||||
|
@ -73,9 +75,12 @@ describe("ChatCompletionsTable", () => {
|
||||||
input_messages: [{ role: "user", content: "Test input" }],
|
input_messages: [{ role: "user", content: "Test input" }],
|
||||||
};
|
};
|
||||||
|
|
||||||
rerender(
|
// Set up mocks to return expected values
|
||||||
<ChatCompletionsTable {...defaultProps} completions={[mockCompletion]} />,
|
extractTextFromContentPart.mockReturnValue("Test input");
|
||||||
);
|
extractDisplayableText.mockReturnValue("Test output");
|
||||||
|
|
||||||
|
render(<ChatCompletionsTable {...defaultProps} data={[mockCompletion]} />);
|
||||||
|
|
||||||
const row = screen.getByText("Test input").closest("tr");
|
const row = screen.getByText("Test input").closest("tr");
|
||||||
if (row) {
|
if (row) {
|
||||||
fireEvent.click(row);
|
fireEvent.click(row);
|
||||||
|
@ -91,14 +96,13 @@ describe("ChatCompletionsTable", () => {
|
||||||
<ChatCompletionsTable {...defaultProps} isLoading={true} />,
|
<ChatCompletionsTable {...defaultProps} isLoading={true} />,
|
||||||
);
|
);
|
||||||
|
|
||||||
// The Skeleton component uses data-slot="skeleton"
|
|
||||||
const skeletonSelector = '[data-slot="skeleton"]';
|
|
||||||
|
|
||||||
// Check for skeleton in the table caption
|
// Check for skeleton in the table caption
|
||||||
const tableCaption = container.querySelector("caption");
|
const tableCaption = container.querySelector("caption");
|
||||||
expect(tableCaption).toBeInTheDocument();
|
expect(tableCaption).toBeInTheDocument();
|
||||||
if (tableCaption) {
|
if (tableCaption) {
|
||||||
const captionSkeleton = tableCaption.querySelector(skeletonSelector);
|
const captionSkeleton = tableCaption.querySelector(
|
||||||
|
'[data-slot="skeleton"]',
|
||||||
|
);
|
||||||
expect(captionSkeleton).toBeInTheDocument();
|
expect(captionSkeleton).toBeInTheDocument();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -107,16 +111,10 @@ describe("ChatCompletionsTable", () => {
|
||||||
expect(tableBody).toBeInTheDocument();
|
expect(tableBody).toBeInTheDocument();
|
||||||
if (tableBody) {
|
if (tableBody) {
|
||||||
const bodySkeletons = tableBody.querySelectorAll(
|
const bodySkeletons = tableBody.querySelectorAll(
|
||||||
`td ${skeletonSelector}`,
|
'[data-slot="skeleton"]',
|
||||||
);
|
);
|
||||||
expect(bodySkeletons.length).toBeGreaterThan(0); // Ensure at least one skeleton cell exists
|
expect(bodySkeletons.length).toBeGreaterThan(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// General check: ensure multiple skeleton elements are present in the table overall
|
|
||||||
const allSkeletonsInTable = container.querySelectorAll(
|
|
||||||
`table ${skeletonSelector}`,
|
|
||||||
);
|
|
||||||
expect(allSkeletonsInTable.length).toBeGreaterThan(3); // e.g., caption + at least one row of 3 cells, or just a few
|
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -140,14 +138,14 @@ describe("ChatCompletionsTable", () => {
|
||||||
{...defaultProps}
|
{...defaultProps}
|
||||||
error={{ name: "Error", message: "" }}
|
error={{ name: "Error", message: "" }}
|
||||||
/>,
|
/>,
|
||||||
); // Error with empty message
|
);
|
||||||
expect(
|
expect(
|
||||||
screen.getByText("Error fetching data: An unknown error occurred"),
|
screen.getByText("Error fetching data: An unknown error occurred"),
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
});
|
});
|
||||||
|
|
||||||
test("renders default error message when error prop is an object without message", () => {
|
test("renders default error message when error prop is an object without message", () => {
|
||||||
render(<ChatCompletionsTable {...defaultProps} error={{} as Error} />); // Empty error object
|
render(<ChatCompletionsTable {...defaultProps} error={{} as Error} />);
|
||||||
expect(
|
expect(
|
||||||
screen.getByText("Error fetching data: An unknown error occurred"),
|
screen.getByText("Error fetching data: An unknown error occurred"),
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
|
@ -155,14 +153,8 @@ describe("ChatCompletionsTable", () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
describe("Empty State", () => {
|
describe("Empty State", () => {
|
||||||
test('renders "No chat completions found." and no table when completions array is empty', () => {
|
test('renders "No chat completions found." and no table when data array is empty', () => {
|
||||||
render(
|
render(<ChatCompletionsTable data={[]} isLoading={false} error={null} />);
|
||||||
<ChatCompletionsTable
|
|
||||||
completions={[]}
|
|
||||||
isLoading={false}
|
|
||||||
error={null}
|
|
||||||
/>,
|
|
||||||
);
|
|
||||||
expect(
|
expect(
|
||||||
screen.getByText("No chat completions found."),
|
screen.getByText("No chat completions found."),
|
||||||
).toBeInTheDocument();
|
).toBeInTheDocument();
|
||||||
|
@ -179,7 +171,7 @@ describe("ChatCompletionsTable", () => {
|
||||||
{
|
{
|
||||||
id: "comp_1",
|
id: "comp_1",
|
||||||
object: "chat.completion",
|
object: "chat.completion",
|
||||||
created: 1710000000, // Fixed timestamp for test
|
created: 1710000000,
|
||||||
model: "llama-test-model",
|
model: "llama-test-model",
|
||||||
choices: [
|
choices: [
|
||||||
{
|
{
|
||||||
|
@ -206,9 +198,22 @@ describe("ChatCompletionsTable", () => {
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
|
|
||||||
|
// Set up mocks to return expected values
|
||||||
|
extractTextFromContentPart.mockImplementation((content: unknown) => {
|
||||||
|
if (content === "Test input") return "Test input";
|
||||||
|
if (content === "Another input") return "Another input";
|
||||||
|
return "extracted text";
|
||||||
|
});
|
||||||
|
extractDisplayableText.mockImplementation((message: unknown) => {
|
||||||
|
const msg = message as { content?: string };
|
||||||
|
if (msg?.content === "Test output") return "Test output";
|
||||||
|
if (msg?.content === "Another output") return "Another output";
|
||||||
|
return "extracted output";
|
||||||
|
});
|
||||||
|
|
||||||
render(
|
render(
|
||||||
<ChatCompletionsTable
|
<ChatCompletionsTable
|
||||||
completions={mockCompletions}
|
data={mockCompletions}
|
||||||
isLoading={false}
|
isLoading={false}
|
||||||
error={null}
|
error={null}
|
||||||
/>,
|
/>,
|
||||||
|
@ -242,7 +247,7 @@ describe("ChatCompletionsTable", () => {
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
describe("Text Truncation and Tool Call Formatting", () => {
|
describe("Text Truncation and Content Extraction", () => {
|
||||||
test("truncates long input and output text", () => {
|
test("truncates long input and output text", () => {
|
||||||
// Specific mock implementation for this test
|
// Specific mock implementation for this test
|
||||||
truncateText.mockImplementation(
|
truncateText.mockImplementation(
|
||||||
|
@ -259,6 +264,10 @@ describe("ChatCompletionsTable", () => {
|
||||||
"This is a very long input message that should be truncated.";
|
"This is a very long input message that should be truncated.";
|
||||||
const longOutput =
|
const longOutput =
|
||||||
"This is a very long output message that should also be truncated.";
|
"This is a very long output message that should also be truncated.";
|
||||||
|
|
||||||
|
extractTextFromContentPart.mockReturnValue(longInput);
|
||||||
|
extractDisplayableText.mockReturnValue(longOutput);
|
||||||
|
|
||||||
const mockCompletions = [
|
const mockCompletions = [
|
||||||
{
|
{
|
||||||
id: "comp_trunc",
|
id: "comp_trunc",
|
||||||
|
@ -278,7 +287,7 @@ describe("ChatCompletionsTable", () => {
|
||||||
|
|
||||||
render(
|
render(
|
||||||
<ChatCompletionsTable
|
<ChatCompletionsTable
|
||||||
completions={mockCompletions}
|
data={mockCompletions}
|
||||||
isLoading={false}
|
isLoading={false}
|
||||||
error={null}
|
error={null}
|
||||||
/>,
|
/>,
|
||||||
|
@ -289,52 +298,50 @@ describe("ChatCompletionsTable", () => {
|
||||||
longInput.slice(0, 10) + "...",
|
longInput.slice(0, 10) + "...",
|
||||||
);
|
);
|
||||||
expect(truncatedTexts.length).toBe(2); // one for input, one for output
|
expect(truncatedTexts.length).toBe(2); // one for input, one for output
|
||||||
// Optionally, verify each one is in the document if getAllByText doesn't throw on not found
|
|
||||||
truncatedTexts.forEach((textElement) =>
|
truncatedTexts.forEach((textElement) =>
|
||||||
expect(textElement).toBeInTheDocument(),
|
expect(textElement).toBeInTheDocument(),
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
test("formats tool call output using formatToolCallToString", () => {
|
test("uses content extraction functions correctly", () => {
|
||||||
// Specific mock implementation for this test
|
const mockCompletion = {
|
||||||
formatToolCallToString.mockImplementation(
|
id: "comp_extract",
|
||||||
(toolCall: any) => `[TOOL:${toolCall.name}]`,
|
|
||||||
);
|
|
||||||
// Ensure no truncation interferes for this specific test for clarity of tool call format
|
|
||||||
truncateText.mockImplementation((text: string | undefined) => text);
|
|
||||||
|
|
||||||
const toolCall = { name: "search", args: { query: "llama" } };
|
|
||||||
const mockCompletions = [
|
|
||||||
{
|
|
||||||
id: "comp_tool",
|
|
||||||
object: "chat.completion",
|
object: "chat.completion",
|
||||||
created: 1710003000,
|
created: 1710003000,
|
||||||
model: "llama-tool-model",
|
model: "llama-extract-model",
|
||||||
choices: [
|
choices: [
|
||||||
{
|
{
|
||||||
index: 0,
|
index: 0,
|
||||||
message: {
|
message: { role: "assistant", content: "Extracted output" },
|
||||||
role: "assistant",
|
|
||||||
content: "Tool output", // Content that will be prepended
|
|
||||||
tool_calls: [toolCall],
|
|
||||||
},
|
|
||||||
finish_reason: "stop",
|
finish_reason: "stop",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
input_messages: [{ role: "user", content: "Tool input" }],
|
input_messages: [{ role: "user", content: "Extracted input" }],
|
||||||
},
|
};
|
||||||
];
|
|
||||||
|
extractTextFromContentPart.mockReturnValue("Extracted input");
|
||||||
|
extractDisplayableText.mockReturnValue("Extracted output");
|
||||||
|
|
||||||
render(
|
render(
|
||||||
<ChatCompletionsTable
|
<ChatCompletionsTable
|
||||||
completions={mockCompletions}
|
data={[mockCompletion]}
|
||||||
isLoading={false}
|
isLoading={false}
|
||||||
error={null}
|
error={null}
|
||||||
/>,
|
/>,
|
||||||
);
|
);
|
||||||
|
|
||||||
// The component concatenates message.content and the formatted tool call
|
// Verify the extraction functions were called
|
||||||
expect(screen.getByText("Tool output [TOOL:search]")).toBeInTheDocument();
|
expect(extractTextFromContentPart).toHaveBeenCalledWith(
|
||||||
|
"Extracted input",
|
||||||
|
);
|
||||||
|
expect(extractDisplayableText).toHaveBeenCalledWith({
|
||||||
|
role: "assistant",
|
||||||
|
content: "Extracted output",
|
||||||
|
});
|
||||||
|
|
||||||
|
// Verify the extracted content is displayed
|
||||||
|
expect(screen.getByText("Extracted input")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Extracted output")).toBeInTheDocument();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
|
import { ChatCompletion } from "@/lib/types";
|
||||||
|
import { LogsTable, LogTableRow } from "@/components/logs/logs-table";
|
||||||
|
import {
|
||||||
|
extractTextFromContentPart,
|
||||||
|
extractDisplayableText,
|
||||||
|
} from "@/lib/format-message-content";
|
||||||
|
|
||||||
|
interface ChatCompletionsTableProps {
|
||||||
|
data: ChatCompletion[];
|
||||||
|
isLoading: boolean;
|
||||||
|
error: Error | null;
|
||||||
|
}
|
||||||
|
|
||||||
|
function formatChatCompletionToRow(completion: ChatCompletion): LogTableRow {
|
||||||
|
return {
|
||||||
|
id: completion.id,
|
||||||
|
input: extractTextFromContentPart(completion.input_messages?.[0]?.content),
|
||||||
|
output: extractDisplayableText(completion.choices?.[0]?.message),
|
||||||
|
model: completion.model,
|
||||||
|
createdTime: new Date(completion.created * 1000).toLocaleString(),
|
||||||
|
detailPath: `/logs/chat-completions/${completion.id}`,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
export function ChatCompletionsTable({
|
||||||
|
data,
|
||||||
|
isLoading,
|
||||||
|
error,
|
||||||
|
}: ChatCompletionsTableProps) {
|
||||||
|
const formattedData = data.map(formatChatCompletionToRow);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<LogsTable
|
||||||
|
data={formattedData}
|
||||||
|
isLoading={isLoading}
|
||||||
|
error={error}
|
||||||
|
caption="A list of your recent chat completions."
|
||||||
|
emptyMessage="No chat completions found."
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
|
@ -4,45 +4,10 @@ import { ChatMessage } from "@/lib/types";
|
||||||
import React from "react";
|
import React from "react";
|
||||||
import { formatToolCallToString } from "@/lib/format-tool-call";
|
import { formatToolCallToString } from "@/lib/format-tool-call";
|
||||||
import { extractTextFromContentPart } from "@/lib/format-message-content";
|
import { extractTextFromContentPart } from "@/lib/format-message-content";
|
||||||
|
import {
|
||||||
// Sub-component or helper for the common label + content structure
|
MessageBlock,
|
||||||
const MessageBlock: React.FC<{
|
ToolCallBlock,
|
||||||
label: string;
|
} from "@/components/ui/message-components";
|
||||||
labelDetail?: string;
|
|
||||||
content: React.ReactNode;
|
|
||||||
}> = ({ label, labelDetail, content }) => {
|
|
||||||
return (
|
|
||||||
<div>
|
|
||||||
<p className="py-1 font-semibold text-gray-800 mb-1">
|
|
||||||
{label}
|
|
||||||
{labelDetail && (
|
|
||||||
<span className="text-xs text-gray-500 font-normal ml-1">
|
|
||||||
{labelDetail}
|
|
||||||
</span>
|
|
||||||
)}
|
|
||||||
</p>
|
|
||||||
<div className="py-1">{content}</div>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
interface ToolCallBlockProps {
|
|
||||||
children: React.ReactNode;
|
|
||||||
className?: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
const ToolCallBlock = ({ children, className }: ToolCallBlockProps) => {
|
|
||||||
// Common styling for both function call arguments and tool output blocks
|
|
||||||
// Let's use slate-50 background as it's good for code-like content.
|
|
||||||
const baseClassName =
|
|
||||||
"p-3 bg-slate-50 border border-slate-200 rounded-md text-sm";
|
|
||||||
|
|
||||||
return (
|
|
||||||
<div className={`${baseClassName} ${className || ""}`}>
|
|
||||||
<pre className="whitespace-pre-wrap text-xs">{children}</pre>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
interface ChatMessageItemProps {
|
interface ChatMessageItemProps {
|
||||||
message: ChatMessage;
|
message: ChatMessage;
|
||||||
|
@ -65,7 +30,11 @@ export function ChatMessageItem({ message }: ChatMessageItemProps) {
|
||||||
);
|
);
|
||||||
|
|
||||||
case "assistant":
|
case "assistant":
|
||||||
if (message.tool_calls && message.tool_calls.length > 0) {
|
if (
|
||||||
|
message.tool_calls &&
|
||||||
|
Array.isArray(message.tool_calls) &&
|
||||||
|
message.tool_calls.length > 0
|
||||||
|
) {
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
{message.tool_calls.map((toolCall: any, index: number) => {
|
{message.tool_calls.map((toolCall: any, index: number) => {
|
||||||
|
|
141
llama_stack/ui/components/layout/detail-layout.tsx
Normal file
141
llama_stack/ui/components/layout/detail-layout.tsx
Normal file
|
@ -0,0 +1,141 @@
|
||||||
|
import React from "react";
|
||||||
|
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
|
||||||
|
import { Skeleton } from "@/components/ui/skeleton";
|
||||||
|
|
||||||
|
export function DetailLoadingView({ title }: { title: string }) {
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<Skeleton className="h-8 w-3/4 mb-6" /> {/* Title Skeleton */}
|
||||||
|
<div className="flex flex-col md:flex-row gap-6">
|
||||||
|
<div className="flex-grow md:w-2/3 space-y-6">
|
||||||
|
{[...Array(2)].map((_, i) => (
|
||||||
|
<Card key={`main-skeleton-card-${i}`}>
|
||||||
|
<CardHeader>
|
||||||
|
<CardTitle>
|
||||||
|
<Skeleton className="h-6 w-1/2" />
|
||||||
|
</CardTitle>
|
||||||
|
</CardHeader>
|
||||||
|
<CardContent className="space-y-2">
|
||||||
|
<Skeleton className="h-4 w-full" />
|
||||||
|
<Skeleton className="h-4 w-full" />
|
||||||
|
<Skeleton className="h-4 w-3/4" />
|
||||||
|
</CardContent>
|
||||||
|
</Card>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
<div className="md:w-1/3">
|
||||||
|
<div className="p-4 border rounded-lg shadow-sm bg-white space-y-3">
|
||||||
|
<Skeleton className="h-6 w-1/3 mb-3" />{" "}
|
||||||
|
{/* Properties Title Skeleton */}
|
||||||
|
{[...Array(5)].map((_, i) => (
|
||||||
|
<div key={`prop-skeleton-${i}`} className="space-y-1">
|
||||||
|
<Skeleton className="h-4 w-1/4" />
|
||||||
|
<Skeleton className="h-4 w-1/2" />
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function DetailErrorView({
|
||||||
|
title,
|
||||||
|
id,
|
||||||
|
error,
|
||||||
|
}: {
|
||||||
|
title: string;
|
||||||
|
id: string;
|
||||||
|
error: Error;
|
||||||
|
}) {
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<h1 className="text-2xl font-bold mb-6">{title}</h1>
|
||||||
|
<p>
|
||||||
|
Error loading details for ID {id}: {error.message}
|
||||||
|
</p>
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function DetailNotFoundView({
|
||||||
|
title,
|
||||||
|
id,
|
||||||
|
}: {
|
||||||
|
title: string;
|
||||||
|
id: string;
|
||||||
|
}) {
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<h1 className="text-2xl font-bold mb-6">{title}</h1>
|
||||||
|
<p>No details found for ID: {id}.</p>
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface PropertyItemProps {
|
||||||
|
label: string;
|
||||||
|
value: React.ReactNode;
|
||||||
|
className?: string;
|
||||||
|
hasBorder?: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function PropertyItem({
|
||||||
|
label,
|
||||||
|
value,
|
||||||
|
className = "",
|
||||||
|
hasBorder = false,
|
||||||
|
}: PropertyItemProps) {
|
||||||
|
return (
|
||||||
|
<li
|
||||||
|
className={`${hasBorder ? "pt-1 mt-1 border-t border-gray-200" : ""} ${className}`}
|
||||||
|
>
|
||||||
|
<strong>{label}:</strong>{" "}
|
||||||
|
{typeof value === "string" || typeof value === "number" ? (
|
||||||
|
<span className="text-gray-900 font-medium">{value}</span>
|
||||||
|
) : (
|
||||||
|
value
|
||||||
|
)}
|
||||||
|
</li>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface PropertiesCardProps {
|
||||||
|
children: React.ReactNode;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function PropertiesCard({ children }: PropertiesCardProps) {
|
||||||
|
return (
|
||||||
|
<Card>
|
||||||
|
<CardHeader>
|
||||||
|
<CardTitle>Properties</CardTitle>
|
||||||
|
</CardHeader>
|
||||||
|
<CardContent>
|
||||||
|
<ul className="space-y-2 text-sm text-gray-600">{children}</ul>
|
||||||
|
</CardContent>
|
||||||
|
</Card>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface DetailLayoutProps {
|
||||||
|
title: string;
|
||||||
|
mainContent: React.ReactNode;
|
||||||
|
sidebar: React.ReactNode;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function DetailLayout({
|
||||||
|
title,
|
||||||
|
mainContent,
|
||||||
|
sidebar,
|
||||||
|
}: DetailLayoutProps) {
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<h1 className="text-2xl font-bold mb-6">{title}</h1>
|
||||||
|
<div className="flex flex-col md:flex-row gap-6">
|
||||||
|
<div className="flex-grow md:w-2/3 space-y-6">{mainContent}</div>
|
||||||
|
<div className="md:w-1/3">{sidebar}</div>
|
||||||
|
</div>
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
}
|
49
llama_stack/ui/components/layout/logs-layout.tsx
Normal file
49
llama_stack/ui/components/layout/logs-layout.tsx
Normal file
|
@ -0,0 +1,49 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
|
import React from "react";
|
||||||
|
import { usePathname, useParams } from "next/navigation";
|
||||||
|
import {
|
||||||
|
PageBreadcrumb,
|
||||||
|
BreadcrumbSegment,
|
||||||
|
} from "@/components/layout/page-breadcrumb";
|
||||||
|
import { truncateText } from "@/lib/truncate-text";
|
||||||
|
|
||||||
|
interface LogsLayoutProps {
|
||||||
|
children: React.ReactNode;
|
||||||
|
sectionLabel: string;
|
||||||
|
basePath: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function LogsLayout({
|
||||||
|
children,
|
||||||
|
sectionLabel,
|
||||||
|
basePath,
|
||||||
|
}: LogsLayoutProps) {
|
||||||
|
const pathname = usePathname();
|
||||||
|
const params = useParams();
|
||||||
|
|
||||||
|
let segments: BreadcrumbSegment[] = [];
|
||||||
|
|
||||||
|
if (pathname === basePath) {
|
||||||
|
segments = [{ label: sectionLabel }];
|
||||||
|
}
|
||||||
|
|
||||||
|
const idParam = params?.id;
|
||||||
|
if (idParam && typeof idParam === "string") {
|
||||||
|
segments = [
|
||||||
|
{ label: sectionLabel, href: basePath },
|
||||||
|
{ label: `Details (${truncateText(idParam, 20)})` },
|
||||||
|
];
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="container mx-auto p-4">
|
||||||
|
<>
|
||||||
|
{segments.length > 0 && (
|
||||||
|
<PageBreadcrumb segments={segments} className="mb-4" />
|
||||||
|
)}
|
||||||
|
{children}
|
||||||
|
</>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
350
llama_stack/ui/components/logs/logs-table.test.tsx
Normal file
350
llama_stack/ui/components/logs/logs-table.test.tsx
Normal file
|
@ -0,0 +1,350 @@
|
||||||
|
import React from "react";
|
||||||
|
import { render, screen, fireEvent } from "@testing-library/react";
|
||||||
|
import "@testing-library/jest-dom";
|
||||||
|
import { LogsTable, LogTableRow } from "./logs-table";
|
||||||
|
|
||||||
|
// Mock next/navigation
|
||||||
|
const mockPush = jest.fn();
|
||||||
|
jest.mock("next/navigation", () => ({
|
||||||
|
useRouter: () => ({
|
||||||
|
push: mockPush,
|
||||||
|
}),
|
||||||
|
}));
|
||||||
|
|
||||||
|
// Mock helper functions
|
||||||
|
jest.mock("@/lib/truncate-text");
|
||||||
|
|
||||||
|
// Import the mocked functions
|
||||||
|
import { truncateText as originalTruncateText } from "@/lib/truncate-text";
|
||||||
|
|
||||||
|
// Cast to jest.Mock for typings
|
||||||
|
const truncateText = originalTruncateText as jest.Mock;
|
||||||
|
|
||||||
|
describe("LogsTable", () => {
|
||||||
|
const defaultProps = {
|
||||||
|
data: [] as LogTableRow[],
|
||||||
|
isLoading: false,
|
||||||
|
error: null,
|
||||||
|
caption: "Test table caption",
|
||||||
|
emptyMessage: "No data found",
|
||||||
|
};
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
// Reset all mocks before each test
|
||||||
|
mockPush.mockClear();
|
||||||
|
truncateText.mockClear();
|
||||||
|
|
||||||
|
// Default pass-through implementation
|
||||||
|
truncateText.mockImplementation((text: string | undefined) => text);
|
||||||
|
});
|
||||||
|
|
||||||
|
test("renders without crashing with default props", () => {
|
||||||
|
render(<LogsTable {...defaultProps} />);
|
||||||
|
expect(screen.getByText("No data found")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("click on a row navigates to the correct URL", () => {
|
||||||
|
const mockData: LogTableRow[] = [
|
||||||
|
{
|
||||||
|
id: "row_123",
|
||||||
|
input: "Test input",
|
||||||
|
output: "Test output",
|
||||||
|
model: "test-model",
|
||||||
|
createdTime: "2024-01-01 12:00:00",
|
||||||
|
detailPath: "/test/path/row_123",
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
render(<LogsTable {...defaultProps} data={mockData} />);
|
||||||
|
|
||||||
|
const row = screen.getByText("Test input").closest("tr");
|
||||||
|
if (row) {
|
||||||
|
fireEvent.click(row);
|
||||||
|
expect(mockPush).toHaveBeenCalledWith("/test/path/row_123");
|
||||||
|
} else {
|
||||||
|
throw new Error('Row with "Test input" not found for router mock test.');
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Loading State", () => {
|
||||||
|
test("renders skeleton UI when isLoading is true", () => {
|
||||||
|
const { container } = render(
|
||||||
|
<LogsTable {...defaultProps} isLoading={true} />,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Check for skeleton in the table caption
|
||||||
|
const tableCaption = container.querySelector("caption");
|
||||||
|
expect(tableCaption).toBeInTheDocument();
|
||||||
|
if (tableCaption) {
|
||||||
|
const captionSkeleton = tableCaption.querySelector(
|
||||||
|
'[data-slot="skeleton"]',
|
||||||
|
);
|
||||||
|
expect(captionSkeleton).toBeInTheDocument();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for skeletons in the table body cells
|
||||||
|
const tableBody = container.querySelector("tbody");
|
||||||
|
expect(tableBody).toBeInTheDocument();
|
||||||
|
if (tableBody) {
|
||||||
|
const bodySkeletons = tableBody.querySelectorAll(
|
||||||
|
'[data-slot="skeleton"]',
|
||||||
|
);
|
||||||
|
expect(bodySkeletons.length).toBeGreaterThan(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that table headers are still rendered
|
||||||
|
expect(screen.getByText("Input")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Output")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Model")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Created")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("renders correct number of skeleton rows", () => {
|
||||||
|
const { container } = render(
|
||||||
|
<LogsTable {...defaultProps} isLoading={true} />,
|
||||||
|
);
|
||||||
|
|
||||||
|
const skeletonRows = container.querySelectorAll("tbody tr");
|
||||||
|
expect(skeletonRows.length).toBe(3); // Should render 3 skeleton rows
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Error State", () => {
|
||||||
|
test("renders error message when error prop is provided", () => {
|
||||||
|
const errorMessage = "Network Error";
|
||||||
|
render(
|
||||||
|
<LogsTable
|
||||||
|
{...defaultProps}
|
||||||
|
error={{ name: "Error", message: errorMessage }}
|
||||||
|
/>,
|
||||||
|
);
|
||||||
|
expect(
|
||||||
|
screen.getByText(`Error fetching data: ${errorMessage}`),
|
||||||
|
).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("renders default error message when error.message is not available", () => {
|
||||||
|
render(
|
||||||
|
<LogsTable {...defaultProps} error={{ name: "Error", message: "" }} />,
|
||||||
|
);
|
||||||
|
expect(
|
||||||
|
screen.getByText("Error fetching data: An unknown error occurred"),
|
||||||
|
).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("renders default error message when error prop is an object without message", () => {
|
||||||
|
render(<LogsTable {...defaultProps} error={{} as Error} />);
|
||||||
|
expect(
|
||||||
|
screen.getByText("Error fetching data: An unknown error occurred"),
|
||||||
|
).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("does not render table when in error state", () => {
|
||||||
|
render(
|
||||||
|
<LogsTable
|
||||||
|
{...defaultProps}
|
||||||
|
error={{ name: "Error", message: "Test error" }}
|
||||||
|
/>,
|
||||||
|
);
|
||||||
|
const table = screen.queryByRole("table");
|
||||||
|
expect(table).not.toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Empty State", () => {
|
||||||
|
test("renders custom empty message when data array is empty", () => {
|
||||||
|
render(
|
||||||
|
<LogsTable
|
||||||
|
{...defaultProps}
|
||||||
|
data={[]}
|
||||||
|
emptyMessage="Custom empty message"
|
||||||
|
/>,
|
||||||
|
);
|
||||||
|
expect(screen.getByText("Custom empty message")).toBeInTheDocument();
|
||||||
|
|
||||||
|
// Ensure that the table structure is NOT rendered in the empty state
|
||||||
|
const table = screen.queryByRole("table");
|
||||||
|
expect(table).not.toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Data Rendering", () => {
|
||||||
|
test("renders table caption, headers, and data correctly", () => {
|
||||||
|
const mockData: LogTableRow[] = [
|
||||||
|
{
|
||||||
|
id: "row_1",
|
||||||
|
input: "First input",
|
||||||
|
output: "First output",
|
||||||
|
model: "model-1",
|
||||||
|
createdTime: "2024-01-01 12:00:00",
|
||||||
|
detailPath: "/path/1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: "row_2",
|
||||||
|
input: "Second input",
|
||||||
|
output: "Second output",
|
||||||
|
model: "model-2",
|
||||||
|
createdTime: "2024-01-02 13:00:00",
|
||||||
|
detailPath: "/path/2",
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
render(
|
||||||
|
<LogsTable
|
||||||
|
{...defaultProps}
|
||||||
|
data={mockData}
|
||||||
|
caption="Custom table caption"
|
||||||
|
/>,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Table caption
|
||||||
|
expect(screen.getByText("Custom table caption")).toBeInTheDocument();
|
||||||
|
|
||||||
|
// Table headers
|
||||||
|
expect(screen.getByText("Input")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Output")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Model")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Created")).toBeInTheDocument();
|
||||||
|
|
||||||
|
// Data rows
|
||||||
|
expect(screen.getByText("First input")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("First output")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("model-1")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("2024-01-01 12:00:00")).toBeInTheDocument();
|
||||||
|
|
||||||
|
expect(screen.getByText("Second input")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Second output")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("model-2")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("2024-01-02 13:00:00")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("applies correct CSS classes to table rows", () => {
|
||||||
|
const mockData: LogTableRow[] = [
|
||||||
|
{
|
||||||
|
id: "row_1",
|
||||||
|
input: "Test input",
|
||||||
|
output: "Test output",
|
||||||
|
model: "test-model",
|
||||||
|
createdTime: "2024-01-01 12:00:00",
|
||||||
|
detailPath: "/test/path",
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
render(<LogsTable {...defaultProps} data={mockData} />);
|
||||||
|
|
||||||
|
const row = screen.getByText("Test input").closest("tr");
|
||||||
|
expect(row).toHaveClass("cursor-pointer");
|
||||||
|
expect(row).toHaveClass("hover:bg-muted/50");
|
||||||
|
});
|
||||||
|
|
||||||
|
test("applies correct alignment to Created column", () => {
|
||||||
|
const mockData: LogTableRow[] = [
|
||||||
|
{
|
||||||
|
id: "row_1",
|
||||||
|
input: "Test input",
|
||||||
|
output: "Test output",
|
||||||
|
model: "test-model",
|
||||||
|
createdTime: "2024-01-01 12:00:00",
|
||||||
|
detailPath: "/test/path",
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
render(<LogsTable {...defaultProps} data={mockData} />);
|
||||||
|
|
||||||
|
const createdCell = screen.getByText("2024-01-01 12:00:00").closest("td");
|
||||||
|
expect(createdCell).toHaveClass("text-right");
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Text Truncation", () => {
|
||||||
|
test("truncates input and output text using truncateText function", () => {
|
||||||
|
// Mock truncateText to return truncated versions
|
||||||
|
truncateText.mockImplementation((text: string | undefined) => {
|
||||||
|
if (typeof text === "string" && text.length > 10) {
|
||||||
|
return text.slice(0, 10) + "...";
|
||||||
|
}
|
||||||
|
return text;
|
||||||
|
});
|
||||||
|
|
||||||
|
const longInput =
|
||||||
|
"This is a very long input text that should be truncated";
|
||||||
|
const longOutput =
|
||||||
|
"This is a very long output text that should be truncated";
|
||||||
|
|
||||||
|
const mockData: LogTableRow[] = [
|
||||||
|
{
|
||||||
|
id: "row_1",
|
||||||
|
input: longInput,
|
||||||
|
output: longOutput,
|
||||||
|
model: "test-model",
|
||||||
|
createdTime: "2024-01-01 12:00:00",
|
||||||
|
detailPath: "/test/path",
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
render(<LogsTable {...defaultProps} data={mockData} />);
|
||||||
|
|
||||||
|
// Verify truncateText was called
|
||||||
|
expect(truncateText).toHaveBeenCalledWith(longInput);
|
||||||
|
expect(truncateText).toHaveBeenCalledWith(longOutput);
|
||||||
|
|
||||||
|
// Verify truncated text is displayed
|
||||||
|
const truncatedTexts = screen.getAllByText("This is a ...");
|
||||||
|
expect(truncatedTexts).toHaveLength(2); // one for input, one for output
|
||||||
|
truncatedTexts.forEach((textElement) =>
|
||||||
|
expect(textElement).toBeInTheDocument(),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
test("does not truncate model names", () => {
|
||||||
|
const mockData: LogTableRow[] = [
|
||||||
|
{
|
||||||
|
id: "row_1",
|
||||||
|
input: "Test input",
|
||||||
|
output: "Test output",
|
||||||
|
model: "very-long-model-name-that-should-not-be-truncated",
|
||||||
|
createdTime: "2024-01-01 12:00:00",
|
||||||
|
detailPath: "/test/path",
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
render(<LogsTable {...defaultProps} data={mockData} />);
|
||||||
|
|
||||||
|
// Model name should not be passed to truncateText
|
||||||
|
expect(truncateText).not.toHaveBeenCalledWith(
|
||||||
|
"very-long-model-name-that-should-not-be-truncated",
|
||||||
|
);
|
||||||
|
|
||||||
|
// Full model name should be displayed
|
||||||
|
expect(
|
||||||
|
screen.getByText("very-long-model-name-that-should-not-be-truncated"),
|
||||||
|
).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Accessibility", () => {
|
||||||
|
test("table has proper role and structure", () => {
|
||||||
|
const mockData: LogTableRow[] = [
|
||||||
|
{
|
||||||
|
id: "row_1",
|
||||||
|
input: "Test input",
|
||||||
|
output: "Test output",
|
||||||
|
model: "test-model",
|
||||||
|
createdTime: "2024-01-01 12:00:00",
|
||||||
|
detailPath: "/test/path",
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
render(<LogsTable {...defaultProps} data={mockData} />);
|
||||||
|
|
||||||
|
const table = screen.getByRole("table");
|
||||||
|
expect(table).toBeInTheDocument();
|
||||||
|
|
||||||
|
const columnHeaders = screen.getAllByRole("columnheader");
|
||||||
|
expect(columnHeaders).toHaveLength(4);
|
||||||
|
|
||||||
|
const rows = screen.getAllByRole("row");
|
||||||
|
expect(rows).toHaveLength(2); // 1 header row + 1 data row
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
|
@ -1,12 +1,7 @@
|
||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { useRouter } from "next/navigation";
|
import { useRouter } from "next/navigation";
|
||||||
import { ChatCompletion } from "@/lib/types";
|
|
||||||
import { truncateText } from "@/lib/truncate-text";
|
import { truncateText } from "@/lib/truncate-text";
|
||||||
import {
|
|
||||||
extractTextFromContentPart,
|
|
||||||
extractDisplayableText,
|
|
||||||
} from "@/lib/format-message-content";
|
|
||||||
import {
|
import {
|
||||||
Table,
|
Table,
|
||||||
TableBody,
|
TableBody,
|
||||||
|
@ -18,17 +13,31 @@ import {
|
||||||
} from "@/components/ui/table";
|
} from "@/components/ui/table";
|
||||||
import { Skeleton } from "@/components/ui/skeleton";
|
import { Skeleton } from "@/components/ui/skeleton";
|
||||||
|
|
||||||
interface ChatCompletionsTableProps {
|
// Generic table row data interface
|
||||||
completions: ChatCompletion[];
|
export interface LogTableRow {
|
||||||
isLoading: boolean;
|
id: string;
|
||||||
error: Error | null;
|
input: string;
|
||||||
|
output: string;
|
||||||
|
model: string;
|
||||||
|
createdTime: string;
|
||||||
|
detailPath: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function ChatCompletionsTable({
|
interface LogsTableProps {
|
||||||
completions,
|
data: LogTableRow[];
|
||||||
|
isLoading: boolean;
|
||||||
|
error: Error | null;
|
||||||
|
caption: string;
|
||||||
|
emptyMessage: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function LogsTable({
|
||||||
|
data,
|
||||||
isLoading,
|
isLoading,
|
||||||
error,
|
error,
|
||||||
}: ChatCompletionsTableProps) {
|
caption,
|
||||||
|
emptyMessage,
|
||||||
|
}: LogsTableProps) {
|
||||||
const router = useRouter();
|
const router = useRouter();
|
||||||
|
|
||||||
const tableHeader = (
|
const tableHeader = (
|
||||||
|
@ -77,41 +86,25 @@ export function ChatCompletionsTable({
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (completions.length === 0) {
|
if (data.length === 0) {
|
||||||
return <p>No chat completions found.</p>;
|
return <p>{emptyMessage}</p>;
|
||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Table>
|
<Table>
|
||||||
<TableCaption>A list of your recent chat completions.</TableCaption>
|
<TableCaption>{caption}</TableCaption>
|
||||||
{tableHeader}
|
{tableHeader}
|
||||||
<TableBody>
|
<TableBody>
|
||||||
{completions.map((completion) => (
|
{data.map((row) => (
|
||||||
<TableRow
|
<TableRow
|
||||||
key={completion.id}
|
key={row.id}
|
||||||
onClick={() =>
|
onClick={() => router.push(row.detailPath)}
|
||||||
router.push(`/logs/chat-completions/${completion.id}`)
|
|
||||||
}
|
|
||||||
className="cursor-pointer hover:bg-muted/50"
|
className="cursor-pointer hover:bg-muted/50"
|
||||||
>
|
>
|
||||||
<TableCell>
|
<TableCell>{truncateText(row.input)}</TableCell>
|
||||||
{truncateText(
|
<TableCell>{truncateText(row.output)}</TableCell>
|
||||||
extractTextFromContentPart(
|
<TableCell>{row.model}</TableCell>
|
||||||
completion.input_messages?.[0]?.content,
|
<TableCell className="text-right">{row.createdTime}</TableCell>
|
||||||
),
|
|
||||||
)}
|
|
||||||
</TableCell>
|
|
||||||
<TableCell>
|
|
||||||
{(() => {
|
|
||||||
const message = completion.choices?.[0]?.message;
|
|
||||||
const outputText = extractDisplayableText(message);
|
|
||||||
return truncateText(outputText);
|
|
||||||
})()}
|
|
||||||
</TableCell>
|
|
||||||
<TableCell>{completion.model}</TableCell>
|
|
||||||
<TableCell className="text-right">
|
|
||||||
{new Date(completion.created * 1000).toLocaleString()}
|
|
||||||
</TableCell>
|
|
||||||
</TableRow>
|
</TableRow>
|
||||||
))}
|
))}
|
||||||
</TableBody>
|
</TableBody>
|
|
@ -0,0 +1,56 @@
|
||||||
|
import { useFunctionCallGrouping } from "../hooks/function-call-grouping";
|
||||||
|
import { ItemRenderer } from "../items/item-renderer";
|
||||||
|
import { GroupedFunctionCallItemComponent } from "../items/grouped-function-call-item";
|
||||||
|
import {
|
||||||
|
isFunctionCallItem,
|
||||||
|
isFunctionCallOutputItem,
|
||||||
|
AnyResponseItem,
|
||||||
|
} from "../utils/item-types";
|
||||||
|
|
||||||
|
interface GroupedItemsDisplayProps {
|
||||||
|
items: AnyResponseItem[];
|
||||||
|
keyPrefix: string;
|
||||||
|
defaultRole?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function GroupedItemsDisplay({
|
||||||
|
items,
|
||||||
|
keyPrefix,
|
||||||
|
defaultRole = "unknown",
|
||||||
|
}: GroupedItemsDisplayProps) {
|
||||||
|
const groupedItems = useFunctionCallGrouping(items);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
{groupedItems.map((groupedItem) => {
|
||||||
|
// If this is a function call with an output, render the grouped component
|
||||||
|
if (
|
||||||
|
groupedItem.outputItem &&
|
||||||
|
isFunctionCallItem(groupedItem.item) &&
|
||||||
|
isFunctionCallOutputItem(groupedItem.outputItem)
|
||||||
|
) {
|
||||||
|
return (
|
||||||
|
<GroupedFunctionCallItemComponent
|
||||||
|
key={`${keyPrefix}-${groupedItem.index}`}
|
||||||
|
functionCall={groupedItem.item}
|
||||||
|
output={groupedItem.outputItem}
|
||||||
|
index={groupedItem.index}
|
||||||
|
keyPrefix={keyPrefix}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, render the individual item
|
||||||
|
return (
|
||||||
|
<ItemRenderer
|
||||||
|
key={`${keyPrefix}-${groupedItem.index}`}
|
||||||
|
item={groupedItem.item}
|
||||||
|
index={groupedItem.index}
|
||||||
|
keyPrefix={keyPrefix}
|
||||||
|
defaultRole={defaultRole}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
}
|
|
@ -0,0 +1,92 @@
|
||||||
|
import { useMemo } from "react";
|
||||||
|
import {
|
||||||
|
isFunctionCallOutputItem,
|
||||||
|
AnyResponseItem,
|
||||||
|
FunctionCallOutputItem,
|
||||||
|
} from "../utils/item-types";
|
||||||
|
|
||||||
|
export interface GroupedItem {
|
||||||
|
item: AnyResponseItem;
|
||||||
|
index: number;
|
||||||
|
outputItem?: AnyResponseItem;
|
||||||
|
outputIndex?: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Hook to group function calls with their corresponding outputs
|
||||||
|
* @param items Array of items to group
|
||||||
|
* @returns Array of grouped items with their outputs
|
||||||
|
*/
|
||||||
|
export function useFunctionCallGrouping(
|
||||||
|
items: AnyResponseItem[],
|
||||||
|
): GroupedItem[] {
|
||||||
|
return useMemo(() => {
|
||||||
|
const groupedItems: GroupedItem[] = [];
|
||||||
|
const processedIndices = new Set<number>();
|
||||||
|
|
||||||
|
// Build a map of call_id to indices for function_call_output items
|
||||||
|
const callIdToIndices = new Map<string, number[]>();
|
||||||
|
|
||||||
|
for (let i = 0; i < items.length; i++) {
|
||||||
|
const item = items[i];
|
||||||
|
if (isFunctionCallOutputItem(item)) {
|
||||||
|
if (!callIdToIndices.has(item.call_id)) {
|
||||||
|
callIdToIndices.set(item.call_id, []);
|
||||||
|
}
|
||||||
|
callIdToIndices.get(item.call_id)!.push(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process items and group function calls with their outputs
|
||||||
|
for (let i = 0; i < items.length; i++) {
|
||||||
|
if (processedIndices.has(i)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const currentItem = items[i];
|
||||||
|
|
||||||
|
if (
|
||||||
|
currentItem.type === "function_call" &&
|
||||||
|
"name" in currentItem &&
|
||||||
|
"call_id" in currentItem
|
||||||
|
) {
|
||||||
|
const functionCallId = currentItem.call_id as string;
|
||||||
|
let outputIndex = -1;
|
||||||
|
let outputItem: FunctionCallOutputItem | null = null;
|
||||||
|
|
||||||
|
const relatedIndices = callIdToIndices.get(functionCallId) || [];
|
||||||
|
for (const idx of relatedIndices) {
|
||||||
|
const potentialOutput = items[idx];
|
||||||
|
outputIndex = idx;
|
||||||
|
outputItem = potentialOutput as FunctionCallOutputItem;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (outputItem && outputIndex !== -1) {
|
||||||
|
// Group function call with its function_call_output
|
||||||
|
groupedItems.push({
|
||||||
|
item: currentItem,
|
||||||
|
index: i,
|
||||||
|
outputItem,
|
||||||
|
outputIndex,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Mark both items as processed
|
||||||
|
processedIndices.add(i);
|
||||||
|
processedIndices.add(outputIndex);
|
||||||
|
|
||||||
|
// Matching function call and output found, skip to next item
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// render normally
|
||||||
|
groupedItems.push({
|
||||||
|
item: currentItem,
|
||||||
|
index: i,
|
||||||
|
});
|
||||||
|
processedIndices.add(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
return groupedItems;
|
||||||
|
}, [items]);
|
||||||
|
}
|
|
@ -0,0 +1,29 @@
|
||||||
|
import {
|
||||||
|
MessageBlock,
|
||||||
|
ToolCallBlock,
|
||||||
|
} from "@/components/ui/message-components";
|
||||||
|
import { FunctionCallItem } from "../utils/item-types";
|
||||||
|
|
||||||
|
interface FunctionCallItemProps {
|
||||||
|
item: FunctionCallItem;
|
||||||
|
index: number;
|
||||||
|
keyPrefix: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function FunctionCallItemComponent({
|
||||||
|
item,
|
||||||
|
index,
|
||||||
|
keyPrefix,
|
||||||
|
}: FunctionCallItemProps) {
|
||||||
|
const name = item.name || "unknown";
|
||||||
|
const args = item.arguments || "{}";
|
||||||
|
const formattedFunctionCall = `${name}(${args})`;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<MessageBlock
|
||||||
|
key={`${keyPrefix}-${index}`}
|
||||||
|
label="Function Call"
|
||||||
|
content={<ToolCallBlock>{formattedFunctionCall}</ToolCallBlock>}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
37
llama_stack/ui/components/responses/items/generic-item.tsx
Normal file
37
llama_stack/ui/components/responses/items/generic-item.tsx
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
import {
|
||||||
|
MessageBlock,
|
||||||
|
ToolCallBlock,
|
||||||
|
} from "@/components/ui/message-components";
|
||||||
|
import { BaseItem } from "../utils/item-types";
|
||||||
|
|
||||||
|
interface GenericItemProps {
|
||||||
|
item: BaseItem;
|
||||||
|
index: number;
|
||||||
|
keyPrefix: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function GenericItemComponent({
|
||||||
|
item,
|
||||||
|
index,
|
||||||
|
keyPrefix,
|
||||||
|
}: GenericItemProps) {
|
||||||
|
// Handle other types like function calls, tool outputs, etc.
|
||||||
|
const itemData = item as Record<string, unknown>;
|
||||||
|
|
||||||
|
const content = itemData.content
|
||||||
|
? typeof itemData.content === "string"
|
||||||
|
? itemData.content
|
||||||
|
: JSON.stringify(itemData.content, null, 2)
|
||||||
|
: JSON.stringify(itemData, null, 2);
|
||||||
|
|
||||||
|
const label = keyPrefix === "input" ? "Input" : "Output";
|
||||||
|
|
||||||
|
return (
|
||||||
|
<MessageBlock
|
||||||
|
key={`${keyPrefix}-${index}`}
|
||||||
|
label={label}
|
||||||
|
labelDetail={`(${itemData.type})`}
|
||||||
|
content={<ToolCallBlock>{content}</ToolCallBlock>}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
|
@ -0,0 +1,54 @@
|
||||||
|
import {
|
||||||
|
MessageBlock,
|
||||||
|
ToolCallBlock,
|
||||||
|
} from "@/components/ui/message-components";
|
||||||
|
import { FunctionCallItem, FunctionCallOutputItem } from "../utils/item-types";
|
||||||
|
|
||||||
|
interface GroupedFunctionCallItemProps {
|
||||||
|
functionCall: FunctionCallItem;
|
||||||
|
output: FunctionCallOutputItem;
|
||||||
|
index: number;
|
||||||
|
keyPrefix: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function GroupedFunctionCallItemComponent({
|
||||||
|
functionCall,
|
||||||
|
output,
|
||||||
|
index,
|
||||||
|
keyPrefix,
|
||||||
|
}: GroupedFunctionCallItemProps) {
|
||||||
|
const name = functionCall.name || "unknown";
|
||||||
|
const args = functionCall.arguments || "{}";
|
||||||
|
|
||||||
|
// Extract the output content from function_call_output
|
||||||
|
let outputContent = "";
|
||||||
|
if (output.output) {
|
||||||
|
outputContent =
|
||||||
|
typeof output.output === "string"
|
||||||
|
? output.output
|
||||||
|
: JSON.stringify(output.output);
|
||||||
|
} else {
|
||||||
|
outputContent = JSON.stringify(output, null, 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
const functionCallContent = (
|
||||||
|
<div>
|
||||||
|
<div className="mb-2">
|
||||||
|
<span className="text-sm text-gray-600">Arguments</span>
|
||||||
|
<ToolCallBlock>{`${name}(${args})`}</ToolCallBlock>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<span className="text-sm text-gray-600">Output</span>
|
||||||
|
<ToolCallBlock>{outputContent}</ToolCallBlock>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<MessageBlock
|
||||||
|
key={`${keyPrefix}-${index}`}
|
||||||
|
label="Function Call"
|
||||||
|
content={functionCallContent}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
6
llama_stack/ui/components/responses/items/index.ts
Normal file
6
llama_stack/ui/components/responses/items/index.ts
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
export { MessageItemComponent } from "./message-item";
|
||||||
|
export { FunctionCallItemComponent } from "./function-call-item";
|
||||||
|
export { WebSearchItemComponent } from "./web-search-item";
|
||||||
|
export { GenericItemComponent } from "./generic-item";
|
||||||
|
export { GroupedFunctionCallItemComponent } from "./grouped-function-call-item";
|
||||||
|
export { ItemRenderer } from "./item-renderer";
|
60
llama_stack/ui/components/responses/items/item-renderer.tsx
Normal file
60
llama_stack/ui/components/responses/items/item-renderer.tsx
Normal file
|
@ -0,0 +1,60 @@
|
||||||
|
import {
|
||||||
|
isMessageItem,
|
||||||
|
isFunctionCallItem,
|
||||||
|
isWebSearchCallItem,
|
||||||
|
AnyResponseItem,
|
||||||
|
} from "../utils/item-types";
|
||||||
|
import { MessageItemComponent } from "./message-item";
|
||||||
|
import { FunctionCallItemComponent } from "./function-call-item";
|
||||||
|
import { WebSearchItemComponent } from "./web-search-item";
|
||||||
|
import { GenericItemComponent } from "./generic-item";
|
||||||
|
|
||||||
|
interface ItemRendererProps {
|
||||||
|
item: AnyResponseItem;
|
||||||
|
index: number;
|
||||||
|
keyPrefix: string;
|
||||||
|
defaultRole?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function ItemRenderer({
|
||||||
|
item,
|
||||||
|
index,
|
||||||
|
keyPrefix,
|
||||||
|
defaultRole = "unknown",
|
||||||
|
}: ItemRendererProps) {
|
||||||
|
if (isMessageItem(item)) {
|
||||||
|
return (
|
||||||
|
<MessageItemComponent
|
||||||
|
item={item}
|
||||||
|
index={index}
|
||||||
|
keyPrefix={keyPrefix}
|
||||||
|
defaultRole={defaultRole}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isFunctionCallItem(item)) {
|
||||||
|
return (
|
||||||
|
<FunctionCallItemComponent
|
||||||
|
item={item}
|
||||||
|
index={index}
|
||||||
|
keyPrefix={keyPrefix}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isWebSearchCallItem(item)) {
|
||||||
|
return (
|
||||||
|
<WebSearchItemComponent item={item} index={index} keyPrefix={keyPrefix} />
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to generic item for unknown types
|
||||||
|
return (
|
||||||
|
<GenericItemComponent
|
||||||
|
item={item as any}
|
||||||
|
index={index}
|
||||||
|
keyPrefix={keyPrefix}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
41
llama_stack/ui/components/responses/items/message-item.tsx
Normal file
41
llama_stack/ui/components/responses/items/message-item.tsx
Normal file
|
@ -0,0 +1,41 @@
|
||||||
|
import { MessageBlock } from "@/components/ui/message-components";
|
||||||
|
import { MessageItem } from "../utils/item-types";
|
||||||
|
|
||||||
|
interface MessageItemProps {
|
||||||
|
item: MessageItem;
|
||||||
|
index: number;
|
||||||
|
keyPrefix: string;
|
||||||
|
defaultRole?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function MessageItemComponent({
|
||||||
|
item,
|
||||||
|
index,
|
||||||
|
keyPrefix,
|
||||||
|
defaultRole = "unknown",
|
||||||
|
}: MessageItemProps) {
|
||||||
|
let content = "";
|
||||||
|
|
||||||
|
if (typeof item.content === "string") {
|
||||||
|
content = item.content;
|
||||||
|
} else if (Array.isArray(item.content)) {
|
||||||
|
content = item.content
|
||||||
|
.map((c) => {
|
||||||
|
return c.type === "input_text" || c.type === "output_text"
|
||||||
|
? c.text
|
||||||
|
: JSON.stringify(c);
|
||||||
|
})
|
||||||
|
.join(" ");
|
||||||
|
}
|
||||||
|
|
||||||
|
const role = item.role || defaultRole;
|
||||||
|
const label = role.charAt(0).toUpperCase() + role.slice(1);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<MessageBlock
|
||||||
|
key={`${keyPrefix}-${index}`}
|
||||||
|
label={label}
|
||||||
|
content={content}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
|
@ -0,0 +1,28 @@
|
||||||
|
import {
|
||||||
|
MessageBlock,
|
||||||
|
ToolCallBlock,
|
||||||
|
} from "@/components/ui/message-components";
|
||||||
|
import { WebSearchCallItem } from "../utils/item-types";
|
||||||
|
|
||||||
|
interface WebSearchItemProps {
|
||||||
|
item: WebSearchCallItem;
|
||||||
|
index: number;
|
||||||
|
keyPrefix: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function WebSearchItemComponent({
|
||||||
|
item,
|
||||||
|
index,
|
||||||
|
keyPrefix,
|
||||||
|
}: WebSearchItemProps) {
|
||||||
|
const formattedWebSearch = `web_search_call(status: ${item.status})`;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<MessageBlock
|
||||||
|
key={`${keyPrefix}-${index}`}
|
||||||
|
label="Function Call"
|
||||||
|
labelDetail="(Web Search)"
|
||||||
|
content={<ToolCallBlock>{formattedWebSearch}</ToolCallBlock>}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
777
llama_stack/ui/components/responses/responses-detail.test.tsx
Normal file
777
llama_stack/ui/components/responses/responses-detail.test.tsx
Normal file
|
@ -0,0 +1,777 @@
|
||||||
|
import React from "react";
|
||||||
|
import { render, screen } from "@testing-library/react";
|
||||||
|
import "@testing-library/jest-dom";
|
||||||
|
import { ResponseDetailView } from "./responses-detail";
|
||||||
|
import { OpenAIResponse, InputItemListResponse } from "@/lib/types";
|
||||||
|
|
||||||
|
describe("ResponseDetailView", () => {
|
||||||
|
const defaultProps = {
|
||||||
|
response: null,
|
||||||
|
inputItems: null,
|
||||||
|
isLoading: false,
|
||||||
|
isLoadingInputItems: false,
|
||||||
|
error: null,
|
||||||
|
inputItemsError: null,
|
||||||
|
id: "test_id",
|
||||||
|
};
|
||||||
|
|
||||||
|
describe("Loading State", () => {
|
||||||
|
test("renders loading skeleton when isLoading is true", () => {
|
||||||
|
const { container } = render(
|
||||||
|
<ResponseDetailView {...defaultProps} isLoading={true} />,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Check for skeleton elements
|
||||||
|
const skeletons = container.querySelectorAll('[data-slot="skeleton"]');
|
||||||
|
expect(skeletons.length).toBeGreaterThan(0);
|
||||||
|
|
||||||
|
// The title is replaced by a skeleton when loading, so we shouldn't expect the text
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Error State", () => {
|
||||||
|
test("renders error message when error prop is provided", () => {
|
||||||
|
const errorMessage = "Network Error";
|
||||||
|
render(
|
||||||
|
<ResponseDetailView
|
||||||
|
{...defaultProps}
|
||||||
|
error={{ name: "Error", message: errorMessage }}
|
||||||
|
/>,
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(screen.getByText("Responses Details")).toBeInTheDocument();
|
||||||
|
// The error message is split across elements, so we check for parts
|
||||||
|
expect(
|
||||||
|
screen.getByText(/Error loading details for ID/),
|
||||||
|
).toBeInTheDocument();
|
||||||
|
expect(screen.getByText(/test_id/)).toBeInTheDocument();
|
||||||
|
expect(screen.getByText(/Network Error/)).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("renders default error message when error.message is not available", () => {
|
||||||
|
render(
|
||||||
|
<ResponseDetailView
|
||||||
|
{...defaultProps}
|
||||||
|
error={{ name: "Error", message: "" }}
|
||||||
|
/>,
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(
|
||||||
|
screen.getByText(/Error loading details for ID/),
|
||||||
|
).toBeInTheDocument();
|
||||||
|
expect(screen.getByText(/test_id/)).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Not Found State", () => {
|
||||||
|
test("renders not found message when response is null and not loading/error", () => {
|
||||||
|
render(<ResponseDetailView {...defaultProps} response={null} />);
|
||||||
|
|
||||||
|
expect(screen.getByText("Responses Details")).toBeInTheDocument();
|
||||||
|
// The message is split across elements
|
||||||
|
expect(screen.getByText(/No details found for ID:/)).toBeInTheDocument();
|
||||||
|
expect(screen.getByText(/test_id/)).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Response Data Rendering", () => {
|
||||||
|
const mockResponse: OpenAIResponse = {
|
||||||
|
id: "resp_123",
|
||||||
|
object: "response",
|
||||||
|
created_at: 1710000000,
|
||||||
|
model: "llama-test-model",
|
||||||
|
status: "completed",
|
||||||
|
output: [
|
||||||
|
{
|
||||||
|
type: "message",
|
||||||
|
role: "assistant",
|
||||||
|
content: "Test response output",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
input: [
|
||||||
|
{
|
||||||
|
type: "message",
|
||||||
|
role: "user",
|
||||||
|
content: "Test input message",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
temperature: 0.7,
|
||||||
|
top_p: 0.9,
|
||||||
|
parallel_tool_calls: true,
|
||||||
|
previous_response_id: "prev_resp_456",
|
||||||
|
};
|
||||||
|
|
||||||
|
test("renders response data with input and output sections", () => {
|
||||||
|
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
|
||||||
|
|
||||||
|
// Check main sections
|
||||||
|
expect(screen.getByText("Responses Details")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Input")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Output")).toBeInTheDocument();
|
||||||
|
|
||||||
|
// Check input content
|
||||||
|
expect(screen.getByText("Test input message")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("User")).toBeInTheDocument();
|
||||||
|
|
||||||
|
// Check output content
|
||||||
|
expect(screen.getByText("Test response output")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Assistant")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("renders properties sidebar with all response metadata", () => {
|
||||||
|
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
|
||||||
|
|
||||||
|
// Check properties - use regex to handle text split across elements
|
||||||
|
expect(screen.getByText(/Created/)).toBeInTheDocument();
|
||||||
|
expect(
|
||||||
|
screen.getByText(new Date(1710000000 * 1000).toLocaleString()),
|
||||||
|
).toBeInTheDocument();
|
||||||
|
|
||||||
|
// Check for the specific ID label (not Previous Response ID)
|
||||||
|
expect(
|
||||||
|
screen.getByText((content, element) => {
|
||||||
|
return element?.tagName === "STRONG" && content === "ID:";
|
||||||
|
}),
|
||||||
|
).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("resp_123")).toBeInTheDocument();
|
||||||
|
|
||||||
|
expect(screen.getByText(/Model/)).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("llama-test-model")).toBeInTheDocument();
|
||||||
|
|
||||||
|
expect(screen.getByText(/Status/)).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("completed")).toBeInTheDocument();
|
||||||
|
|
||||||
|
expect(screen.getByText(/Temperature/)).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("0.7")).toBeInTheDocument();
|
||||||
|
|
||||||
|
expect(screen.getByText(/Top P/)).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("0.9")).toBeInTheDocument();
|
||||||
|
|
||||||
|
expect(screen.getByText(/Parallel Tool Calls/)).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Yes")).toBeInTheDocument();
|
||||||
|
|
||||||
|
expect(screen.getByText(/Previous Response ID/)).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("prev_resp_456")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("handles optional properties correctly", () => {
|
||||||
|
const minimalResponse: OpenAIResponse = {
|
||||||
|
id: "resp_minimal",
|
||||||
|
object: "response",
|
||||||
|
created_at: 1710000000,
|
||||||
|
model: "test-model",
|
||||||
|
status: "completed",
|
||||||
|
output: [],
|
||||||
|
input: [],
|
||||||
|
};
|
||||||
|
|
||||||
|
render(
|
||||||
|
<ResponseDetailView {...defaultProps} response={minimalResponse} />,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Should show required properties
|
||||||
|
expect(screen.getByText("resp_minimal")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("test-model")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("completed")).toBeInTheDocument();
|
||||||
|
|
||||||
|
// Should not show optional properties
|
||||||
|
expect(screen.queryByText("Temperature")).not.toBeInTheDocument();
|
||||||
|
expect(screen.queryByText("Top P")).not.toBeInTheDocument();
|
||||||
|
expect(screen.queryByText("Parallel Tool Calls")).not.toBeInTheDocument();
|
||||||
|
expect(
|
||||||
|
screen.queryByText("Previous Response ID"),
|
||||||
|
).not.toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("renders error information when response has error", () => {
|
||||||
|
const errorResponse: OpenAIResponse = {
|
||||||
|
...mockResponse,
|
||||||
|
error: {
|
||||||
|
code: "invalid_request",
|
||||||
|
message: "The request was invalid",
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
render(<ResponseDetailView {...defaultProps} response={errorResponse} />);
|
||||||
|
|
||||||
|
// The error is shown in the properties sidebar, not as a separate "Error" label
|
||||||
|
expect(
|
||||||
|
screen.getByText("invalid_request: The request was invalid"),
|
||||||
|
).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Input Items Handling", () => {
|
||||||
|
const mockResponse: OpenAIResponse = {
|
||||||
|
id: "resp_123",
|
||||||
|
object: "response",
|
||||||
|
created_at: 1710000000,
|
||||||
|
model: "test-model",
|
||||||
|
status: "completed",
|
||||||
|
output: [{ type: "message", role: "assistant", content: "output" }],
|
||||||
|
input: [{ type: "message", role: "user", content: "fallback input" }],
|
||||||
|
};
|
||||||
|
|
||||||
|
test("shows loading state for input items", () => {
|
||||||
|
render(
|
||||||
|
<ResponseDetailView
|
||||||
|
{...defaultProps}
|
||||||
|
response={mockResponse}
|
||||||
|
isLoadingInputItems={true}
|
||||||
|
/>,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Check for skeleton loading in input items section
|
||||||
|
const { container } = render(
|
||||||
|
<ResponseDetailView
|
||||||
|
{...defaultProps}
|
||||||
|
response={mockResponse}
|
||||||
|
isLoadingInputItems={true}
|
||||||
|
/>,
|
||||||
|
);
|
||||||
|
|
||||||
|
const skeletons = container.querySelectorAll('[data-slot="skeleton"]');
|
||||||
|
expect(skeletons.length).toBeGreaterThan(0);
|
||||||
|
});
|
||||||
|
|
||||||
|
test("shows error message for input items with fallback", () => {
|
||||||
|
render(
|
||||||
|
<ResponseDetailView
|
||||||
|
{...defaultProps}
|
||||||
|
response={mockResponse}
|
||||||
|
inputItemsError={{
|
||||||
|
name: "Error",
|
||||||
|
message: "Failed to load input items",
|
||||||
|
}}
|
||||||
|
/>,
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(
|
||||||
|
screen.getByText(
|
||||||
|
"Error loading input items: Failed to load input items",
|
||||||
|
),
|
||||||
|
).toBeInTheDocument();
|
||||||
|
expect(
|
||||||
|
screen.getByText("Falling back to response input data."),
|
||||||
|
).toBeInTheDocument();
|
||||||
|
|
||||||
|
// Should still show fallback input data
|
||||||
|
expect(screen.getByText("fallback input")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("uses input items data when available", () => {
|
||||||
|
const mockInputItems: InputItemListResponse = {
|
||||||
|
object: "list",
|
||||||
|
data: [
|
||||||
|
{
|
||||||
|
type: "message",
|
||||||
|
role: "user",
|
||||||
|
content: "input from items API",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
render(
|
||||||
|
<ResponseDetailView
|
||||||
|
{...defaultProps}
|
||||||
|
response={mockResponse}
|
||||||
|
inputItems={mockInputItems}
|
||||||
|
/>,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Should show input items data, not response.input
|
||||||
|
expect(screen.getByText("input from items API")).toBeInTheDocument();
|
||||||
|
expect(screen.queryByText("fallback input")).not.toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("falls back to response.input when input items is empty", () => {
|
||||||
|
const emptyInputItems: InputItemListResponse = {
|
||||||
|
object: "list",
|
||||||
|
data: [],
|
||||||
|
};
|
||||||
|
|
||||||
|
render(
|
||||||
|
<ResponseDetailView
|
||||||
|
{...defaultProps}
|
||||||
|
response={mockResponse}
|
||||||
|
inputItems={emptyInputItems}
|
||||||
|
/>,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Should show fallback input data
|
||||||
|
expect(screen.getByText("fallback input")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("shows no input message when no data available", () => {
|
||||||
|
const responseWithoutInput: OpenAIResponse = {
|
||||||
|
...mockResponse,
|
||||||
|
input: [],
|
||||||
|
};
|
||||||
|
|
||||||
|
render(
|
||||||
|
<ResponseDetailView
|
||||||
|
{...defaultProps}
|
||||||
|
response={responseWithoutInput}
|
||||||
|
inputItems={null}
|
||||||
|
/>,
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(screen.getByText("No input data available.")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Input Display Components", () => {
|
||||||
|
test("renders string content input correctly", () => {
|
||||||
|
const mockResponse: OpenAIResponse = {
|
||||||
|
id: "resp_123",
|
||||||
|
object: "response",
|
||||||
|
created_at: 1710000000,
|
||||||
|
model: "test-model",
|
||||||
|
status: "completed",
|
||||||
|
output: [],
|
||||||
|
input: [
|
||||||
|
{
|
||||||
|
type: "message",
|
||||||
|
role: "user",
|
||||||
|
content: "Simple string input",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
|
||||||
|
|
||||||
|
expect(screen.getByText("Simple string input")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("User")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("renders array content input correctly", () => {
|
||||||
|
const mockResponse: OpenAIResponse = {
|
||||||
|
id: "resp_123",
|
||||||
|
object: "response",
|
||||||
|
created_at: 1710000000,
|
||||||
|
model: "test-model",
|
||||||
|
status: "completed",
|
||||||
|
output: [],
|
||||||
|
input: [
|
||||||
|
{
|
||||||
|
type: "message",
|
||||||
|
role: "user",
|
||||||
|
content: [
|
||||||
|
{ type: "input_text", text: "First part" },
|
||||||
|
{ type: "output_text", text: "Second part" },
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
|
||||||
|
|
||||||
|
expect(screen.getByText("First part Second part")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("User")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("renders non-message input types correctly", () => {
|
||||||
|
const mockResponse: OpenAIResponse = {
|
||||||
|
id: "resp_123",
|
||||||
|
object: "response",
|
||||||
|
created_at: 1710000000,
|
||||||
|
model: "test-model",
|
||||||
|
status: "completed",
|
||||||
|
output: [],
|
||||||
|
input: [
|
||||||
|
{
|
||||||
|
type: "function_call",
|
||||||
|
content: "function call content",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
|
||||||
|
|
||||||
|
expect(screen.getByText("function call content")).toBeInTheDocument();
|
||||||
|
// Use getAllByText to find the specific "Input" with the type detail
|
||||||
|
const inputElements = screen.getAllByText("Input");
|
||||||
|
expect(inputElements.length).toBeGreaterThan(0);
|
||||||
|
expect(screen.getByText("(function_call)")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("handles input with object content", () => {
|
||||||
|
const mockResponse: OpenAIResponse = {
|
||||||
|
id: "resp_123",
|
||||||
|
object: "response",
|
||||||
|
created_at: 1710000000,
|
||||||
|
model: "test-model",
|
||||||
|
status: "completed",
|
||||||
|
output: [],
|
||||||
|
input: [
|
||||||
|
{
|
||||||
|
type: "custom_type",
|
||||||
|
content: JSON.stringify({ key: "value", nested: { data: "test" } }),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
|
||||||
|
|
||||||
|
// Should show JSON stringified content (without quotes around keys in the rendered output)
|
||||||
|
expect(screen.getByText(/key.*value/)).toBeInTheDocument();
|
||||||
|
// Use getAllByText to find the specific "Input" with the type detail
|
||||||
|
const inputElements = screen.getAllByText("Input");
|
||||||
|
expect(inputElements.length).toBeGreaterThan(0);
|
||||||
|
expect(screen.getByText("(custom_type)")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("renders function call input correctly", () => {
|
||||||
|
const mockResponse: OpenAIResponse = {
|
||||||
|
id: "resp_123",
|
||||||
|
object: "response",
|
||||||
|
created_at: 1710000000,
|
||||||
|
model: "test-model",
|
||||||
|
status: "completed",
|
||||||
|
output: [],
|
||||||
|
input: [
|
||||||
|
{
|
||||||
|
type: "function_call",
|
||||||
|
id: "call_456",
|
||||||
|
status: "completed",
|
||||||
|
name: "input_function",
|
||||||
|
arguments: '{"param": "value"}',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
|
||||||
|
|
||||||
|
expect(
|
||||||
|
screen.getByText('input_function({"param": "value"})'),
|
||||||
|
).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Function Call")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("renders web search call input correctly", () => {
|
||||||
|
const mockResponse: OpenAIResponse = {
|
||||||
|
id: "resp_123",
|
||||||
|
object: "response",
|
||||||
|
created_at: 1710000000,
|
||||||
|
model: "test-model",
|
||||||
|
status: "completed",
|
||||||
|
output: [],
|
||||||
|
input: [
|
||||||
|
{
|
||||||
|
type: "web_search_call",
|
||||||
|
id: "search_789",
|
||||||
|
status: "completed",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
|
||||||
|
|
||||||
|
expect(
|
||||||
|
screen.getByText("web_search_call(status: completed)"),
|
||||||
|
).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Function Call")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("(Web Search)")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Output Display Components", () => {
|
||||||
|
test("renders message output with string content", () => {
|
||||||
|
const mockResponse: OpenAIResponse = {
|
||||||
|
id: "resp_123",
|
||||||
|
object: "response",
|
||||||
|
created_at: 1710000000,
|
||||||
|
model: "test-model",
|
||||||
|
status: "completed",
|
||||||
|
output: [
|
||||||
|
{
|
||||||
|
type: "message",
|
||||||
|
role: "assistant",
|
||||||
|
content: "Simple string output",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
input: [],
|
||||||
|
};
|
||||||
|
|
||||||
|
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
|
||||||
|
|
||||||
|
expect(screen.getByText("Simple string output")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Assistant")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("renders message output with array content", () => {
|
||||||
|
const mockResponse: OpenAIResponse = {
|
||||||
|
id: "resp_123",
|
||||||
|
object: "response",
|
||||||
|
created_at: 1710000000,
|
||||||
|
model: "test-model",
|
||||||
|
status: "completed",
|
||||||
|
output: [
|
||||||
|
{
|
||||||
|
type: "message",
|
||||||
|
role: "assistant",
|
||||||
|
content: [
|
||||||
|
{ type: "output_text", text: "First output" },
|
||||||
|
{ type: "input_text", text: "Second output" },
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
input: [],
|
||||||
|
};
|
||||||
|
|
||||||
|
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
|
||||||
|
|
||||||
|
expect(
|
||||||
|
screen.getByText("First output Second output"),
|
||||||
|
).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Assistant")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("renders function call output correctly", () => {
|
||||||
|
const mockResponse: OpenAIResponse = {
|
||||||
|
id: "resp_123",
|
||||||
|
object: "response",
|
||||||
|
created_at: 1710000000,
|
||||||
|
model: "test-model",
|
||||||
|
status: "completed",
|
||||||
|
output: [
|
||||||
|
{
|
||||||
|
type: "function_call",
|
||||||
|
id: "call_123",
|
||||||
|
status: "completed",
|
||||||
|
name: "search_function",
|
||||||
|
arguments: '{"query": "test"}',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
input: [],
|
||||||
|
};
|
||||||
|
|
||||||
|
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
|
||||||
|
|
||||||
|
expect(
|
||||||
|
screen.getByText('search_function({"query": "test"})'),
|
||||||
|
).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Function Call")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("renders function call output without arguments", () => {
|
||||||
|
const mockResponse: OpenAIResponse = {
|
||||||
|
id: "resp_123",
|
||||||
|
object: "response",
|
||||||
|
created_at: 1710000000,
|
||||||
|
model: "test-model",
|
||||||
|
status: "completed",
|
||||||
|
output: [
|
||||||
|
{
|
||||||
|
type: "function_call",
|
||||||
|
id: "call_123",
|
||||||
|
status: "completed",
|
||||||
|
name: "simple_function",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
input: [],
|
||||||
|
};
|
||||||
|
|
||||||
|
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
|
||||||
|
|
||||||
|
expect(screen.getByText("simple_function({})")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText(/Function Call/)).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("renders web search call output correctly", () => {
|
||||||
|
const mockResponse: OpenAIResponse = {
|
||||||
|
id: "resp_123",
|
||||||
|
object: "response",
|
||||||
|
created_at: 1710000000,
|
||||||
|
model: "test-model",
|
||||||
|
status: "completed",
|
||||||
|
output: [
|
||||||
|
{
|
||||||
|
type: "web_search_call",
|
||||||
|
id: "search_123",
|
||||||
|
status: "completed",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
input: [],
|
||||||
|
};
|
||||||
|
|
||||||
|
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
|
||||||
|
|
||||||
|
expect(
|
||||||
|
screen.getByText("web_search_call(status: completed)"),
|
||||||
|
).toBeInTheDocument();
|
||||||
|
expect(screen.getByText(/Function Call/)).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("(Web Search)")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("renders unknown output types with JSON fallback", () => {
|
||||||
|
const mockResponse: OpenAIResponse = {
|
||||||
|
id: "resp_123",
|
||||||
|
object: "response",
|
||||||
|
created_at: 1710000000,
|
||||||
|
model: "test-model",
|
||||||
|
status: "completed",
|
||||||
|
output: [
|
||||||
|
{
|
||||||
|
type: "unknown_type",
|
||||||
|
custom_field: "custom_value",
|
||||||
|
data: { nested: "object" },
|
||||||
|
} as any,
|
||||||
|
],
|
||||||
|
input: [],
|
||||||
|
};
|
||||||
|
|
||||||
|
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
|
||||||
|
|
||||||
|
// Should show JSON stringified content
|
||||||
|
expect(
|
||||||
|
screen.getByText(/custom_field.*custom_value/),
|
||||||
|
).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("(unknown_type)")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("shows no output message when output array is empty", () => {
|
||||||
|
const mockResponse: OpenAIResponse = {
|
||||||
|
id: "resp_123",
|
||||||
|
object: "response",
|
||||||
|
created_at: 1710000000,
|
||||||
|
model: "test-model",
|
||||||
|
status: "completed",
|
||||||
|
output: [],
|
||||||
|
input: [],
|
||||||
|
};
|
||||||
|
|
||||||
|
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
|
||||||
|
|
||||||
|
expect(screen.getByText("No output data available.")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("groups function call with its output correctly", () => {
|
||||||
|
const mockResponse: OpenAIResponse = {
|
||||||
|
id: "resp_123",
|
||||||
|
object: "response",
|
||||||
|
created_at: 1710000000,
|
||||||
|
model: "test-model",
|
||||||
|
status: "completed",
|
||||||
|
output: [
|
||||||
|
{
|
||||||
|
type: "function_call",
|
||||||
|
id: "call_123",
|
||||||
|
status: "completed",
|
||||||
|
name: "get_weather",
|
||||||
|
arguments: '{"city": "Tokyo"}',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
type: "message",
|
||||||
|
role: "assistant",
|
||||||
|
call_id: "call_123",
|
||||||
|
content: "sunny and warm",
|
||||||
|
} as any, // Using any to bypass the type restriction for this test
|
||||||
|
],
|
||||||
|
input: [],
|
||||||
|
};
|
||||||
|
|
||||||
|
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
|
||||||
|
|
||||||
|
// Should show the function call and message as separate items (not grouped)
|
||||||
|
expect(screen.getByText("Function Call")).toBeInTheDocument();
|
||||||
|
expect(
|
||||||
|
screen.getByText('get_weather({"city": "Tokyo"})'),
|
||||||
|
).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Assistant")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("sunny and warm")).toBeInTheDocument();
|
||||||
|
|
||||||
|
// Should NOT have the grouped "Arguments" and "Output" labels
|
||||||
|
expect(screen.queryByText("Arguments")).not.toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("groups function call with function_call_output correctly", () => {
|
||||||
|
const mockResponse: OpenAIResponse = {
|
||||||
|
id: "resp_123",
|
||||||
|
object: "response",
|
||||||
|
created_at: 1710000000,
|
||||||
|
model: "test-model",
|
||||||
|
status: "completed",
|
||||||
|
output: [
|
||||||
|
{
|
||||||
|
type: "function_call",
|
||||||
|
call_id: "call_123",
|
||||||
|
status: "completed",
|
||||||
|
name: "get_weather",
|
||||||
|
arguments: '{"city": "Tokyo"}',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
type: "function_call_output",
|
||||||
|
id: "fc_68364957013081...",
|
||||||
|
status: "completed",
|
||||||
|
call_id: "call_123",
|
||||||
|
output: "sunny and warm",
|
||||||
|
} as any, // Using any to bypass the type restriction for this test
|
||||||
|
],
|
||||||
|
input: [],
|
||||||
|
};
|
||||||
|
|
||||||
|
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
|
||||||
|
|
||||||
|
// Should show the function call grouped with its clean output
|
||||||
|
expect(screen.getByText("Function Call")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Arguments")).toBeInTheDocument();
|
||||||
|
expect(
|
||||||
|
screen.getByText('get_weather({"city": "Tokyo"})'),
|
||||||
|
).toBeInTheDocument();
|
||||||
|
// Use getAllByText since there are multiple "Output" elements (card title and output label)
|
||||||
|
const outputElements = screen.getAllByText("Output");
|
||||||
|
expect(outputElements.length).toBeGreaterThan(0);
|
||||||
|
expect(screen.getByText("sunny and warm")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Edge Cases and Error Handling", () => {
|
||||||
|
test("handles missing role in message input", () => {
|
||||||
|
const mockResponse: OpenAIResponse = {
|
||||||
|
id: "resp_123",
|
||||||
|
object: "response",
|
||||||
|
created_at: 1710000000,
|
||||||
|
model: "test-model",
|
||||||
|
status: "completed",
|
||||||
|
output: [],
|
||||||
|
input: [
|
||||||
|
{
|
||||||
|
type: "message",
|
||||||
|
content: "Message without role",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
|
||||||
|
|
||||||
|
expect(screen.getByText("Message without role")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Unknown")).toBeInTheDocument(); // Default role
|
||||||
|
});
|
||||||
|
|
||||||
|
test("handles missing name in function call output", () => {
|
||||||
|
const mockResponse: OpenAIResponse = {
|
||||||
|
id: "resp_123",
|
||||||
|
object: "response",
|
||||||
|
created_at: 1710000000,
|
||||||
|
model: "test-model",
|
||||||
|
status: "completed",
|
||||||
|
output: [
|
||||||
|
{
|
||||||
|
type: "function_call",
|
||||||
|
id: "call_123",
|
||||||
|
status: "completed",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
input: [],
|
||||||
|
};
|
||||||
|
|
||||||
|
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
|
||||||
|
|
||||||
|
// When name is missing, it falls back to JSON.stringify of the entire output
|
||||||
|
const functionCallElements = screen.getAllByText(/function_call/);
|
||||||
|
expect(functionCallElements.length).toBeGreaterThan(0);
|
||||||
|
expect(screen.getByText(/call_123/)).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
171
llama_stack/ui/components/responses/responses-detail.tsx
Normal file
171
llama_stack/ui/components/responses/responses-detail.tsx
Normal file
|
@ -0,0 +1,171 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
|
import { OpenAIResponse, InputItemListResponse } from "@/lib/types";
|
||||||
|
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
|
||||||
|
import { Skeleton } from "@/components/ui/skeleton";
|
||||||
|
import {
|
||||||
|
DetailLoadingView,
|
||||||
|
DetailErrorView,
|
||||||
|
DetailNotFoundView,
|
||||||
|
DetailLayout,
|
||||||
|
PropertiesCard,
|
||||||
|
PropertyItem,
|
||||||
|
} from "@/components/layout/detail-layout";
|
||||||
|
import { GroupedItemsDisplay } from "./grouping/grouped-items-display";
|
||||||
|
|
||||||
|
interface ResponseDetailViewProps {
|
||||||
|
response: OpenAIResponse | null;
|
||||||
|
inputItems: InputItemListResponse | null;
|
||||||
|
isLoading: boolean;
|
||||||
|
isLoadingInputItems: boolean;
|
||||||
|
error: Error | null;
|
||||||
|
inputItemsError: Error | null;
|
||||||
|
id: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function ResponseDetailView({
|
||||||
|
response,
|
||||||
|
inputItems,
|
||||||
|
isLoading,
|
||||||
|
isLoadingInputItems,
|
||||||
|
error,
|
||||||
|
inputItemsError,
|
||||||
|
id,
|
||||||
|
}: ResponseDetailViewProps) {
|
||||||
|
const title = "Responses Details";
|
||||||
|
|
||||||
|
if (error) {
|
||||||
|
return <DetailErrorView title={title} id={id} error={error} />;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isLoading) {
|
||||||
|
return <DetailLoadingView title={title} />;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!response) {
|
||||||
|
return <DetailNotFoundView title={title} id={id} />;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Main content cards
|
||||||
|
const mainContent = (
|
||||||
|
<>
|
||||||
|
<Card>
|
||||||
|
<CardHeader>
|
||||||
|
<CardTitle>Input</CardTitle>
|
||||||
|
</CardHeader>
|
||||||
|
<CardContent>
|
||||||
|
{/* Show loading state for input items */}
|
||||||
|
{isLoadingInputItems ? (
|
||||||
|
<div className="space-y-2">
|
||||||
|
<Skeleton className="h-4 w-full" />
|
||||||
|
<Skeleton className="h-4 w-3/4" />
|
||||||
|
<Skeleton className="h-4 w-1/2" />
|
||||||
|
</div>
|
||||||
|
) : inputItemsError ? (
|
||||||
|
<div className="text-red-500 text-sm">
|
||||||
|
Error loading input items: {inputItemsError.message}
|
||||||
|
<br />
|
||||||
|
<span className="text-gray-500 text-xs">
|
||||||
|
Falling back to response input data.
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
) : null}
|
||||||
|
|
||||||
|
{/* Display input items if available, otherwise fall back to response.input */}
|
||||||
|
{(() => {
|
||||||
|
const dataToDisplay =
|
||||||
|
inputItems?.data && inputItems.data.length > 0
|
||||||
|
? inputItems.data
|
||||||
|
: response.input;
|
||||||
|
|
||||||
|
if (dataToDisplay && dataToDisplay.length > 0) {
|
||||||
|
return (
|
||||||
|
<GroupedItemsDisplay
|
||||||
|
items={dataToDisplay}
|
||||||
|
keyPrefix="input"
|
||||||
|
defaultRole="unknown"
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
return (
|
||||||
|
<p className="text-gray-500 italic text-sm">
|
||||||
|
No input data available.
|
||||||
|
</p>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
})()}
|
||||||
|
</CardContent>
|
||||||
|
</Card>
|
||||||
|
|
||||||
|
<Card>
|
||||||
|
<CardHeader>
|
||||||
|
<CardTitle>Output</CardTitle>
|
||||||
|
</CardHeader>
|
||||||
|
<CardContent>
|
||||||
|
{response.output?.length > 0 ? (
|
||||||
|
<GroupedItemsDisplay
|
||||||
|
items={response.output}
|
||||||
|
keyPrefix="output"
|
||||||
|
defaultRole="assistant"
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<p className="text-gray-500 italic text-sm">
|
||||||
|
No output data available.
|
||||||
|
</p>
|
||||||
|
)}
|
||||||
|
</CardContent>
|
||||||
|
</Card>
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
|
||||||
|
// Properties sidebar
|
||||||
|
const sidebar = (
|
||||||
|
<PropertiesCard>
|
||||||
|
<PropertyItem
|
||||||
|
label="Created"
|
||||||
|
value={new Date(response.created_at * 1000).toLocaleString()}
|
||||||
|
/>
|
||||||
|
<PropertyItem label="ID" value={response.id} />
|
||||||
|
<PropertyItem label="Model" value={response.model} />
|
||||||
|
<PropertyItem label="Status" value={response.status} hasBorder />
|
||||||
|
{response.temperature && (
|
||||||
|
<PropertyItem
|
||||||
|
label="Temperature"
|
||||||
|
value={response.temperature}
|
||||||
|
hasBorder
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{response.top_p && <PropertyItem label="Top P" value={response.top_p} />}
|
||||||
|
{response.parallel_tool_calls && (
|
||||||
|
<PropertyItem
|
||||||
|
label="Parallel Tool Calls"
|
||||||
|
value={response.parallel_tool_calls ? "Yes" : "No"}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{response.previous_response_id && (
|
||||||
|
<PropertyItem
|
||||||
|
label="Previous Response ID"
|
||||||
|
value={
|
||||||
|
<span className="text-xs">{response.previous_response_id}</span>
|
||||||
|
}
|
||||||
|
hasBorder
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{response.error && (
|
||||||
|
<PropertyItem
|
||||||
|
label="Error"
|
||||||
|
value={
|
||||||
|
<span className="text-red-900 font-medium">
|
||||||
|
{response.error.code}: {response.error.message}
|
||||||
|
</span>
|
||||||
|
}
|
||||||
|
className="pt-1 mt-1 border-t border-red-200"
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</PropertiesCard>
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<DetailLayout title={title} mainContent={mainContent} sidebar={sidebar} />
|
||||||
|
);
|
||||||
|
}
|
537
llama_stack/ui/components/responses/responses-table.test.tsx
Normal file
537
llama_stack/ui/components/responses/responses-table.test.tsx
Normal file
|
@ -0,0 +1,537 @@
|
||||||
|
import React from "react";
|
||||||
|
import { render, screen, fireEvent } from "@testing-library/react";
|
||||||
|
import "@testing-library/jest-dom";
|
||||||
|
import { ResponsesTable } from "./responses-table";
|
||||||
|
import { OpenAIResponse } from "@/lib/types";
|
||||||
|
|
||||||
|
// Mock next/navigation
|
||||||
|
const mockPush = jest.fn();
|
||||||
|
jest.mock("next/navigation", () => ({
|
||||||
|
useRouter: () => ({
|
||||||
|
push: mockPush,
|
||||||
|
}),
|
||||||
|
}));
|
||||||
|
|
||||||
|
// Mock helper functions
|
||||||
|
jest.mock("@/lib/truncate-text");
|
||||||
|
|
||||||
|
// Import the mocked functions
|
||||||
|
import { truncateText as originalTruncateText } from "@/lib/truncate-text";
|
||||||
|
|
||||||
|
// Cast to jest.Mock for typings
|
||||||
|
const truncateText = originalTruncateText as jest.Mock;
|
||||||
|
|
||||||
|
describe("ResponsesTable", () => {
|
||||||
|
const defaultProps = {
|
||||||
|
data: [] as OpenAIResponse[],
|
||||||
|
isLoading: false,
|
||||||
|
error: null,
|
||||||
|
};
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
// Reset all mocks before each test
|
||||||
|
mockPush.mockClear();
|
||||||
|
truncateText.mockClear();
|
||||||
|
|
||||||
|
// Default pass-through implementation
|
||||||
|
truncateText.mockImplementation((text: string | undefined) => text);
|
||||||
|
});
|
||||||
|
|
||||||
|
test("renders without crashing with default props", () => {
|
||||||
|
render(<ResponsesTable {...defaultProps} />);
|
||||||
|
expect(screen.getByText("No responses found.")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("click on a row navigates to the correct URL", () => {
|
||||||
|
const mockResponse: OpenAIResponse = {
|
||||||
|
id: "resp_123",
|
||||||
|
object: "response",
|
||||||
|
created_at: Math.floor(Date.now() / 1000),
|
||||||
|
model: "llama-test-model",
|
||||||
|
status: "completed",
|
||||||
|
output: [
|
||||||
|
{
|
||||||
|
type: "message",
|
||||||
|
role: "assistant",
|
||||||
|
content: "Test output",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
input: [
|
||||||
|
{
|
||||||
|
type: "message",
|
||||||
|
role: "user",
|
||||||
|
content: "Test input",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
render(<ResponsesTable {...defaultProps} data={[mockResponse]} />);
|
||||||
|
|
||||||
|
const row = screen.getByText("Test input").closest("tr");
|
||||||
|
if (row) {
|
||||||
|
fireEvent.click(row);
|
||||||
|
expect(mockPush).toHaveBeenCalledWith("/logs/responses/resp_123");
|
||||||
|
} else {
|
||||||
|
throw new Error('Row with "Test input" not found for router mock test.');
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Loading State", () => {
|
||||||
|
test("renders skeleton UI when isLoading is true", () => {
|
||||||
|
const { container } = render(
|
||||||
|
<ResponsesTable {...defaultProps} isLoading={true} />,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Check for skeleton in the table caption
|
||||||
|
const tableCaption = container.querySelector("caption");
|
||||||
|
expect(tableCaption).toBeInTheDocument();
|
||||||
|
if (tableCaption) {
|
||||||
|
const captionSkeleton = tableCaption.querySelector(
|
||||||
|
'[data-slot="skeleton"]',
|
||||||
|
);
|
||||||
|
expect(captionSkeleton).toBeInTheDocument();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for skeletons in the table body cells
|
||||||
|
const tableBody = container.querySelector("tbody");
|
||||||
|
expect(tableBody).toBeInTheDocument();
|
||||||
|
if (tableBody) {
|
||||||
|
const bodySkeletons = tableBody.querySelectorAll(
|
||||||
|
'[data-slot="skeleton"]',
|
||||||
|
);
|
||||||
|
expect(bodySkeletons.length).toBeGreaterThan(0);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Error State", () => {
|
||||||
|
test("renders error message when error prop is provided", () => {
|
||||||
|
const errorMessage = "Network Error";
|
||||||
|
render(
|
||||||
|
<ResponsesTable
|
||||||
|
{...defaultProps}
|
||||||
|
error={{ name: "Error", message: errorMessage }}
|
||||||
|
/>,
|
||||||
|
);
|
||||||
|
expect(
|
||||||
|
screen.getByText(`Error fetching data: ${errorMessage}`),
|
||||||
|
).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("renders default error message when error.message is not available", () => {
|
||||||
|
render(
|
||||||
|
<ResponsesTable
|
||||||
|
{...defaultProps}
|
||||||
|
error={{ name: "Error", message: "" }}
|
||||||
|
/>,
|
||||||
|
);
|
||||||
|
expect(
|
||||||
|
screen.getByText("Error fetching data: An unknown error occurred"),
|
||||||
|
).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("renders default error message when error prop is an object without message", () => {
|
||||||
|
render(<ResponsesTable {...defaultProps} error={{} as Error} />);
|
||||||
|
expect(
|
||||||
|
screen.getByText("Error fetching data: An unknown error occurred"),
|
||||||
|
).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Empty State", () => {
|
||||||
|
test('renders "No responses found." and no table when data array is empty', () => {
|
||||||
|
render(<ResponsesTable data={[]} isLoading={false} error={null} />);
|
||||||
|
expect(screen.getByText("No responses found.")).toBeInTheDocument();
|
||||||
|
|
||||||
|
// Ensure that the table structure is NOT rendered in the empty state
|
||||||
|
const table = screen.queryByRole("table");
|
||||||
|
expect(table).not.toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Data Rendering", () => {
|
||||||
|
test("renders table caption, headers, and response data correctly", () => {
|
||||||
|
const mockResponses = [
|
||||||
|
{
|
||||||
|
id: "resp_1",
|
||||||
|
object: "response" as const,
|
||||||
|
created_at: 1710000000,
|
||||||
|
model: "llama-test-model",
|
||||||
|
status: "completed",
|
||||||
|
output: [
|
||||||
|
{
|
||||||
|
type: "message" as const,
|
||||||
|
role: "assistant" as const,
|
||||||
|
content: "Test output",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
input: [
|
||||||
|
{
|
||||||
|
type: "message",
|
||||||
|
role: "user",
|
||||||
|
content: "Test input",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: "resp_2",
|
||||||
|
object: "response" as const,
|
||||||
|
created_at: 1710001000,
|
||||||
|
model: "llama-another-model",
|
||||||
|
status: "completed",
|
||||||
|
output: [
|
||||||
|
{
|
||||||
|
type: "message" as const,
|
||||||
|
role: "assistant" as const,
|
||||||
|
content: "Another output",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
input: [
|
||||||
|
{
|
||||||
|
type: "message",
|
||||||
|
role: "user",
|
||||||
|
content: "Another input",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
render(
|
||||||
|
<ResponsesTable data={mockResponses} isLoading={false} error={null} />,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Table caption
|
||||||
|
expect(
|
||||||
|
screen.getByText("A list of your recent responses."),
|
||||||
|
).toBeInTheDocument();
|
||||||
|
|
||||||
|
// Table headers
|
||||||
|
expect(screen.getByText("Input")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Output")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Model")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Created")).toBeInTheDocument();
|
||||||
|
|
||||||
|
// Data rows
|
||||||
|
expect(screen.getByText("Test input")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Test output")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("llama-test-model")).toBeInTheDocument();
|
||||||
|
expect(
|
||||||
|
screen.getByText(new Date(1710000000 * 1000).toLocaleString()),
|
||||||
|
).toBeInTheDocument();
|
||||||
|
|
||||||
|
expect(screen.getByText("Another input")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Another output")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("llama-another-model")).toBeInTheDocument();
|
||||||
|
expect(
|
||||||
|
screen.getByText(new Date(1710001000 * 1000).toLocaleString()),
|
||||||
|
).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Input Text Extraction", () => {
|
||||||
|
test("extracts text from string content", () => {
|
||||||
|
const mockResponse: OpenAIResponse = {
|
||||||
|
id: "resp_string",
|
||||||
|
object: "response",
|
||||||
|
created_at: 1710000000,
|
||||||
|
model: "test-model",
|
||||||
|
status: "completed",
|
||||||
|
output: [{ type: "message", role: "assistant", content: "output" }],
|
||||||
|
input: [
|
||||||
|
{
|
||||||
|
type: "message",
|
||||||
|
role: "user",
|
||||||
|
content: "Simple string input",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
render(
|
||||||
|
<ResponsesTable data={[mockResponse]} isLoading={false} error={null} />,
|
||||||
|
);
|
||||||
|
expect(screen.getByText("Simple string input")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("extracts text from array content with input_text type", () => {
|
||||||
|
const mockResponse: OpenAIResponse = {
|
||||||
|
id: "resp_array",
|
||||||
|
object: "response",
|
||||||
|
created_at: 1710000000,
|
||||||
|
model: "test-model",
|
||||||
|
status: "completed",
|
||||||
|
output: [{ type: "message", role: "assistant", content: "output" }],
|
||||||
|
input: [
|
||||||
|
{
|
||||||
|
type: "message",
|
||||||
|
role: "user",
|
||||||
|
content: [
|
||||||
|
{ type: "input_text", text: "Array input text" },
|
||||||
|
{ type: "input_text", text: "Should not be used" },
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
render(
|
||||||
|
<ResponsesTable data={[mockResponse]} isLoading={false} error={null} />,
|
||||||
|
);
|
||||||
|
expect(screen.getByText("Array input text")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("returns empty string when no message input found", () => {
|
||||||
|
const mockResponse: OpenAIResponse = {
|
||||||
|
id: "resp_no_input",
|
||||||
|
object: "response",
|
||||||
|
created_at: 1710000000,
|
||||||
|
model: "test-model",
|
||||||
|
status: "completed",
|
||||||
|
output: [{ type: "message", role: "assistant", content: "output" }],
|
||||||
|
input: [
|
||||||
|
{
|
||||||
|
type: "other_type",
|
||||||
|
content: "Not a message",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
const { container } = render(
|
||||||
|
<ResponsesTable data={[mockResponse]} isLoading={false} error={null} />,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Find the input cell (first cell in the data row) and verify it's empty
|
||||||
|
const inputCell = container.querySelector("tbody tr td:first-child");
|
||||||
|
expect(inputCell).toBeInTheDocument();
|
||||||
|
expect(inputCell).toHaveTextContent("");
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Output Text Extraction", () => {
|
||||||
|
test("extracts text from string message content", () => {
|
||||||
|
const mockResponse: OpenAIResponse = {
|
||||||
|
id: "resp_string_output",
|
||||||
|
object: "response",
|
||||||
|
created_at: 1710000000,
|
||||||
|
model: "test-model",
|
||||||
|
status: "completed",
|
||||||
|
output: [
|
||||||
|
{
|
||||||
|
type: "message",
|
||||||
|
role: "assistant",
|
||||||
|
content: "Simple string output",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
input: [{ type: "message", content: "input" }],
|
||||||
|
};
|
||||||
|
|
||||||
|
render(
|
||||||
|
<ResponsesTable data={[mockResponse]} isLoading={false} error={null} />,
|
||||||
|
);
|
||||||
|
expect(screen.getByText("Simple string output")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("extracts text from array message content with output_text type", () => {
|
||||||
|
const mockResponse: OpenAIResponse = {
|
||||||
|
id: "resp_array_output",
|
||||||
|
object: "response",
|
||||||
|
created_at: 1710000000,
|
||||||
|
model: "test-model",
|
||||||
|
status: "completed",
|
||||||
|
output: [
|
||||||
|
{
|
||||||
|
type: "message",
|
||||||
|
role: "assistant",
|
||||||
|
content: [
|
||||||
|
{ type: "output_text", text: "Array output text" },
|
||||||
|
{ type: "output_text", text: "Should not be used" },
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
input: [{ type: "message", content: "input" }],
|
||||||
|
};
|
||||||
|
|
||||||
|
render(
|
||||||
|
<ResponsesTable data={[mockResponse]} isLoading={false} error={null} />,
|
||||||
|
);
|
||||||
|
expect(screen.getByText("Array output text")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("formats function call output", () => {
|
||||||
|
const mockResponse: OpenAIResponse = {
|
||||||
|
id: "resp_function_call",
|
||||||
|
object: "response",
|
||||||
|
created_at: 1710000000,
|
||||||
|
model: "test-model",
|
||||||
|
status: "completed",
|
||||||
|
output: [
|
||||||
|
{
|
||||||
|
type: "function_call",
|
||||||
|
id: "call_123",
|
||||||
|
status: "completed",
|
||||||
|
name: "search_function",
|
||||||
|
arguments: '{"query": "test"}',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
input: [{ type: "message", content: "input" }],
|
||||||
|
};
|
||||||
|
|
||||||
|
render(
|
||||||
|
<ResponsesTable data={[mockResponse]} isLoading={false} error={null} />,
|
||||||
|
);
|
||||||
|
expect(
|
||||||
|
screen.getByText('search_function({"query": "test"})'),
|
||||||
|
).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("formats function call output without arguments", () => {
|
||||||
|
const mockResponse: OpenAIResponse = {
|
||||||
|
id: "resp_function_no_args",
|
||||||
|
object: "response",
|
||||||
|
created_at: 1710000000,
|
||||||
|
model: "test-model",
|
||||||
|
status: "completed",
|
||||||
|
output: [
|
||||||
|
{
|
||||||
|
type: "function_call",
|
||||||
|
id: "call_123",
|
||||||
|
status: "completed",
|
||||||
|
name: "simple_function",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
input: [{ type: "message", content: "input" }],
|
||||||
|
};
|
||||||
|
|
||||||
|
render(
|
||||||
|
<ResponsesTable data={[mockResponse]} isLoading={false} error={null} />,
|
||||||
|
);
|
||||||
|
expect(screen.getByText("simple_function({})")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("formats web search call output", () => {
|
||||||
|
const mockResponse: OpenAIResponse = {
|
||||||
|
id: "resp_web_search",
|
||||||
|
object: "response",
|
||||||
|
created_at: 1710000000,
|
||||||
|
model: "test-model",
|
||||||
|
status: "completed",
|
||||||
|
output: [
|
||||||
|
{
|
||||||
|
type: "web_search_call",
|
||||||
|
id: "search_123",
|
||||||
|
status: "completed",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
input: [{ type: "message", content: "input" }],
|
||||||
|
};
|
||||||
|
|
||||||
|
render(
|
||||||
|
<ResponsesTable data={[mockResponse]} isLoading={false} error={null} />,
|
||||||
|
);
|
||||||
|
expect(
|
||||||
|
screen.getByText("web_search_call(status: completed)"),
|
||||||
|
).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("falls back to JSON.stringify for unknown tool call types", () => {
|
||||||
|
const mockResponse: OpenAIResponse = {
|
||||||
|
id: "resp_unknown_tool",
|
||||||
|
object: "response",
|
||||||
|
created_at: 1710000000,
|
||||||
|
model: "test-model",
|
||||||
|
status: "completed",
|
||||||
|
output: [
|
||||||
|
{
|
||||||
|
type: "unknown_call",
|
||||||
|
id: "unknown_123",
|
||||||
|
status: "completed",
|
||||||
|
custom_field: "custom_value",
|
||||||
|
} as any,
|
||||||
|
],
|
||||||
|
input: [{ type: "message", content: "input" }],
|
||||||
|
};
|
||||||
|
|
||||||
|
render(
|
||||||
|
<ResponsesTable data={[mockResponse]} isLoading={false} error={null} />,
|
||||||
|
);
|
||||||
|
// Should contain the JSON stringified version
|
||||||
|
expect(screen.getByText(/unknown_call/)).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("falls back to JSON.stringify for entire output when no message or tool call found", () => {
|
||||||
|
const mockResponse: OpenAIResponse = {
|
||||||
|
id: "resp_fallback",
|
||||||
|
object: "response",
|
||||||
|
created_at: 1710000000,
|
||||||
|
model: "test-model",
|
||||||
|
status: "completed",
|
||||||
|
output: [
|
||||||
|
{
|
||||||
|
type: "unknown_type",
|
||||||
|
data: "some data",
|
||||||
|
} as any,
|
||||||
|
],
|
||||||
|
input: [{ type: "message", content: "input" }],
|
||||||
|
};
|
||||||
|
|
||||||
|
render(
|
||||||
|
<ResponsesTable data={[mockResponse]} isLoading={false} error={null} />,
|
||||||
|
);
|
||||||
|
// Should contain the JSON stringified version of the output array
|
||||||
|
expect(screen.getByText(/unknown_type/)).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Text Truncation", () => {
|
||||||
|
test("truncates long input and output text", () => {
|
||||||
|
// Specific mock implementation for this test
|
||||||
|
truncateText.mockImplementation(
|
||||||
|
(text: string | undefined, maxLength?: number) => {
|
||||||
|
const defaultTestMaxLength = 10;
|
||||||
|
const effectiveMaxLength = maxLength ?? defaultTestMaxLength;
|
||||||
|
return typeof text === "string" && text.length > effectiveMaxLength
|
||||||
|
? text.slice(0, effectiveMaxLength) + "..."
|
||||||
|
: text;
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
const longInput =
|
||||||
|
"This is a very long input message that should be truncated.";
|
||||||
|
const longOutput =
|
||||||
|
"This is a very long output message that should also be truncated.";
|
||||||
|
|
||||||
|
const mockResponse: OpenAIResponse = {
|
||||||
|
id: "resp_trunc",
|
||||||
|
object: "response",
|
||||||
|
created_at: 1710002000,
|
||||||
|
model: "llama-trunc-model",
|
||||||
|
status: "completed",
|
||||||
|
output: [
|
||||||
|
{
|
||||||
|
type: "message",
|
||||||
|
role: "assistant",
|
||||||
|
content: longOutput,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
input: [
|
||||||
|
{
|
||||||
|
type: "message",
|
||||||
|
role: "user",
|
||||||
|
content: longInput,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
render(
|
||||||
|
<ResponsesTable data={[mockResponse]} isLoading={false} error={null} />,
|
||||||
|
);
|
||||||
|
|
||||||
|
// The truncated text should be present for both input and output
|
||||||
|
const truncatedTexts = screen.getAllByText(
|
||||||
|
longInput.slice(0, 10) + "...",
|
||||||
|
);
|
||||||
|
expect(truncatedTexts.length).toBe(2); // one for input, one for output
|
||||||
|
truncatedTexts.forEach((textElement) =>
|
||||||
|
expect(textElement).toBeInTheDocument(),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
117
llama_stack/ui/components/responses/responses-table.tsx
Normal file
117
llama_stack/ui/components/responses/responses-table.tsx
Normal file
|
@ -0,0 +1,117 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
|
import {
|
||||||
|
OpenAIResponse,
|
||||||
|
ResponseInput,
|
||||||
|
ResponseInputMessageContent,
|
||||||
|
} from "@/lib/types";
|
||||||
|
import { LogsTable, LogTableRow } from "@/components/logs/logs-table";
|
||||||
|
import {
|
||||||
|
isMessageInput,
|
||||||
|
isMessageItem,
|
||||||
|
isFunctionCallItem,
|
||||||
|
isWebSearchCallItem,
|
||||||
|
MessageItem,
|
||||||
|
FunctionCallItem,
|
||||||
|
WebSearchCallItem,
|
||||||
|
} from "./utils/item-types";
|
||||||
|
|
||||||
|
interface ResponsesTableProps {
|
||||||
|
data: OpenAIResponse[];
|
||||||
|
isLoading: boolean;
|
||||||
|
error: Error | null;
|
||||||
|
}
|
||||||
|
|
||||||
|
function getInputText(response: OpenAIResponse): string {
|
||||||
|
const firstInput = response.input.find(isMessageInput);
|
||||||
|
if (firstInput) {
|
||||||
|
return extractContentFromItem(firstInput);
|
||||||
|
}
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
function getOutputText(response: OpenAIResponse): string {
|
||||||
|
const firstMessage = response.output.find((item) =>
|
||||||
|
isMessageItem(item as any),
|
||||||
|
);
|
||||||
|
if (firstMessage) {
|
||||||
|
const content = extractContentFromItem(firstMessage as MessageItem);
|
||||||
|
if (content) {
|
||||||
|
return content;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const functionCall = response.output.find((item) =>
|
||||||
|
isFunctionCallItem(item as any),
|
||||||
|
);
|
||||||
|
if (functionCall) {
|
||||||
|
return formatFunctionCall(functionCall as FunctionCallItem);
|
||||||
|
}
|
||||||
|
|
||||||
|
const webSearchCall = response.output.find((item) =>
|
||||||
|
isWebSearchCallItem(item as any),
|
||||||
|
);
|
||||||
|
if (webSearchCall) {
|
||||||
|
return formatWebSearchCall(webSearchCall as WebSearchCallItem);
|
||||||
|
}
|
||||||
|
|
||||||
|
return JSON.stringify(response.output);
|
||||||
|
}
|
||||||
|
|
||||||
|
function extractContentFromItem(item: {
|
||||||
|
content?: string | ResponseInputMessageContent[];
|
||||||
|
}): string {
|
||||||
|
if (!item.content) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (typeof item.content === "string") {
|
||||||
|
return item.content;
|
||||||
|
} else if (Array.isArray(item.content)) {
|
||||||
|
const textContent = item.content.find(
|
||||||
|
(c: ResponseInputMessageContent) =>
|
||||||
|
c.type === "input_text" || c.type === "output_text",
|
||||||
|
);
|
||||||
|
return textContent?.text || "";
|
||||||
|
}
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
function formatFunctionCall(functionCall: FunctionCallItem): string {
|
||||||
|
const args = functionCall.arguments || "{}";
|
||||||
|
const name = functionCall.name || "unknown";
|
||||||
|
return `${name}(${args})`;
|
||||||
|
}
|
||||||
|
|
||||||
|
function formatWebSearchCall(webSearchCall: WebSearchCallItem): string {
|
||||||
|
return `web_search_call(status: ${webSearchCall.status})`;
|
||||||
|
}
|
||||||
|
|
||||||
|
function formatResponseToRow(response: OpenAIResponse): LogTableRow {
|
||||||
|
return {
|
||||||
|
id: response.id,
|
||||||
|
input: getInputText(response),
|
||||||
|
output: getOutputText(response),
|
||||||
|
model: response.model,
|
||||||
|
createdTime: new Date(response.created_at * 1000).toLocaleString(),
|
||||||
|
detailPath: `/logs/responses/${response.id}`,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
export function ResponsesTable({
|
||||||
|
data,
|
||||||
|
isLoading,
|
||||||
|
error,
|
||||||
|
}: ResponsesTableProps) {
|
||||||
|
const formattedData = data.map(formatResponseToRow);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<LogsTable
|
||||||
|
data={formattedData}
|
||||||
|
isLoading={isLoading}
|
||||||
|
error={error}
|
||||||
|
caption="A list of your recent responses."
|
||||||
|
emptyMessage="No responses found."
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
61
llama_stack/ui/components/responses/utils/item-types.ts
Normal file
61
llama_stack/ui/components/responses/utils/item-types.ts
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
/**
|
||||||
|
* Type guards for different item types in responses
|
||||||
|
*/
|
||||||
|
|
||||||
|
import type {
|
||||||
|
ResponseInput,
|
||||||
|
ResponseOutput,
|
||||||
|
ResponseMessage,
|
||||||
|
ResponseToolCall,
|
||||||
|
} from "@/lib/types";
|
||||||
|
|
||||||
|
export interface BaseItem {
|
||||||
|
type: string;
|
||||||
|
[key: string]: unknown;
|
||||||
|
}
|
||||||
|
|
||||||
|
export type MessageItem = ResponseMessage;
|
||||||
|
export type FunctionCallItem = ResponseToolCall & { type: "function_call" };
|
||||||
|
export type WebSearchCallItem = ResponseToolCall & { type: "web_search_call" };
|
||||||
|
export type FunctionCallOutputItem = BaseItem & {
|
||||||
|
type: "function_call_output";
|
||||||
|
call_id: string;
|
||||||
|
output?: string | object;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type AnyResponseItem =
|
||||||
|
| ResponseInput
|
||||||
|
| ResponseOutput
|
||||||
|
| FunctionCallOutputItem;
|
||||||
|
|
||||||
|
export function isMessageInput(
|
||||||
|
item: ResponseInput,
|
||||||
|
): item is ResponseInput & { type: "message" } {
|
||||||
|
return item.type === "message";
|
||||||
|
}
|
||||||
|
|
||||||
|
export function isMessageItem(item: AnyResponseItem): item is MessageItem {
|
||||||
|
return item.type === "message" && "content" in item;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function isFunctionCallItem(
|
||||||
|
item: AnyResponseItem,
|
||||||
|
): item is FunctionCallItem {
|
||||||
|
return item.type === "function_call" && "name" in item;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function isWebSearchCallItem(
|
||||||
|
item: AnyResponseItem,
|
||||||
|
): item is WebSearchCallItem {
|
||||||
|
return item.type === "web_search_call";
|
||||||
|
}
|
||||||
|
|
||||||
|
export function isFunctionCallOutputItem(
|
||||||
|
item: AnyResponseItem,
|
||||||
|
): item is FunctionCallOutputItem {
|
||||||
|
return (
|
||||||
|
item.type === "function_call_output" &&
|
||||||
|
"call_id" in item &&
|
||||||
|
typeof (item as any).call_id === "string"
|
||||||
|
);
|
||||||
|
}
|
49
llama_stack/ui/components/ui/message-components.tsx
Normal file
49
llama_stack/ui/components/ui/message-components.tsx
Normal file
|
@ -0,0 +1,49 @@
|
||||||
|
import React from "react";
|
||||||
|
|
||||||
|
export interface MessageBlockProps {
|
||||||
|
label: string;
|
||||||
|
labelDetail?: string;
|
||||||
|
content: React.ReactNode;
|
||||||
|
className?: string;
|
||||||
|
contentClassName?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export const MessageBlock: React.FC<MessageBlockProps> = ({
|
||||||
|
label,
|
||||||
|
labelDetail,
|
||||||
|
content,
|
||||||
|
className = "",
|
||||||
|
contentClassName = "",
|
||||||
|
}) => {
|
||||||
|
return (
|
||||||
|
<div className={`mb-4 ${className}`}>
|
||||||
|
<p className="py-1 font-semibold text-gray-800 mb-1">
|
||||||
|
{label}
|
||||||
|
{labelDetail && (
|
||||||
|
<span className="text-xs text-gray-500 font-normal ml-1">
|
||||||
|
{labelDetail}
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
</p>
|
||||||
|
<div className={`py-1 whitespace-pre-wrap ${contentClassName}`}>
|
||||||
|
{content}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export interface ToolCallBlockProps {
|
||||||
|
children: React.ReactNode;
|
||||||
|
className?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export const ToolCallBlock = ({ children, className }: ToolCallBlockProps) => {
|
||||||
|
const baseClassName =
|
||||||
|
"p-3 bg-slate-50 border border-slate-200 rounded-md text-sm";
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className={`${baseClassName} ${className || ""}`}>
|
||||||
|
<pre className="whitespace-pre-wrap text-xs">{children}</pre>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
12
llama_stack/ui/lib/client.ts
Normal file
12
llama_stack/ui/lib/client.ts
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
import LlamaStackClient from "llama-stack-client";
|
||||||
|
import OpenAI from "openai";
|
||||||
|
|
||||||
|
export const client =
|
||||||
|
process.env.NEXT_PUBLIC_USE_OPENAI_CLIENT === "true" // useful for testing
|
||||||
|
? new OpenAI({
|
||||||
|
apiKey: process.env.NEXT_PUBLIC_OPENAI_API_KEY,
|
||||||
|
dangerouslyAllowBrowser: true,
|
||||||
|
})
|
||||||
|
: new LlamaStackClient({
|
||||||
|
baseURL: process.env.NEXT_PUBLIC_LLAMA_STACK_BASE_URL,
|
||||||
|
});
|
|
@ -43,10 +43,14 @@ export function extractDisplayableText(
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
|
|
||||||
let textPart = extractTextFromContentPart(message.content);
|
const textPart = extractTextFromContentPart(message.content);
|
||||||
let toolCallPart = "";
|
let toolCallPart = "";
|
||||||
|
|
||||||
if (message.tool_calls && message.tool_calls.length > 0) {
|
if (
|
||||||
|
message.tool_calls &&
|
||||||
|
Array.isArray(message.tool_calls) &&
|
||||||
|
message.tool_calls.length > 0
|
||||||
|
) {
|
||||||
// For summary, usually the first tool call is sufficient
|
// For summary, usually the first tool call is sufficient
|
||||||
toolCallPart = formatToolCallToString(message.tool_calls[0]);
|
toolCallPart = formatToolCallToString(message.tool_calls[0]);
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,20 +18,20 @@ export interface ImageUrlContentBlock {
|
||||||
export type ChatMessageContentPart =
|
export type ChatMessageContentPart =
|
||||||
| TextContentBlock
|
| TextContentBlock
|
||||||
| ImageUrlContentBlock
|
| ImageUrlContentBlock
|
||||||
| { type: string; [key: string]: any }; // Fallback for other potential types
|
| { type: string; [key: string]: unknown }; // Fallback for other potential types
|
||||||
|
|
||||||
export interface ChatMessage {
|
export interface ChatMessage {
|
||||||
role: string;
|
role: string;
|
||||||
content: string | ChatMessageContentPart[]; // Updated content type
|
content: string | ChatMessageContentPart[]; // Updated content type
|
||||||
name?: string | null;
|
name?: string | null;
|
||||||
tool_calls?: any | null; // This could also be refined to a more specific ToolCall[] type
|
tool_calls?: unknown | null; // This could also be refined to a more specific ToolCall[] type
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface Choice {
|
export interface Choice {
|
||||||
message: ChatMessage;
|
message: ChatMessage;
|
||||||
finish_reason: string;
|
finish_reason: string;
|
||||||
index: number;
|
index: number;
|
||||||
logprobs?: any | null;
|
logprobs?: unknown | null;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ChatCompletion {
|
export interface ChatCompletion {
|
||||||
|
@ -42,3 +42,62 @@ export interface ChatCompletion {
|
||||||
model: string;
|
model: string;
|
||||||
input_messages: ChatMessage[];
|
input_messages: ChatMessage[];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Response types for OpenAI Responses API
|
||||||
|
export interface ResponseInputMessageContent {
|
||||||
|
text?: string;
|
||||||
|
type: "input_text" | "input_image" | "output_text";
|
||||||
|
image_url?: string;
|
||||||
|
detail?: "low" | "high" | "auto";
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ResponseMessage {
|
||||||
|
content: string | ResponseInputMessageContent[];
|
||||||
|
role: "system" | "developer" | "user" | "assistant";
|
||||||
|
type: "message";
|
||||||
|
id?: string;
|
||||||
|
status?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ResponseToolCall {
|
||||||
|
id: string;
|
||||||
|
status: string;
|
||||||
|
type: "web_search_call" | "function_call";
|
||||||
|
arguments?: string;
|
||||||
|
call_id?: string;
|
||||||
|
name?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export type ResponseOutput = ResponseMessage | ResponseToolCall;
|
||||||
|
|
||||||
|
export interface ResponseInput {
|
||||||
|
type: string;
|
||||||
|
content?: string | ResponseInputMessageContent[];
|
||||||
|
role?: string;
|
||||||
|
[key: string]: unknown; // Flexible for various input types
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface OpenAIResponse {
|
||||||
|
id: string;
|
||||||
|
created_at: number;
|
||||||
|
model: string;
|
||||||
|
object: "response";
|
||||||
|
status: string;
|
||||||
|
output: ResponseOutput[];
|
||||||
|
input: ResponseInput[];
|
||||||
|
error?: {
|
||||||
|
code: string;
|
||||||
|
message: string;
|
||||||
|
};
|
||||||
|
parallel_tool_calls?: boolean;
|
||||||
|
previous_response_id?: string;
|
||||||
|
temperature?: number;
|
||||||
|
top_p?: number;
|
||||||
|
truncation?: string;
|
||||||
|
user?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface InputItemListResponse {
|
||||||
|
data: ResponseInput[];
|
||||||
|
object: "list";
|
||||||
|
}
|
||||||
|
|
52
llama_stack/ui/package-lock.json
generated
52
llama_stack/ui/package-lock.json
generated
|
@ -19,6 +19,7 @@
|
||||||
"lucide-react": "^0.510.0",
|
"lucide-react": "^0.510.0",
|
||||||
"next": "15.3.2",
|
"next": "15.3.2",
|
||||||
"next-themes": "^0.4.6",
|
"next-themes": "^0.4.6",
|
||||||
|
"openai": "^4.103.0",
|
||||||
"react": "^19.0.0",
|
"react": "^19.0.0",
|
||||||
"react-dom": "^19.0.0",
|
"react-dom": "^19.0.0",
|
||||||
"tailwind-merge": "^3.3.0"
|
"tailwind-merge": "^3.3.0"
|
||||||
|
@ -9092,7 +9093,7 @@
|
||||||
},
|
},
|
||||||
"node_modules/llama-stack-client": {
|
"node_modules/llama-stack-client": {
|
||||||
"version": "0.0.1-alpha.0",
|
"version": "0.0.1-alpha.0",
|
||||||
"resolved": "git+ssh://git@github.com/stainless-sdks/llama-stack-node.git#efa814980d44b3b2c92944377a086915137b2134",
|
"resolved": "git+ssh://git@github.com/stainless-sdks/llama-stack-node.git#5d34d229fb53b6dad02da0f19f4b310b529c6b15",
|
||||||
"license": "Apache-2.0",
|
"license": "Apache-2.0",
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@types/node": "^18.11.18",
|
"@types/node": "^18.11.18",
|
||||||
|
@ -9804,6 +9805,51 @@
|
||||||
"url": "https://github.com/sponsors/sindresorhus"
|
"url": "https://github.com/sponsors/sindresorhus"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/openai": {
|
||||||
|
"version": "4.103.0",
|
||||||
|
"resolved": "https://registry.npmjs.org/openai/-/openai-4.103.0.tgz",
|
||||||
|
"integrity": "sha512-eWcz9kdurkGOFDtd5ySS5y251H2uBgq9+1a2lTBnjMMzlexJ40Am5t6Mu76SSE87VvitPa0dkIAp75F+dZVC0g==",
|
||||||
|
"license": "Apache-2.0",
|
||||||
|
"dependencies": {
|
||||||
|
"@types/node": "^18.11.18",
|
||||||
|
"@types/node-fetch": "^2.6.4",
|
||||||
|
"abort-controller": "^3.0.0",
|
||||||
|
"agentkeepalive": "^4.2.1",
|
||||||
|
"form-data-encoder": "1.7.2",
|
||||||
|
"formdata-node": "^4.3.2",
|
||||||
|
"node-fetch": "^2.6.7"
|
||||||
|
},
|
||||||
|
"bin": {
|
||||||
|
"openai": "bin/cli"
|
||||||
|
},
|
||||||
|
"peerDependencies": {
|
||||||
|
"ws": "^8.18.0",
|
||||||
|
"zod": "^3.23.8"
|
||||||
|
},
|
||||||
|
"peerDependenciesMeta": {
|
||||||
|
"ws": {
|
||||||
|
"optional": true
|
||||||
|
},
|
||||||
|
"zod": {
|
||||||
|
"optional": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/openai/node_modules/@types/node": {
|
||||||
|
"version": "18.19.103",
|
||||||
|
"resolved": "https://registry.npmjs.org/@types/node/-/node-18.19.103.tgz",
|
||||||
|
"integrity": "sha512-hHTHp+sEz6SxFsp+SA+Tqrua3AbmlAw+Y//aEwdHrdZkYVRWdvWD3y5uPZ0flYOkgskaFWqZ/YGFm3FaFQ0pRw==",
|
||||||
|
"license": "MIT",
|
||||||
|
"dependencies": {
|
||||||
|
"undici-types": "~5.26.4"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/openai/node_modules/undici-types": {
|
||||||
|
"version": "5.26.5",
|
||||||
|
"resolved": "https://registry.npmjs.org/undici-types/-/undici-types-5.26.5.tgz",
|
||||||
|
"integrity": "sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==",
|
||||||
|
"license": "MIT"
|
||||||
|
},
|
||||||
"node_modules/optionator": {
|
"node_modules/optionator": {
|
||||||
"version": "0.9.4",
|
"version": "0.9.4",
|
||||||
"resolved": "https://registry.npmjs.org/optionator/-/optionator-0.9.4.tgz",
|
"resolved": "https://registry.npmjs.org/optionator/-/optionator-0.9.4.tgz",
|
||||||
|
@ -12223,7 +12269,7 @@
|
||||||
"version": "8.18.2",
|
"version": "8.18.2",
|
||||||
"resolved": "https://registry.npmjs.org/ws/-/ws-8.18.2.tgz",
|
"resolved": "https://registry.npmjs.org/ws/-/ws-8.18.2.tgz",
|
||||||
"integrity": "sha512-DMricUmwGZUVr++AEAe2uiVM7UoO9MAVZMDu05UQOaUII0lp+zOzLLU4Xqh/JvTqklB1T4uELaaPBKyjE1r4fQ==",
|
"integrity": "sha512-DMricUmwGZUVr++AEAe2uiVM7UoO9MAVZMDu05UQOaUII0lp+zOzLLU4Xqh/JvTqklB1T4uELaaPBKyjE1r4fQ==",
|
||||||
"dev": true,
|
"devOptional": true,
|
||||||
"license": "MIT",
|
"license": "MIT",
|
||||||
"engines": {
|
"engines": {
|
||||||
"node": ">=10.0.0"
|
"node": ">=10.0.0"
|
||||||
|
@ -12334,7 +12380,7 @@
|
||||||
"version": "3.24.4",
|
"version": "3.24.4",
|
||||||
"resolved": "https://registry.npmjs.org/zod/-/zod-3.24.4.tgz",
|
"resolved": "https://registry.npmjs.org/zod/-/zod-3.24.4.tgz",
|
||||||
"integrity": "sha512-OdqJE9UDRPwWsrHjLN2F8bPxvwJBK22EHLWtanu0LSYr5YqzsaaW3RMgmjwr8Rypg5k+meEJdSPXJZXE/yqOMg==",
|
"integrity": "sha512-OdqJE9UDRPwWsrHjLN2F8bPxvwJBK22EHLWtanu0LSYr5YqzsaaW3RMgmjwr8Rypg5k+meEJdSPXJZXE/yqOMg==",
|
||||||
"dev": true,
|
"devOptional": true,
|
||||||
"license": "MIT",
|
"license": "MIT",
|
||||||
"funding": {
|
"funding": {
|
||||||
"url": "https://github.com/sponsors/colinhacks"
|
"url": "https://github.com/sponsors/colinhacks"
|
||||||
|
|
|
@ -19,7 +19,7 @@
|
||||||
"@radix-ui/react-tooltip": "^1.2.6",
|
"@radix-ui/react-tooltip": "^1.2.6",
|
||||||
"class-variance-authority": "^0.7.1",
|
"class-variance-authority": "^0.7.1",
|
||||||
"clsx": "^2.1.1",
|
"clsx": "^2.1.1",
|
||||||
"llama-stack-client": "github:stainless-sdks/llama-stack-node#ehhuang/dev",
|
"llama-stack-client": "0.2.8",
|
||||||
"lucide-react": "^0.510.0",
|
"lucide-react": "^0.510.0",
|
||||||
"next": "15.3.2",
|
"next": "15.3.2",
|
||||||
"next-themes": "^0.4.6",
|
"next-themes": "^0.4.6",
|
||||||
|
|
|
@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "llama_stack"
|
name = "llama_stack"
|
||||||
version = "0.2.7"
|
version = "0.2.8"
|
||||||
authors = [{ name = "Meta Llama", email = "llama-oss@meta.com" }]
|
authors = [{ name = "Meta Llama", email = "llama-oss@meta.com" }]
|
||||||
description = "Llama Stack"
|
description = "Llama Stack"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
@ -21,13 +21,13 @@ classifiers = [
|
||||||
"Topic :: Scientific/Engineering :: Information Analysis",
|
"Topic :: Scientific/Engineering :: Information Analysis",
|
||||||
]
|
]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"blobfile",
|
"aiohttp",
|
||||||
"fire",
|
"fire",
|
||||||
"httpx",
|
"httpx",
|
||||||
"huggingface-hub",
|
"huggingface-hub",
|
||||||
"jinja2>=3.1.6",
|
"jinja2>=3.1.6",
|
||||||
"jsonschema",
|
"jsonschema",
|
||||||
"llama-stack-client>=0.2.7",
|
"llama-stack-client>=0.2.8",
|
||||||
"openai>=1.66",
|
"openai>=1.66",
|
||||||
"prompt-toolkit",
|
"prompt-toolkit",
|
||||||
"python-dotenv",
|
"python-dotenv",
|
||||||
|
@ -36,6 +36,7 @@ dependencies = [
|
||||||
"requests",
|
"requests",
|
||||||
"rich",
|
"rich",
|
||||||
"setuptools",
|
"setuptools",
|
||||||
|
"starlette",
|
||||||
"termcolor",
|
"termcolor",
|
||||||
"tiktoken",
|
"tiktoken",
|
||||||
"pillow",
|
"pillow",
|
||||||
|
@ -43,6 +44,14 @@ dependencies = [
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
ui = [
|
||||||
|
"streamlit",
|
||||||
|
"pandas",
|
||||||
|
"llama-stack-client>=0.2.8",
|
||||||
|
"streamlit-option-menu",
|
||||||
|
]
|
||||||
|
|
||||||
|
[dependency-groups]
|
||||||
dev = [
|
dev = [
|
||||||
"pytest",
|
"pytest",
|
||||||
"pytest-timeout",
|
"pytest-timeout",
|
||||||
|
@ -73,10 +82,11 @@ unit = [
|
||||||
"opentelemetry-exporter-otlp-proto-http",
|
"opentelemetry-exporter-otlp-proto-http",
|
||||||
"sqlalchemy",
|
"sqlalchemy",
|
||||||
"sqlalchemy[asyncio]>=2.0.41",
|
"sqlalchemy[asyncio]>=2.0.41",
|
||||||
|
"blobfile",
|
||||||
]
|
]
|
||||||
# These are the core dependencies required for running integration tests. They are shared across all
|
# These are the core dependencies required for running integration tests. They are shared across all
|
||||||
# providers. If a provider requires additional dependencies, please add them to your environment
|
# providers. If a provider requires additional dependencies, please add them to your environment
|
||||||
# separately. If you are using "uv" to execute your tests, you can use the "--with" flag to specify extra
|
# separately. If you are using "uv" to execute your tests, you can use the "--group" flag to specify extra
|
||||||
# dependencies.
|
# dependencies.
|
||||||
test = [
|
test = [
|
||||||
"openai",
|
"openai",
|
||||||
|
@ -112,12 +122,6 @@ docs = [
|
||||||
"sphinxcontrib.openapi",
|
"sphinxcontrib.openapi",
|
||||||
]
|
]
|
||||||
codegen = ["rich", "pydantic", "jinja2>=3.1.6"]
|
codegen = ["rich", "pydantic", "jinja2>=3.1.6"]
|
||||||
ui = [
|
|
||||||
"streamlit",
|
|
||||||
"pandas",
|
|
||||||
"llama-stack-client>=0.2.7",
|
|
||||||
"streamlit-option-menu",
|
|
||||||
]
|
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
Homepage = "https://github.com/meta-llama/llama-stack"
|
Homepage = "https://github.com/meta-llama/llama-stack"
|
||||||
|
@ -138,7 +142,6 @@ explicit = true
|
||||||
[tool.uv.sources]
|
[tool.uv.sources]
|
||||||
torch = [{ index = "pytorch-cpu" }]
|
torch = [{ index = "pytorch-cpu" }]
|
||||||
torchvision = [{ index = "pytorch-cpu" }]
|
torchvision = [{ index = "pytorch-cpu" }]
|
||||||
llama-stack = { workspace = true }
|
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
line-length = 120
|
line-length = 120
|
||||||
|
@ -332,10 +335,5 @@ init_forbid_extra = true
|
||||||
init_typed = true
|
init_typed = true
|
||||||
warn_required_dynamic_aliases = true
|
warn_required_dynamic_aliases = true
|
||||||
|
|
||||||
[dependency-groups]
|
|
||||||
dev = [
|
|
||||||
"llama-stack",
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.ruff.lint.pep8-naming]
|
[tool.ruff.lint.pep8-naming]
|
||||||
classmethod-decorators = ["classmethod", "pydantic.field_validator"]
|
classmethod-decorators = ["classmethod", "pydantic.field_validator"]
|
||||||
|
|
150
requirements.txt
150
requirements.txt
|
@ -1,63 +1,203 @@
|
||||||
# This file was autogenerated by uv via the following command:
|
# This file was autogenerated by uv via the following command:
|
||||||
# uv export --frozen --no-hashes --no-emit-project --output-file=requirements.txt
|
# uv export --frozen --no-hashes --no-emit-project --no-default-groups --output-file=requirements.txt
|
||||||
|
aiohappyeyeballs==2.5.0
|
||||||
|
# via aiohttp
|
||||||
|
aiohttp==3.11.13
|
||||||
|
# via llama-stack
|
||||||
|
aiosignal==1.3.2
|
||||||
|
# via aiohttp
|
||||||
annotated-types==0.7.0
|
annotated-types==0.7.0
|
||||||
|
# via pydantic
|
||||||
anyio==4.8.0
|
anyio==4.8.0
|
||||||
|
# via
|
||||||
|
# httpx
|
||||||
|
# llama-stack-client
|
||||||
|
# openai
|
||||||
|
# starlette
|
||||||
|
async-timeout==5.0.1 ; python_full_version < '3.11'
|
||||||
|
# via aiohttp
|
||||||
attrs==25.1.0
|
attrs==25.1.0
|
||||||
blobfile==3.0.0
|
# via
|
||||||
|
# aiohttp
|
||||||
|
# jsonschema
|
||||||
|
# referencing
|
||||||
certifi==2025.1.31
|
certifi==2025.1.31
|
||||||
|
# via
|
||||||
|
# httpcore
|
||||||
|
# httpx
|
||||||
|
# requests
|
||||||
charset-normalizer==3.4.1
|
charset-normalizer==3.4.1
|
||||||
|
# via requests
|
||||||
click==8.1.8
|
click==8.1.8
|
||||||
|
# via llama-stack-client
|
||||||
colorama==0.4.6 ; sys_platform == 'win32'
|
colorama==0.4.6 ; sys_platform == 'win32'
|
||||||
|
# via
|
||||||
|
# click
|
||||||
|
# tqdm
|
||||||
distro==1.9.0
|
distro==1.9.0
|
||||||
|
# via
|
||||||
|
# llama-stack-client
|
||||||
|
# openai
|
||||||
ecdsa==0.19.1
|
ecdsa==0.19.1
|
||||||
|
# via python-jose
|
||||||
exceptiongroup==1.2.2 ; python_full_version < '3.11'
|
exceptiongroup==1.2.2 ; python_full_version < '3.11'
|
||||||
|
# via anyio
|
||||||
filelock==3.17.0
|
filelock==3.17.0
|
||||||
|
# via huggingface-hub
|
||||||
fire==0.7.0
|
fire==0.7.0
|
||||||
|
# via llama-stack
|
||||||
|
frozenlist==1.5.0
|
||||||
|
# via
|
||||||
|
# aiohttp
|
||||||
|
# aiosignal
|
||||||
fsspec==2024.12.0
|
fsspec==2024.12.0
|
||||||
|
# via huggingface-hub
|
||||||
h11==0.16.0
|
h11==0.16.0
|
||||||
|
# via
|
||||||
|
# httpcore
|
||||||
|
# llama-stack
|
||||||
httpcore==1.0.9
|
httpcore==1.0.9
|
||||||
|
# via httpx
|
||||||
httpx==0.28.1
|
httpx==0.28.1
|
||||||
|
# via
|
||||||
|
# llama-stack
|
||||||
|
# llama-stack-client
|
||||||
|
# openai
|
||||||
huggingface-hub==0.29.0
|
huggingface-hub==0.29.0
|
||||||
|
# via llama-stack
|
||||||
idna==3.10
|
idna==3.10
|
||||||
|
# via
|
||||||
|
# anyio
|
||||||
|
# httpx
|
||||||
|
# requests
|
||||||
|
# yarl
|
||||||
jinja2==3.1.6
|
jinja2==3.1.6
|
||||||
|
# via llama-stack
|
||||||
jiter==0.8.2
|
jiter==0.8.2
|
||||||
|
# via openai
|
||||||
jsonschema==4.23.0
|
jsonschema==4.23.0
|
||||||
|
# via llama-stack
|
||||||
jsonschema-specifications==2024.10.1
|
jsonschema-specifications==2024.10.1
|
||||||
llama-stack-client==0.2.7
|
# via jsonschema
|
||||||
lxml==5.3.1
|
llama-stack-client==0.2.8
|
||||||
|
# via llama-stack
|
||||||
markdown-it-py==3.0.0
|
markdown-it-py==3.0.0
|
||||||
|
# via rich
|
||||||
markupsafe==3.0.2
|
markupsafe==3.0.2
|
||||||
|
# via jinja2
|
||||||
mdurl==0.1.2
|
mdurl==0.1.2
|
||||||
|
# via markdown-it-py
|
||||||
|
multidict==6.1.0
|
||||||
|
# via
|
||||||
|
# aiohttp
|
||||||
|
# yarl
|
||||||
numpy==2.2.3
|
numpy==2.2.3
|
||||||
|
# via pandas
|
||||||
openai==1.71.0
|
openai==1.71.0
|
||||||
|
# via llama-stack
|
||||||
packaging==24.2
|
packaging==24.2
|
||||||
|
# via huggingface-hub
|
||||||
pandas==2.2.3
|
pandas==2.2.3
|
||||||
|
# via llama-stack-client
|
||||||
pillow==11.1.0
|
pillow==11.1.0
|
||||||
|
# via llama-stack
|
||||||
prompt-toolkit==3.0.50
|
prompt-toolkit==3.0.50
|
||||||
|
# via
|
||||||
|
# llama-stack
|
||||||
|
# llama-stack-client
|
||||||
|
propcache==0.3.0
|
||||||
|
# via
|
||||||
|
# aiohttp
|
||||||
|
# yarl
|
||||||
pyaml==25.1.0
|
pyaml==25.1.0
|
||||||
|
# via llama-stack-client
|
||||||
pyasn1==0.4.8
|
pyasn1==0.4.8
|
||||||
pycryptodomex==3.21.0
|
# via
|
||||||
|
# python-jose
|
||||||
|
# rsa
|
||||||
pydantic==2.10.6
|
pydantic==2.10.6
|
||||||
|
# via
|
||||||
|
# llama-stack
|
||||||
|
# llama-stack-client
|
||||||
|
# openai
|
||||||
pydantic-core==2.27.2
|
pydantic-core==2.27.2
|
||||||
|
# via pydantic
|
||||||
pygments==2.19.1
|
pygments==2.19.1
|
||||||
|
# via rich
|
||||||
python-dateutil==2.9.0.post0
|
python-dateutil==2.9.0.post0
|
||||||
|
# via pandas
|
||||||
python-dotenv==1.0.1
|
python-dotenv==1.0.1
|
||||||
|
# via llama-stack
|
||||||
python-jose==3.4.0
|
python-jose==3.4.0
|
||||||
|
# via llama-stack
|
||||||
pytz==2025.1
|
pytz==2025.1
|
||||||
|
# via pandas
|
||||||
pyyaml==6.0.2
|
pyyaml==6.0.2
|
||||||
|
# via
|
||||||
|
# huggingface-hub
|
||||||
|
# pyaml
|
||||||
referencing==0.36.2
|
referencing==0.36.2
|
||||||
|
# via
|
||||||
|
# jsonschema
|
||||||
|
# jsonschema-specifications
|
||||||
regex==2024.11.6
|
regex==2024.11.6
|
||||||
|
# via tiktoken
|
||||||
requests==2.32.3
|
requests==2.32.3
|
||||||
|
# via
|
||||||
|
# huggingface-hub
|
||||||
|
# llama-stack
|
||||||
|
# tiktoken
|
||||||
rich==13.9.4
|
rich==13.9.4
|
||||||
|
# via
|
||||||
|
# llama-stack
|
||||||
|
# llama-stack-client
|
||||||
rpds-py==0.22.3
|
rpds-py==0.22.3
|
||||||
|
# via
|
||||||
|
# jsonschema
|
||||||
|
# referencing
|
||||||
rsa==4.9
|
rsa==4.9
|
||||||
|
# via python-jose
|
||||||
setuptools==80.8.0
|
setuptools==80.8.0
|
||||||
|
# via llama-stack
|
||||||
six==1.17.0
|
six==1.17.0
|
||||||
|
# via
|
||||||
|
# ecdsa
|
||||||
|
# python-dateutil
|
||||||
sniffio==1.3.1
|
sniffio==1.3.1
|
||||||
|
# via
|
||||||
|
# anyio
|
||||||
|
# llama-stack-client
|
||||||
|
# openai
|
||||||
|
starlette==0.45.3
|
||||||
|
# via llama-stack
|
||||||
termcolor==2.5.0
|
termcolor==2.5.0
|
||||||
|
# via
|
||||||
|
# fire
|
||||||
|
# llama-stack
|
||||||
|
# llama-stack-client
|
||||||
tiktoken==0.9.0
|
tiktoken==0.9.0
|
||||||
|
# via llama-stack
|
||||||
tqdm==4.67.1
|
tqdm==4.67.1
|
||||||
|
# via
|
||||||
|
# huggingface-hub
|
||||||
|
# llama-stack-client
|
||||||
|
# openai
|
||||||
typing-extensions==4.12.2
|
typing-extensions==4.12.2
|
||||||
|
# via
|
||||||
|
# anyio
|
||||||
|
# huggingface-hub
|
||||||
|
# llama-stack-client
|
||||||
|
# multidict
|
||||||
|
# openai
|
||||||
|
# pydantic
|
||||||
|
# pydantic-core
|
||||||
|
# referencing
|
||||||
|
# rich
|
||||||
tzdata==2025.1
|
tzdata==2025.1
|
||||||
|
# via pandas
|
||||||
urllib3==2.3.0
|
urllib3==2.3.0
|
||||||
|
# via requests
|
||||||
wcwidth==0.2.13
|
wcwidth==0.2.13
|
||||||
|
# via prompt-toolkit
|
||||||
|
yarl==1.18.3
|
||||||
|
# via aiohttp
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
|
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
import importlib
|
import importlib
|
||||||
import json
|
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
|
@ -108,21 +107,6 @@ def collect_template_dependencies(template_dir: Path) -> tuple[str | None, list[
|
||||||
return None, []
|
return None, []
|
||||||
|
|
||||||
|
|
||||||
def generate_dependencies_file(change_tracker: ChangedPathTracker):
|
|
||||||
templates_dir = REPO_ROOT / "llama_stack" / "templates"
|
|
||||||
distribution_deps = {}
|
|
||||||
|
|
||||||
for template_dir in find_template_dirs(templates_dir):
|
|
||||||
name, deps = collect_template_dependencies(template_dir)
|
|
||||||
if name:
|
|
||||||
distribution_deps[name] = deps
|
|
||||||
|
|
||||||
deps_file = REPO_ROOT / "llama_stack" / "templates" / "dependencies.json"
|
|
||||||
change_tracker.add_paths(deps_file)
|
|
||||||
with open(deps_file, "w") as f:
|
|
||||||
f.write(json.dumps(distribution_deps, indent=2) + "\n")
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
templates_dir = REPO_ROOT / "llama_stack" / "templates"
|
templates_dir = REPO_ROOT / "llama_stack" / "templates"
|
||||||
change_tracker = ChangedPathTracker()
|
change_tracker = ChangedPathTracker()
|
||||||
|
@ -143,8 +127,6 @@ def main():
|
||||||
list(executor.map(process_func, template_dirs))
|
list(executor.map(process_func, template_dirs))
|
||||||
progress.update(task, advance=len(template_dirs))
|
progress.update(task, advance=len(template_dirs))
|
||||||
|
|
||||||
generate_dependencies_file(change_tracker)
|
|
||||||
|
|
||||||
if check_for_changes(change_tracker):
|
if check_for_changes(change_tracker):
|
||||||
print(
|
print(
|
||||||
"Distribution template changes detected. Please commit the changes.",
|
"Distribution template changes detected. Please commit the changes.",
|
||||||
|
|
|
@ -10,10 +10,10 @@ PYTHON_VERSION=${PYTHON_VERSION:-3.10}
|
||||||
|
|
||||||
command -v uv >/dev/null 2>&1 || { echo >&2 "uv is required but it's not installed. Exiting."; exit 1; }
|
command -v uv >/dev/null 2>&1 || { echo >&2 "uv is required but it's not installed. Exiting."; exit 1; }
|
||||||
|
|
||||||
uv python find $PYTHON_VERSION
|
uv python find "$PYTHON_VERSION"
|
||||||
FOUND_PYTHON=$?
|
FOUND_PYTHON=$?
|
||||||
if [ $FOUND_PYTHON -ne 0 ]; then
|
if [ $FOUND_PYTHON -ne 0 ]; then
|
||||||
uv python install $PYTHON_VERSION
|
uv python install "$PYTHON_VERSION"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
uv run --python $PYTHON_VERSION --with-editable . --with-editable ".[dev]" --with-editable ".[unit]" pytest --asyncio-mode=auto -s -v tests/unit/ $@
|
uv run --python "$PYTHON_VERSION" --with-editable . --group unit pytest --asyncio-mode=auto -s -v tests/unit/ $@
|
||||||
|
|
|
@ -6,7 +6,6 @@ dependencies = [
|
||||||
"aiohttp",
|
"aiohttp",
|
||||||
"aiosqlite",
|
"aiosqlite",
|
||||||
"autoevals",
|
"autoevals",
|
||||||
"blobfile",
|
|
||||||
"chardet",
|
"chardet",
|
||||||
"chromadb-client",
|
"chromadb-client",
|
||||||
"datasets",
|
"datasets",
|
||||||
|
|
|
@ -41,7 +41,6 @@ def openai_client(client_with_models):
|
||||||
],
|
],
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.skip(reason="Very flaky, sometimes there is a message not a function call, standard tool calling issues")
|
|
||||||
def test_responses_store(openai_client, client_with_models, text_model_id, stream, tools):
|
def test_responses_store(openai_client, client_with_models, text_model_id, stream, tools):
|
||||||
if isinstance(client_with_models, LlamaStackAsLibraryClient):
|
if isinstance(client_with_models, LlamaStackAsLibraryClient):
|
||||||
pytest.skip("OpenAI responses are not supported when testing with library client yet.")
|
pytest.skip("OpenAI responses are not supported when testing with library client yet.")
|
||||||
|
@ -68,13 +67,15 @@ def test_responses_store(openai_client, client_with_models, text_model_id, strea
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
if response_id is None:
|
if response_id is None:
|
||||||
response_id = chunk.response.id
|
response_id = chunk.response.id
|
||||||
if not tools:
|
|
||||||
if chunk.type == "response.completed":
|
if chunk.type == "response.completed":
|
||||||
response_id = chunk.response.id
|
response_id = chunk.response.id
|
||||||
|
output_type = chunk.response.output[0].type
|
||||||
|
if output_type == "message":
|
||||||
content = chunk.response.output[0].content[0].text
|
content = chunk.response.output[0].content[0].text
|
||||||
else:
|
else:
|
||||||
response_id = response.id
|
response_id = response.id
|
||||||
if not tools:
|
output_type = response.output[0].type
|
||||||
|
if output_type == "message":
|
||||||
content = response.output[0].content[0].text
|
content = response.output[0].content[0].text
|
||||||
|
|
||||||
# list responses - use the underlying HTTP client for endpoints not in SDK
|
# list responses - use the underlying HTTP client for endpoints not in SDK
|
||||||
|
@ -87,9 +88,8 @@ def test_responses_store(openai_client, client_with_models, text_model_id, strea
|
||||||
retrieved_response = client.responses.retrieve(response_id)
|
retrieved_response = client.responses.retrieve(response_id)
|
||||||
assert retrieved_response.id == response_id
|
assert retrieved_response.id == response_id
|
||||||
assert retrieved_response.model == text_model_id
|
assert retrieved_response.model == text_model_id
|
||||||
if tools:
|
assert retrieved_response.output[0].type == output_type, retrieved_response
|
||||||
assert retrieved_response.output[0].type == "function_call"
|
if output_type == "message":
|
||||||
else:
|
|
||||||
assert retrieved_response.output[0].content[0].text == content
|
assert retrieved_response.output[0].content[0].text == content
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -224,6 +224,43 @@ def test_openai_chat_completion_streaming(compat_client, client_with_models, tex
|
||||||
assert expected.lower() in "".join(streamed_content)
|
assert expected.lower() in "".join(streamed_content)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_case",
|
||||||
|
[
|
||||||
|
"inference:chat_completion:streaming_01",
|
||||||
|
"inference:chat_completion:streaming_02",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_openai_chat_completion_streaming_with_n(compat_client, client_with_models, text_model_id, test_case):
|
||||||
|
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
|
||||||
|
|
||||||
|
provider = provider_from_model(client_with_models, text_model_id)
|
||||||
|
if provider.provider_type == "remote::ollama":
|
||||||
|
pytest.skip(f"Model {text_model_id} hosted by {provider.provider_type} doesn't support n > 1.")
|
||||||
|
|
||||||
|
tc = TestCase(test_case)
|
||||||
|
question = tc["question"]
|
||||||
|
expected = tc["expected"]
|
||||||
|
|
||||||
|
response = compat_client.chat.completions.create(
|
||||||
|
model=text_model_id,
|
||||||
|
messages=[{"role": "user", "content": question}],
|
||||||
|
stream=True,
|
||||||
|
timeout=120, # Increase timeout to 2 minutes for large conversation history,
|
||||||
|
n=2,
|
||||||
|
)
|
||||||
|
streamed_content = {}
|
||||||
|
for chunk in response:
|
||||||
|
for choice in chunk.choices:
|
||||||
|
if choice.delta.content:
|
||||||
|
streamed_content[choice.index] = (
|
||||||
|
streamed_content.get(choice.index, "") + choice.delta.content.lower().strip()
|
||||||
|
)
|
||||||
|
assert len(streamed_content) == 2
|
||||||
|
for i, content in streamed_content.items():
|
||||||
|
assert expected.lower() in content, f"Choice {i}: Expected {expected.lower()} in {content}"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"stream",
|
"stream",
|
||||||
[
|
[
|
||||||
|
@ -253,6 +290,7 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
if response_id is None:
|
if response_id is None:
|
||||||
response_id = chunk.id
|
response_id = chunk.id
|
||||||
|
if chunk.choices[0].delta.content:
|
||||||
content += chunk.choices[0].delta.content
|
content += chunk.choices[0].delta.content
|
||||||
else:
|
else:
|
||||||
response_id = response.id
|
response_id = response.id
|
||||||
|
@ -263,8 +301,8 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea
|
||||||
|
|
||||||
retrieved_response = client.chat.completions.retrieve(response_id)
|
retrieved_response = client.chat.completions.retrieve(response_id)
|
||||||
assert retrieved_response.id == response_id
|
assert retrieved_response.id == response_id
|
||||||
assert retrieved_response.input_messages[0]["content"] == message
|
assert retrieved_response.input_messages[0]["content"] == message, retrieved_response
|
||||||
assert retrieved_response.choices[0].message.content == content
|
assert retrieved_response.choices[0].message.content == content, retrieved_response
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -274,7 +312,6 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea
|
||||||
False,
|
False,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.skip(reason="Very flaky, tool calling really wacky on CI")
|
|
||||||
def test_inference_store_tool_calls(openai_client, client_with_models, text_model_id, stream):
|
def test_inference_store_tool_calls(openai_client, client_with_models, text_model_id, stream):
|
||||||
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
|
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
|
||||||
client = openai_client
|
client = openai_client
|
||||||
|
@ -312,7 +349,9 @@ def test_inference_store_tool_calls(openai_client, client_with_models, text_mode
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
if response_id is None:
|
if response_id is None:
|
||||||
response_id = chunk.id
|
response_id = chunk.id
|
||||||
content += chunk.choices[0].delta.content
|
if delta := chunk.choices[0].delta:
|
||||||
|
if delta.content:
|
||||||
|
content += delta.content
|
||||||
else:
|
else:
|
||||||
response_id = response.id
|
response_id = response.id
|
||||||
content = response.choices[0].message.content
|
content = response.choices[0].message.content
|
||||||
|
@ -323,5 +362,11 @@ def test_inference_store_tool_calls(openai_client, client_with_models, text_mode
|
||||||
retrieved_response = client.chat.completions.retrieve(response_id)
|
retrieved_response = client.chat.completions.retrieve(response_id)
|
||||||
assert retrieved_response.id == response_id
|
assert retrieved_response.id == response_id
|
||||||
assert retrieved_response.input_messages[0]["content"] == message
|
assert retrieved_response.input_messages[0]["content"] == message
|
||||||
assert retrieved_response.choices[0].message.tool_calls[0].function.name == "get_weather"
|
tool_calls = retrieved_response.choices[0].message.tool_calls
|
||||||
assert retrieved_response.choices[0].message.tool_calls[0].function.arguments == '{"city":"Tokyo"}'
|
# sometimes model doesn't ouptut tool calls, but we still want to test that the tool was called
|
||||||
|
if tool_calls:
|
||||||
|
assert len(tool_calls) == 1
|
||||||
|
assert tool_calls[0].function.name == "get_weather"
|
||||||
|
assert "tokyo" in tool_calls[0].function.arguments.lower()
|
||||||
|
else:
|
||||||
|
assert retrieved_response.choices[0].message.content == content
|
||||||
|
|
|
@ -25,10 +25,12 @@ def test_web_search_tool(llama_stack_client, sample_search_query):
|
||||||
if "TAVILY_SEARCH_API_KEY" not in os.environ:
|
if "TAVILY_SEARCH_API_KEY" not in os.environ:
|
||||||
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
|
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
|
||||||
|
|
||||||
|
tools = llama_stack_client.tool_runtime.list_tools()
|
||||||
|
assert any(tool.identifier == "web_search" for tool in tools)
|
||||||
|
|
||||||
response = llama_stack_client.tool_runtime.invoke_tool(
|
response = llama_stack_client.tool_runtime.invoke_tool(
|
||||||
tool_name="web_search", kwargs={"query": sample_search_query}
|
tool_name="web_search", kwargs={"query": sample_search_query}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify the response
|
# Verify the response
|
||||||
assert response.content is not None
|
assert response.content is not None
|
||||||
assert len(response.content) > 0
|
assert len(response.content) > 0
|
||||||
|
@ -49,11 +51,12 @@ def test_wolfram_alpha_tool(llama_stack_client, sample_wolfram_alpha_query):
|
||||||
if "WOLFRAM_ALPHA_API_KEY" not in os.environ:
|
if "WOLFRAM_ALPHA_API_KEY" not in os.environ:
|
||||||
pytest.skip("WOLFRAM_ALPHA_API_KEY not set, skipping test")
|
pytest.skip("WOLFRAM_ALPHA_API_KEY not set, skipping test")
|
||||||
|
|
||||||
|
tools = llama_stack_client.tool_runtime.list_tools()
|
||||||
|
assert any(tool.identifier == "wolfram_alpha" for tool in tools)
|
||||||
response = llama_stack_client.tool_runtime.invoke_tool(
|
response = llama_stack_client.tool_runtime.invoke_tool(
|
||||||
tool_name="wolfram_alpha", kwargs={"query": sample_wolfram_alpha_query}
|
tool_name="wolfram_alpha", kwargs={"query": sample_wolfram_alpha_query}
|
||||||
)
|
)
|
||||||
|
|
||||||
print(response.content)
|
|
||||||
assert response.content is not None
|
assert response.content is not None
|
||||||
assert len(response.content) > 0
|
assert len(response.content) > 0
|
||||||
assert isinstance(response.content, str)
|
assert isinstance(response.content, str)
|
||||||
|
|
|
@ -31,8 +31,7 @@ def test_mcp_invocation(llama_stack_client, mcp_server):
|
||||||
test_toolgroup_id = MCP_TOOLGROUP_ID
|
test_toolgroup_id = MCP_TOOLGROUP_ID
|
||||||
uri = mcp_server["server_url"]
|
uri = mcp_server["server_url"]
|
||||||
|
|
||||||
# registering itself should fail since it requires listing tools
|
# registering should not raise an error anymore even if you don't specify the auth token
|
||||||
with pytest.raises(Exception, match="Unauthorized"):
|
|
||||||
llama_stack_client.toolgroups.register(
|
llama_stack_client.toolgroups.register(
|
||||||
toolgroup_id=test_toolgroup_id,
|
toolgroup_id=test_toolgroup_id,
|
||||||
provider_id="model-context-protocol",
|
provider_id="model-context-protocol",
|
||||||
|
@ -41,27 +40,18 @@ def test_mcp_invocation(llama_stack_client, mcp_server):
|
||||||
|
|
||||||
provider_data = {
|
provider_data = {
|
||||||
"mcp_headers": {
|
"mcp_headers": {
|
||||||
uri: [
|
uri: {
|
||||||
f"Authorization: Bearer {AUTH_TOKEN}",
|
"Authorization": f"Bearer {AUTH_TOKEN}",
|
||||||
],
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
auth_headers = {
|
auth_headers = {
|
||||||
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
|
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
with pytest.raises(Exception, match="Unauthorized"):
|
||||||
llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id, extra_headers=auth_headers)
|
llama_stack_client.tools.list()
|
||||||
except Exception as e:
|
|
||||||
# An error is OK since the toolgroup may not exist
|
|
||||||
print(f"Error unregistering toolgroup: {e}")
|
|
||||||
|
|
||||||
llama_stack_client.toolgroups.register(
|
|
||||||
toolgroup_id=test_toolgroup_id,
|
|
||||||
provider_id="model-context-protocol",
|
|
||||||
mcp_endpoint=dict(uri=uri),
|
|
||||||
extra_headers=auth_headers,
|
|
||||||
)
|
|
||||||
response = llama_stack_client.tools.list(
|
response = llama_stack_client.tools.list(
|
||||||
toolgroup_id=test_toolgroup_id,
|
toolgroup_id=test_toolgroup_id,
|
||||||
extra_headers=auth_headers,
|
extra_headers=auth_headers,
|
||||||
|
|
|
@ -51,7 +51,5 @@ def test_register_and_unregister_toolgroup(llama_stack_client):
|
||||||
with pytest.raises(Exception, match=f"Tool group '{test_toolgroup_id}' not found"):
|
with pytest.raises(Exception, match=f"Tool group '{test_toolgroup_id}' not found"):
|
||||||
llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id)
|
llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id)
|
||||||
|
|
||||||
# Verify tools are also unregistered
|
with pytest.raises(Exception, match=f"Tool group '{test_toolgroup_id}' not found"):
|
||||||
unregister_tools_list_response = llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id)
|
llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id)
|
||||||
assert isinstance(unregister_tools_list_response, list)
|
|
||||||
assert not unregister_tools_list_response
|
|
||||||
|
|
|
@ -15,7 +15,7 @@ from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataS
|
||||||
from llama_stack.apis.datatypes import Api
|
from llama_stack.apis.datatypes import Api
|
||||||
from llama_stack.apis.models.models import Model, ModelType
|
from llama_stack.apis.models.models import Model, ModelType
|
||||||
from llama_stack.apis.shields.shields import Shield
|
from llama_stack.apis.shields.shields import Shield
|
||||||
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolParameter
|
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter
|
||||||
from llama_stack.apis.vector_dbs.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs.vector_dbs import VectorDB
|
||||||
from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable
|
from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable
|
||||||
from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable
|
from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable
|
||||||
|
@ -101,11 +101,11 @@ class ToolGroupsImpl(Impl):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(Api.tool_runtime)
|
super().__init__(Api.tool_runtime)
|
||||||
|
|
||||||
async def register_tool(self, tool):
|
async def register_toolgroup(self, toolgroup: ToolGroup):
|
||||||
return tool
|
return toolgroup
|
||||||
|
|
||||||
async def unregister_tool(self, tool_name: str):
|
async def unregister_toolgroup(self, toolgroup_id: str):
|
||||||
return tool_name
|
return toolgroup_id
|
||||||
|
|
||||||
async def list_runtime_tools(self, toolgroup_id, mcp_endpoint):
|
async def list_runtime_tools(self, toolgroup_id, mcp_endpoint):
|
||||||
return ListToolDefsResponse(
|
return ListToolDefsResponse(
|
||||||
|
|
|
@ -232,9 +232,17 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
|
||||||
|
|
||||||
# Check that we got the content from our mocked tool execution result
|
# Check that we got the content from our mocked tool execution result
|
||||||
chunks = [chunk async for chunk in result]
|
chunks = [chunk async for chunk in result]
|
||||||
assert len(chunks) > 0
|
assert len(chunks) == 2 # Should have response.created and response.completed
|
||||||
assert chunks[0].response.output[0].type == "function_call"
|
|
||||||
assert chunks[0].response.output[0].name == "get_weather"
|
# Check response.created event (should have empty output)
|
||||||
|
assert chunks[0].type == "response.created"
|
||||||
|
assert len(chunks[0].response.output) == 0
|
||||||
|
|
||||||
|
# Check response.completed event (should have the tool call)
|
||||||
|
assert chunks[1].type == "response.completed"
|
||||||
|
assert len(chunks[1].response.output) == 1
|
||||||
|
assert chunks[1].response.output[0].type == "function_call"
|
||||||
|
assert chunks[1].response.output[0].name == "get_weather"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -620,3 +628,69 @@ async def test_responses_store_list_input_items_logic():
|
||||||
result = await responses_store.list_response_input_items("resp_123", limit=0, order=Order.asc)
|
result = await responses_store.list_response_input_items("resp_123", limit=0, order=Order.asc)
|
||||||
assert result.object == "list"
|
assert result.object == "list"
|
||||||
assert len(result.data) == 0 # Should return no items
|
assert len(result.data) == 0 # Should return no items
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_store_response_uses_rehydrated_input_with_previous_response(
|
||||||
|
openai_responses_impl, mock_responses_store, mock_inference_api
|
||||||
|
):
|
||||||
|
"""Test that _store_response uses the full re-hydrated input (including previous responses)
|
||||||
|
rather than just the original input when previous_response_id is provided."""
|
||||||
|
|
||||||
|
# Setup - Create a previous response that should be included in the stored input
|
||||||
|
previous_response = OpenAIResponseObjectWithInput(
|
||||||
|
id="resp-previous-123",
|
||||||
|
object="response",
|
||||||
|
created_at=1234567890,
|
||||||
|
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
status="completed",
|
||||||
|
input=[
|
||||||
|
OpenAIResponseMessage(
|
||||||
|
id="msg-prev-user", role="user", content=[OpenAIResponseInputMessageContentText(text="What is 2+2?")]
|
||||||
|
)
|
||||||
|
],
|
||||||
|
output=[
|
||||||
|
OpenAIResponseMessage(
|
||||||
|
id="msg-prev-assistant",
|
||||||
|
role="assistant",
|
||||||
|
content=[OpenAIResponseOutputMessageContentOutputText(text="2+2 equals 4.")],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_responses_store.get_response_object.return_value = previous_response
|
||||||
|
|
||||||
|
current_input = "Now what is 3+3?"
|
||||||
|
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
|
mock_chat_completion = load_chat_completion_fixture("simple_chat_completion.yaml")
|
||||||
|
mock_inference_api.openai_chat_completion.return_value = mock_chat_completion
|
||||||
|
|
||||||
|
# Execute - Create response with previous_response_id
|
||||||
|
result = await openai_responses_impl.create_openai_response(
|
||||||
|
input=current_input,
|
||||||
|
model=model,
|
||||||
|
previous_response_id="resp-previous-123",
|
||||||
|
store=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
store_call_args = mock_responses_store.store_response_object.call_args
|
||||||
|
stored_input = store_call_args.kwargs["input"]
|
||||||
|
|
||||||
|
# Verify that the stored input contains the full re-hydrated conversation:
|
||||||
|
# 1. Previous user message
|
||||||
|
# 2. Previous assistant response
|
||||||
|
# 3. Current user message
|
||||||
|
assert len(stored_input) == 3
|
||||||
|
|
||||||
|
assert stored_input[0].role == "user"
|
||||||
|
assert stored_input[0].content[0].text == "What is 2+2?"
|
||||||
|
|
||||||
|
assert stored_input[1].role == "assistant"
|
||||||
|
assert stored_input[1].content[0].text == "2+2 equals 4."
|
||||||
|
|
||||||
|
assert stored_input[2].role == "user"
|
||||||
|
assert stored_input[2].content == "Now what is 3+3?"
|
||||||
|
|
||||||
|
# Verify the response itself is correct
|
||||||
|
assert result.model == model
|
||||||
|
assert result.status == "completed"
|
||||||
|
|
|
@ -27,7 +27,7 @@ export TOGETHER_API_KEY=<your_together_api_key>
|
||||||
```
|
```
|
||||||
then run
|
then run
|
||||||
```bash
|
```bash
|
||||||
uv run --with-editable ".[dev]" python tests/verifications/generate_report.py --run-tests
|
uv run python tests/verifications/generate_report.py --run-tests
|
||||||
```
|
```
|
||||||
|
|
||||||
## Running Tests
|
## Running Tests
|
||||||
|
|
|
@ -10,17 +10,17 @@ from tests.verifications.openai_api.fixtures.fixtures import _load_all_verificat
|
||||||
def pytest_generate_tests(metafunc):
|
def pytest_generate_tests(metafunc):
|
||||||
"""Dynamically parametrize tests based on the selected provider and config."""
|
"""Dynamically parametrize tests based on the selected provider and config."""
|
||||||
if "model" in metafunc.fixturenames:
|
if "model" in metafunc.fixturenames:
|
||||||
|
model = metafunc.config.getoption("model")
|
||||||
|
if model:
|
||||||
|
metafunc.parametrize("model", [model])
|
||||||
|
return
|
||||||
|
|
||||||
provider = metafunc.config.getoption("provider")
|
provider = metafunc.config.getoption("provider")
|
||||||
if not provider:
|
if not provider:
|
||||||
print("Warning: --provider not specified. Skipping model parametrization.")
|
print("Warning: --provider not specified. Skipping model parametrization.")
|
||||||
metafunc.parametrize("model", [])
|
metafunc.parametrize("model", [])
|
||||||
return
|
return
|
||||||
|
|
||||||
model = metafunc.config.getoption("model")
|
|
||||||
if model:
|
|
||||||
metafunc.parametrize("model", [model])
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
config_data = _load_all_verification_configs()
|
config_data = _load_all_verification_configs()
|
||||||
except (OSError, FileNotFoundError) as e:
|
except (OSError, FileNotFoundError) as e:
|
||||||
|
|
|
@ -77,11 +77,12 @@ test_response_image:
|
||||||
image_url: "https://upload.wikimedia.org/wikipedia/commons/f/f7/Llamas%2C_Vernagt-Stausee%2C_Italy.jpg"
|
image_url: "https://upload.wikimedia.org/wikipedia/commons/f/f7/Llamas%2C_Vernagt-Stausee%2C_Italy.jpg"
|
||||||
output: "llama"
|
output: "llama"
|
||||||
|
|
||||||
|
# the models are really poor at tool calling after seeing images :/
|
||||||
test_response_multi_turn_image:
|
test_response_multi_turn_image:
|
||||||
test_name: test_response_multi_turn_image
|
test_name: test_response_multi_turn_image
|
||||||
test_params:
|
test_params:
|
||||||
case:
|
case:
|
||||||
- case_id: "llama_image_search"
|
- case_id: "llama_image_understanding"
|
||||||
turns:
|
turns:
|
||||||
- input:
|
- input:
|
||||||
- role: user
|
- role: user
|
||||||
|
@ -91,7 +92,5 @@ test_response_multi_turn_image:
|
||||||
- type: input_image
|
- type: input_image
|
||||||
image_url: "https://upload.wikimedia.org/wikipedia/commons/f/f7/Llamas%2C_Vernagt-Stausee%2C_Italy.jpg"
|
image_url: "https://upload.wikimedia.org/wikipedia/commons/f/f7/Llamas%2C_Vernagt-Stausee%2C_Italy.jpg"
|
||||||
output: "llama"
|
output: "llama"
|
||||||
- input: "Search the web using the search tool for the animal from the previous response. Your search query should be a single phrase that includes the animal's name and the words 'maverick', 'scout' and 'llm'"
|
- input: "What country do you find this animal primarily in? What continent?"
|
||||||
tools:
|
output: "peru"
|
||||||
- type: web_search
|
|
||||||
output: "model"
|
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue