feat: small ollama package

This commit is contained in:
Raghotham Murthy 2025-05-28 21:13:48 -07:00
commit 2d5d05a2b4
103 changed files with 7262 additions and 7422 deletions

View file

@ -1,10 +1,8 @@
# What does this PR do? # What does this PR do?
[Provide a short summary of what this PR does and why. Link to relevant issues if applicable.] <!-- Provide a short summary of what this PR does and why. Link to relevant issues if applicable. -->
[//]: # (If resolving an issue, uncomment and update the line below) <!-- If resolving an issue, uncomment and update the line below -->
[//]: # (Closes #[issue-number]) <!-- Closes #[issue-number] -->
## Test Plan ## Test Plan
[Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] <!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* -->
[//]: # (## Documentation)

View file

@ -13,7 +13,7 @@ runs:
- name: Install dependencies - name: Install dependencies
shell: bash shell: bash
run: | run: |
uv sync --all-extras uv sync --all-groups
uv pip install ollama faiss-cpu uv pip install ollama faiss-cpu
# always test against the latest version of the client # always test against the latest version of the client
# TODO: this is not necessarily a good idea. we need to test against both published and latest # TODO: this is not necessarily a good idea. we need to test against both published and latest

View file

@ -53,7 +53,7 @@ repos:
- black==24.3.0 - black==24.3.0
- repo: https://github.com/astral-sh/uv-pre-commit - repo: https://github.com/astral-sh/uv-pre-commit
rev: 0.6.3 rev: 0.7.8
hooks: hooks:
- id: uv-lock - id: uv-lock
- id: uv-export - id: uv-export
@ -61,6 +61,7 @@ repos:
"--frozen", "--frozen",
"--no-hashes", "--no-hashes",
"--no-emit-project", "--no-emit-project",
"--no-default-groups",
"--output-file=requirements.txt" "--output-file=requirements.txt"
] ]
@ -88,8 +89,8 @@ repos:
- id: distro-codegen - id: distro-codegen
name: Distribution Template Codegen name: Distribution Template Codegen
additional_dependencies: additional_dependencies:
- uv==0.6.0 - uv==0.7.8
entry: uv run --extra codegen ./scripts/distro_codegen.py entry: uv run --group codegen ./scripts/distro_codegen.py
language: python language: python
pass_filenames: false pass_filenames: false
require_serial: true require_serial: true
@ -97,8 +98,8 @@ repos:
- id: openapi-codegen - id: openapi-codegen
name: API Spec Codegen name: API Spec Codegen
additional_dependencies: additional_dependencies:
- uv==0.6.2 - uv==0.7.8
entry: sh -c 'uv run --with ".[dev]" ./docs/openapi_generator/run_openapi_generator.sh > /dev/null' entry: sh -c 'uv run ./docs/openapi_generator/run_openapi_generator.sh > /dev/null'
language: python language: python
pass_filenames: false pass_filenames: false
require_serial: true require_serial: true

View file

@ -5,28 +5,21 @@
# Required # Required
version: 2 version: 2
# Build documentation in the "docs/" directory with Sphinx
sphinx:
configuration: docs/source/conf.py
# Set the OS, Python version and other tools you might need # Set the OS, Python version and other tools you might need
build: build:
os: ubuntu-22.04 os: ubuntu-22.04
tools: tools:
python: "3.12" python: "3.12"
# You can also specify other tool versions: jobs:
# nodejs: "19" pre_create_environment:
# rust: "1.64" - asdf plugin add uv
# golang: "1.19" - asdf install uv latest
- asdf global uv latest
# Build documentation in the "docs/" directory with Sphinx create_environment:
sphinx: - uv venv "${READTHEDOCS_VIRTUALENV_PATH}"
configuration: docs/source/conf.py
# Optionally build your docs in additional formats such as PDF and ePub
# formats:
# - pdf
# - epub
# Optional but recommended, declare the Python requirements required
# to build your documentation
# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html
python:
install: install:
- requirements: docs/requirements.txt - UV_PROJECT_ENVIRONMENT="${READTHEDOCS_VIRTUALENV_PATH}" uv sync --frozen --group docs

View file

@ -168,10 +168,10 @@ If you are making changes to the documentation at [https://llama-stack.readthedo
```bash ```bash
# This rebuilds the documentation pages. # This rebuilds the documentation pages.
uv run --with ".[docs]" make -C docs/ html uv run --group docs make -C docs/ html
# This will start a local server (usually at http://127.0.0.1:8000) that automatically rebuilds and refreshes when you make changes to the documentation. # This will start a local server (usually at http://127.0.0.1:8000) that automatically rebuilds and refreshes when you make changes to the documentation.
uv run --with ".[docs]" sphinx-autobuild docs/source docs/build/html --write-all uv run --group docs sphinx-autobuild docs/source docs/build/html --write-all
``` ```
### Update API Documentation ### Update API Documentation
@ -179,7 +179,7 @@ uv run --with ".[docs]" sphinx-autobuild docs/source docs/build/html --write-all
If you modify or add new API endpoints, update the API documentation accordingly. You can do this by running the following command: If you modify or add new API endpoints, update the API documentation accordingly. You can do this by running the following command:
```bash ```bash
uv run --with ".[dev]" ./docs/openapi_generator/run_openapi_generator.sh uv run ./docs/openapi_generator/run_openapi_generator.sh
``` ```
The generated API documentation will be available in `docs/_static/`. Make sure to review the changes before committing. The generated API documentation will be available in `docs/_static/`. Make sure to review the changes before committing.

View file

@ -1,5 +1,4 @@
include pyproject.toml include pyproject.toml
include llama_stack/templates/dependencies.json
include llama_stack/models/llama/llama3/tokenizer.model include llama_stack/models/llama/llama3/tokenizer.model
include llama_stack/models/llama/llama4/tokenizer.model include llama_stack/models/llama/llama4/tokenizer.model
include llama_stack/distribution/*.sh include llama_stack/distribution/*.sh

View file

@ -107,26 +107,29 @@ By reducing friction and complexity, Llama Stack empowers developers to focus on
### API Providers ### API Providers
Here is a list of the various API providers and available distributions that can help developers get started easily with Llama Stack. Here is a list of the various API providers and available distributions that can help developers get started easily with Llama Stack.
| **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** | | **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** | **Post Training** |
|:------------------------:|:----------------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:| |:------------------------:|:----------------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:|:-----------------:|
| Meta Reference | Single Node | ✅ | ✅ | ✅ | ✅ | ✅ | | Meta Reference | Single Node | ✅ | ✅ | ✅ | ✅ | ✅ | |
| SambaNova | Hosted | | ✅ | | ✅ | | | SambaNova | Hosted | | ✅ | | ✅ | | |
| Cerebras | Hosted | | ✅ | | | | | Cerebras | Hosted | | ✅ | | | | |
| Fireworks | Hosted | ✅ | ✅ | ✅ | | | | Fireworks | Hosted | ✅ | ✅ | ✅ | | | |
| AWS Bedrock | Hosted | | ✅ | | ✅ | | | AWS Bedrock | Hosted | | ✅ | | ✅ | | |
| Together | Hosted | ✅ | ✅ | | ✅ | | | Together | Hosted | ✅ | ✅ | | ✅ | | |
| Groq | Hosted | | ✅ | | | | | Groq | Hosted | | ✅ | | | | |
| Ollama | Single Node | | ✅ | | | | | Ollama | Single Node | | ✅ | | | | |
| TGI | Hosted and Single Node | | ✅ | | | | | TGI | Hosted and Single Node | | ✅ | | | | |
| NVIDIA NIM | Hosted and Single Node | | ✅ | | | | | NVIDIA NIM | Hosted and Single Node | | ✅ | | | | |
| Chroma | Single Node | | | ✅ | | | | Chroma | Single Node | | | ✅ | | | |
| PG Vector | Single Node | | | ✅ | | | | PG Vector | Single Node | | | ✅ | | | |
| PyTorch ExecuTorch | On-device iOS | ✅ | ✅ | | | | | PyTorch ExecuTorch | On-device iOS | ✅ | ✅ | | | | |
| vLLM | Hosted and Single Node | | ✅ | | | | | vLLM | Hosted and Single Node | | ✅ | | | | |
| OpenAI | Hosted | | ✅ | | | | | OpenAI | Hosted | | ✅ | | | | |
| Anthropic | Hosted | | ✅ | | | | | Anthropic | Hosted | | ✅ | | | | |
| Gemini | Hosted | | ✅ | | | | | Gemini | Hosted | | ✅ | | | | |
| watsonx | Hosted | | ✅ | | | | | watsonx | Hosted | | ✅ | | | | |
| HuggingFace | Single Node | | | | | | ✅ |
| TorchTune | Single Node | | | | | | ✅ |
| NVIDIA NEMO | Hosted | | | | | | ✅ |
### Distributions ### Distributions

View file

@ -7540,6 +7540,9 @@
{ {
"$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseCreated" "$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseCreated"
}, },
{
"$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseOutputTextDelta"
},
{ {
"$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseCompleted" "$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseCompleted"
} }
@ -7548,6 +7551,7 @@
"propertyName": "type", "propertyName": "type",
"mapping": { "mapping": {
"response.created": "#/components/schemas/OpenAIResponseObjectStreamResponseCreated", "response.created": "#/components/schemas/OpenAIResponseObjectStreamResponseCreated",
"response.output_text.delta": "#/components/schemas/OpenAIResponseObjectStreamResponseOutputTextDelta",
"response.completed": "#/components/schemas/OpenAIResponseObjectStreamResponseCompleted" "response.completed": "#/components/schemas/OpenAIResponseObjectStreamResponseCompleted"
} }
} }
@ -7590,6 +7594,41 @@
], ],
"title": "OpenAIResponseObjectStreamResponseCreated" "title": "OpenAIResponseObjectStreamResponseCreated"
}, },
"OpenAIResponseObjectStreamResponseOutputTextDelta": {
"type": "object",
"properties": {
"content_index": {
"type": "integer"
},
"delta": {
"type": "string"
},
"item_id": {
"type": "string"
},
"output_index": {
"type": "integer"
},
"sequence_number": {
"type": "integer"
},
"type": {
"type": "string",
"const": "response.output_text.delta",
"default": "response.output_text.delta"
}
},
"additionalProperties": false,
"required": [
"content_index",
"delta",
"item_id",
"output_index",
"sequence_number",
"type"
],
"title": "OpenAIResponseObjectStreamResponseOutputTextDelta"
},
"CreateUploadSessionRequest": { "CreateUploadSessionRequest": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -9555,9 +9594,6 @@
"toolgroup_id": { "toolgroup_id": {
"type": "string" "type": "string"
}, },
"tool_host": {
"$ref": "#/components/schemas/ToolHost"
},
"description": { "description": {
"type": "string" "type": "string"
}, },
@ -9599,21 +9635,11 @@
"provider_id", "provider_id",
"type", "type",
"toolgroup_id", "toolgroup_id",
"tool_host",
"description", "description",
"parameters" "parameters"
], ],
"title": "Tool" "title": "Tool"
}, },
"ToolHost": {
"type": "string",
"enum": [
"distribution",
"client",
"model_context_protocol"
],
"title": "ToolHost"
},
"ToolGroup": { "ToolGroup": {
"type": "object", "type": "object",
"properties": { "properties": {

View file

@ -5294,11 +5294,13 @@ components:
OpenAIResponseObjectStream: OpenAIResponseObjectStream:
oneOf: oneOf:
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated' - $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated'
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseOutputTextDelta'
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCompleted' - $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCompleted'
discriminator: discriminator:
propertyName: type propertyName: type
mapping: mapping:
response.created: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated' response.created: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated'
response.output_text.delta: '#/components/schemas/OpenAIResponseObjectStreamResponseOutputTextDelta'
response.completed: '#/components/schemas/OpenAIResponseObjectStreamResponseCompleted' response.completed: '#/components/schemas/OpenAIResponseObjectStreamResponseCompleted'
"OpenAIResponseObjectStreamResponseCompleted": "OpenAIResponseObjectStreamResponseCompleted":
type: object type: object
@ -5330,6 +5332,33 @@ components:
- type - type
title: >- title: >-
OpenAIResponseObjectStreamResponseCreated OpenAIResponseObjectStreamResponseCreated
"OpenAIResponseObjectStreamResponseOutputTextDelta":
type: object
properties:
content_index:
type: integer
delta:
type: string
item_id:
type: string
output_index:
type: integer
sequence_number:
type: integer
type:
type: string
const: response.output_text.delta
default: response.output_text.delta
additionalProperties: false
required:
- content_index
- delta
- item_id
- output_index
- sequence_number
- type
title: >-
OpenAIResponseObjectStreamResponseOutputTextDelta
CreateUploadSessionRequest: CreateUploadSessionRequest:
type: object type: object
properties: properties:
@ -6713,8 +6742,6 @@ components:
default: tool default: tool
toolgroup_id: toolgroup_id:
type: string type: string
tool_host:
$ref: '#/components/schemas/ToolHost'
description: description:
type: string type: string
parameters: parameters:
@ -6737,17 +6764,9 @@ components:
- provider_id - provider_id
- type - type
- toolgroup_id - toolgroup_id
- tool_host
- description - description
- parameters - parameters
title: Tool title: Tool
ToolHost:
type: string
enum:
- distribution
- client
- model_context_protocol
title: ToolHost
ToolGroup: ToolGroup:
type: object type: object
properties: properties:

File diff suppressed because it is too large Load diff

View file

@ -6,7 +6,7 @@ Here's a collection of comprehensive guides, examples, and resources for buildin
From the llama-stack root directory, run the following command to render the docs locally: From the llama-stack root directory, run the following command to render the docs locally:
```bash ```bash
uv run --with ".[docs]" sphinx-autobuild docs/source docs/build/html --write-all uv run --group docs sphinx-autobuild docs/source docs/build/html --write-all
``` ```
You can open up the docs in your browser at http://localhost:8000 You can open up the docs in your browser at http://localhost:8000

View file

@ -30,6 +30,18 @@ Runs inference with an LLM.
## Post Training ## Post Training
Fine-tunes a model. Fine-tunes a model.
#### Post Training Providers
The following providers are available for Post Training:
```{toctree}
:maxdepth: 1
external
post_training/huggingface
post_training/torchtune
post_training/nvidia_nemo
```
## Safety ## Safety
Applies safety policies to the output at a Systems (not only model) level. Applies safety policies to the output at a Systems (not only model) level.

View file

@ -0,0 +1,122 @@
---
orphan: true
---
# HuggingFace SFTTrainer
[HuggingFace SFTTrainer](https://huggingface.co/docs/trl/en/sft_trainer) is an inline post training provider for Llama Stack. It allows you to run supervised fine tuning on a variety of models using many datasets
## Features
- Simple access through the post_training API
- Fully integrated with Llama Stack
- GPU support, CPU support, and MPS support (MacOS Metal Performance Shaders)
## Usage
To use the HF SFTTrainer in your Llama Stack project, follow these steps:
1. Configure your Llama Stack project to use this provider.
2. Kick off a SFT job using the Llama Stack post_training API.
## Setup
You can access the HuggingFace trainer via the `ollama` distribution:
```bash
llama stack build --template ollama --image-type venv
llama stack run --image-type venv ~/.llama/distributions/ollama/ollama-run.yaml
```
## Run Training
You can access the provider and the `supervised_fine_tune` method via the post_training API:
```python
import time
import uuid
from llama_stack_client.types import (
post_training_supervised_fine_tune_params,
algorithm_config_param,
)
def create_http_client():
from llama_stack_client import LlamaStackClient
return LlamaStackClient(base_url="http://localhost:8321")
client = create_http_client()
# Example Dataset
client.datasets.register(
purpose="post-training/messages",
source={
"type": "uri",
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
},
dataset_id="simpleqa",
)
training_config = post_training_supervised_fine_tune_params.TrainingConfig(
data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig(
batch_size=32,
data_format="instruct",
dataset_id="simpleqa",
shuffle=True,
),
gradient_accumulation_steps=1,
max_steps_per_epoch=0,
max_validation_steps=1,
n_epochs=4,
)
algorithm_config = algorithm_config_param.LoraFinetuningConfig( # this config is also currently mandatory but should not be
alpha=1,
apply_lora_to_mlp=True,
apply_lora_to_output=False,
lora_attn_modules=["q_proj"],
rank=1,
type="LoRA",
)
job_uuid = f"test-job{uuid.uuid4()}"
# Example Model
training_model = "ibm-granite/granite-3.3-8b-instruct"
start_time = time.time()
response = client.post_training.supervised_fine_tune(
job_uuid=job_uuid,
logger_config={},
model=training_model,
hyperparam_search_config={},
training_config=training_config,
algorithm_config=algorithm_config,
checkpoint_dir="output",
)
print("Job: ", job_uuid)
# Wait for the job to complete!
while True:
status = client.post_training.job.status(job_uuid=job_uuid)
if not status:
print("Job not found")
break
print(status)
if status.status == "completed":
break
print("Waiting for job to complete...")
time.sleep(5)
end_time = time.time()
print("Job completed in", end_time - start_time, "seconds!")
print("Artifacts:")
print(client.post_training.job.artifacts(job_uuid=job_uuid))
```

View file

@ -0,0 +1,163 @@
---
orphan: true
---
# NVIDIA NEMO
[NVIDIA NEMO](https://developer.nvidia.com/nemo-framework) is a remote post training provider for Llama Stack. It provides enterprise-grade fine-tuning capabilities through NVIDIA's NeMo Customizer service.
## Features
- Enterprise-grade fine-tuning capabilities
- Support for LoRA and SFT fine-tuning
- Integration with NVIDIA's NeMo Customizer service
- Support for various NVIDIA-optimized models
- Efficient training with NVIDIA hardware acceleration
## Usage
To use NVIDIA NEMO in your Llama Stack project, follow these steps:
1. Configure your Llama Stack project to use this provider.
2. Set up your NVIDIA API credentials.
3. Kick off a fine-tuning job using the Llama Stack post_training API.
## Setup
You'll need to set the following environment variables:
```bash
export NVIDIA_API_KEY="your-api-key"
export NVIDIA_DATASET_NAMESPACE="default"
export NVIDIA_CUSTOMIZER_URL="your-customizer-url"
export NVIDIA_PROJECT_ID="your-project-id"
export NVIDIA_OUTPUT_MODEL_DIR="your-output-model-dir"
```
## Run Training
You can access the provider and the `supervised_fine_tune` method via the post_training API:
```python
import time
import uuid
from llama_stack_client.types import (
post_training_supervised_fine_tune_params,
algorithm_config_param,
)
def create_http_client():
from llama_stack_client import LlamaStackClient
return LlamaStackClient(base_url="http://localhost:8321")
client = create_http_client()
# Example Dataset
client.datasets.register(
purpose="post-training/messages",
source={
"type": "uri",
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
},
dataset_id="simpleqa",
)
training_config = post_training_supervised_fine_tune_params.TrainingConfig(
data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig(
batch_size=8, # Default batch size for NEMO
data_format="instruct",
dataset_id="simpleqa",
shuffle=True,
),
n_epochs=50, # Default epochs for NEMO
optimizer_config=post_training_supervised_fine_tune_params.TrainingConfigOptimizerConfig(
lr=0.0001, # Default learning rate
weight_decay=0.01, # NEMO-specific parameter
),
# NEMO-specific parameters
log_every_n_steps=None,
val_check_interval=0.25,
sequence_packing_enabled=False,
hidden_dropout=None,
attention_dropout=None,
ffn_dropout=None,
)
algorithm_config = algorithm_config_param.LoraFinetuningConfig(
alpha=16, # Default alpha for NEMO
type="LoRA",
)
job_uuid = f"test-job{uuid.uuid4()}"
# Example Model - must be a supported NEMO model
training_model = "meta/llama-3.1-8b-instruct"
start_time = time.time()
response = client.post_training.supervised_fine_tune(
job_uuid=job_uuid,
logger_config={},
model=training_model,
hyperparam_search_config={},
training_config=training_config,
algorithm_config=algorithm_config,
checkpoint_dir="output",
)
print("Job: ", job_uuid)
# Wait for the job to complete!
while True:
status = client.post_training.job.status(job_uuid=job_uuid)
if not status:
print("Job not found")
break
print(status)
if status.status == "completed":
break
print("Waiting for job to complete...")
time.sleep(5)
end_time = time.time()
print("Job completed in", end_time - start_time, "seconds!")
print("Artifacts:")
print(client.post_training.job.artifacts(job_uuid=job_uuid))
```
## Supported Models
Currently supports the following models:
- meta/llama-3.1-8b-instruct
- meta/llama-3.2-1b-instruct
## Supported Parameters
### TrainingConfig
- n_epochs (default: 50)
- data_config
- optimizer_config
- log_every_n_steps
- val_check_interval (default: 0.25)
- sequence_packing_enabled (default: False)
- hidden_dropout (0.0-1.0)
- attention_dropout (0.0-1.0)
- ffn_dropout (0.0-1.0)
### DataConfig
- dataset_id
- batch_size (default: 8)
### OptimizerConfig
- lr (default: 0.0001)
- weight_decay (default: 0.01)
### LoRA Config
- alpha (default: 16)
- type (must be "LoRA")
Note: Some parameters from the standard Llama Stack API are not supported and will be ignored with a warning.

View file

@ -0,0 +1,125 @@
---
orphan: true
---
# TorchTune
[TorchTune](https://github.com/pytorch/torchtune) is an inline post training provider for Llama Stack. It provides a simple and efficient way to fine-tune language models using PyTorch.
## Features
- Simple access through the post_training API
- Fully integrated with Llama Stack
- GPU support and single device capabilities.
- Support for LoRA
## Usage
To use TorchTune in your Llama Stack project, follow these steps:
1. Configure your Llama Stack project to use this provider.
2. Kick off a fine-tuning job using the Llama Stack post_training API.
## Setup
You can access the TorchTune trainer by writing your own yaml pointing to the provider:
```yaml
post_training:
- provider_id: torchtune
provider_type: inline::torchtune
config: {}
```
you can then build and run your own stack with this provider.
## Run Training
You can access the provider and the `supervised_fine_tune` method via the post_training API:
```python
import time
import uuid
from llama_stack_client.types import (
post_training_supervised_fine_tune_params,
algorithm_config_param,
)
def create_http_client():
from llama_stack_client import LlamaStackClient
return LlamaStackClient(base_url="http://localhost:8321")
client = create_http_client()
# Example Dataset
client.datasets.register(
purpose="post-training/messages",
source={
"type": "uri",
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
},
dataset_id="simpleqa",
)
training_config = post_training_supervised_fine_tune_params.TrainingConfig(
data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig(
batch_size=32,
data_format="instruct",
dataset_id="simpleqa",
shuffle=True,
),
gradient_accumulation_steps=1,
max_steps_per_epoch=0,
max_validation_steps=1,
n_epochs=4,
)
algorithm_config = algorithm_config_param.LoraFinetuningConfig(
alpha=1,
apply_lora_to_mlp=True,
apply_lora_to_output=False,
lora_attn_modules=["q_proj"],
rank=1,
type="LoRA",
)
job_uuid = f"test-job{uuid.uuid4()}"
# Example Model
training_model = "meta-llama/Llama-2-7b-hf"
start_time = time.time()
response = client.post_training.supervised_fine_tune(
job_uuid=job_uuid,
logger_config={},
model=training_model,
hyperparam_search_config={},
training_config=training_config,
algorithm_config=algorithm_config,
checkpoint_dir="output",
)
print("Job: ", job_uuid)
# Wait for the job to complete!
while True:
status = client.post_training.job.status(job_uuid=job_uuid)
if not status:
print("Job not found")
break
print(status)
if status.status == "completed":
break
print("Waiting for job to complete...")
time.sleep(5)
end_time = time.time()
print("Job completed in", end_time - start_time, "seconds!")
print("Artifacts:")
print(client.post_training.job.artifacts(job_uuid=job_uuid))
```

View file

@ -149,6 +149,16 @@ class OpenAIResponseObjectStreamResponseCreated(BaseModel):
type: Literal["response.created"] = "response.created" type: Literal["response.created"] = "response.created"
@json_schema_type
class OpenAIResponseObjectStreamResponseOutputTextDelta(BaseModel):
content_index: int
delta: str
item_id: str
output_index: int
sequence_number: int
type: Literal["response.output_text.delta"] = "response.output_text.delta"
@json_schema_type @json_schema_type
class OpenAIResponseObjectStreamResponseCompleted(BaseModel): class OpenAIResponseObjectStreamResponseCompleted(BaseModel):
response: OpenAIResponseObject response: OpenAIResponseObject
@ -156,7 +166,9 @@ class OpenAIResponseObjectStreamResponseCompleted(BaseModel):
OpenAIResponseObjectStream = Annotated[ OpenAIResponseObjectStream = Annotated[
OpenAIResponseObjectStreamResponseCreated | OpenAIResponseObjectStreamResponseCompleted, OpenAIResponseObjectStreamResponseCreated
| OpenAIResponseObjectStreamResponseOutputTextDelta
| OpenAIResponseObjectStreamResponseCompleted,
Field(discriminator="type"), Field(discriminator="type"),
] ]
register_schema(OpenAIResponseObjectStream, name="OpenAIResponseObjectStream") register_schema(OpenAIResponseObjectStream, name="OpenAIResponseObjectStream")

View file

@ -1,30 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any
from pydantic import BaseModel
from llama_stack.apis.common.content_types import URL
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class RestAPIMethod(Enum):
GET = "GET"
POST = "POST"
PUT = "PUT"
DELETE = "DELETE"
@json_schema_type
class RestAPIExecutionConfig(BaseModel):
url: URL
method: RestAPIMethod
params: dict[str, Any] | None = None
headers: dict[str, Any] | None = None
body: dict[str, Any] | None = None

View file

@ -27,18 +27,10 @@ class ToolParameter(BaseModel):
default: Any | None = None default: Any | None = None
@json_schema_type
class ToolHost(Enum):
distribution = "distribution"
client = "client"
model_context_protocol = "model_context_protocol"
@json_schema_type @json_schema_type
class Tool(Resource): class Tool(Resource):
type: Literal[ResourceType.tool] = ResourceType.tool type: Literal[ResourceType.tool] = ResourceType.tool
toolgroup_id: str toolgroup_id: str
tool_host: ToolHost
description: str description: str
parameters: list[ToolParameter] parameters: list[ToolParameter]
metadata: dict[str, Any] | None = None metadata: dict[str, Any] | None = None

View file

@ -267,8 +267,8 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
if args.run: if args.run:
config_dict = yaml.safe_load(run_config.read_text()) config_dict = yaml.safe_load(run_config.read_text())
config = parse_and_maybe_upgrade_config(config_dict) config = parse_and_maybe_upgrade_config(config_dict)
if not os.path.exists(config.external_providers_dir): if config.external_providers_dir and not config.external_providers_dir.exists():
os.makedirs(config.external_providers_dir, exist_ok=True) config.external_providers_dir.mkdir(exist_ok=True)
run_args = formulate_run_args(args.image_type, args.image_name, config, args.template) run_args = formulate_run_args(args.image_type, args.image_name, config, args.template)
run_args.extend([str(os.getenv("LLAMA_STACK_PORT", 8321)), "--config", run_config]) run_args.extend([str(os.getenv("LLAMA_STACK_PORT", 8321)), "--config", run_config])
run_command(run_args) run_command(run_args)

View file

@ -125,7 +125,6 @@ RUN apt-get update && apt-get install -y \
curl wget telnet git\ curl wget telnet git\
procps psmisc lsof \ procps psmisc lsof \
traceroute \ traceroute \
bubblewrap \
gcc \ gcc \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*

View file

@ -16,7 +16,7 @@ from llama_stack.apis.inspect import (
VersionInfo, VersionInfo,
) )
from llama_stack.distribution.datatypes import StackRunConfig from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.server.endpoints import get_all_api_endpoints from llama_stack.distribution.server.routes import get_all_api_routes
from llama_stack.providers.datatypes import HealthStatus from llama_stack.providers.datatypes import HealthStatus
@ -42,15 +42,15 @@ class DistributionInspectImpl(Inspect):
run_config: StackRunConfig = self.config.run_config run_config: StackRunConfig = self.config.run_config
ret = [] ret = []
all_endpoints = get_all_api_endpoints() all_endpoints = get_all_api_routes()
for api, endpoints in all_endpoints.items(): for api, endpoints in all_endpoints.items():
# Always include provider and inspect APIs, filter others based on run config # Always include provider and inspect APIs, filter others based on run config
if api.value in ["providers", "inspect"]: if api.value in ["providers", "inspect"]:
ret.extend( ret.extend(
[ [
RouteInfo( RouteInfo(
route=e.route, route=e.path,
method=e.method, method=next(iter([m for m in e.methods if m != "HEAD"])),
provider_types=[], # These APIs don't have "real" providers - they're internal to the stack provider_types=[], # These APIs don't have "real" providers - they're internal to the stack
) )
for e in endpoints for e in endpoints
@ -62,8 +62,8 @@ class DistributionInspectImpl(Inspect):
ret.extend( ret.extend(
[ [
RouteInfo( RouteInfo(
route=e.route, route=e.path,
method=e.method, method=next(iter([m for m in e.methods if m != "HEAD"])),
provider_types=[p.provider_type for p in providers], provider_types=[p.provider_type for p in providers],
) )
for e in endpoints for e in endpoints

View file

@ -37,10 +37,7 @@ from llama_stack.distribution.request_headers import (
request_provider_data_context, request_provider_data_context,
) )
from llama_stack.distribution.resolver import ProviderRegistry from llama_stack.distribution.resolver import ProviderRegistry
from llama_stack.distribution.server.endpoints import ( from llama_stack.distribution.server.routes import find_matching_route, initialize_route_impls
find_matching_endpoint,
initialize_endpoint_impls,
)
from llama_stack.distribution.stack import ( from llama_stack.distribution.stack import (
construct_stack, construct_stack,
get_stack_run_config_from_template, get_stack_run_config_from_template,
@ -208,7 +205,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
async def initialize(self) -> bool: async def initialize(self) -> bool:
try: try:
self.endpoint_impls = None self.route_impls = None
self.impls = await construct_stack(self.config, self.custom_provider_registry) self.impls = await construct_stack(self.config, self.custom_provider_registry)
except ModuleNotFoundError as _e: except ModuleNotFoundError as _e:
cprint(_e.msg, color="red", file=sys.stderr) cprint(_e.msg, color="red", file=sys.stderr)
@ -254,7 +251,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
safe_config = redact_sensitive_fields(self.config.model_dump()) safe_config = redact_sensitive_fields(self.config.model_dump())
console.print(yaml.dump(safe_config, indent=2)) console.print(yaml.dump(safe_config, indent=2))
self.endpoint_impls = initialize_endpoint_impls(self.impls) self.route_impls = initialize_route_impls(self.impls)
return True return True
async def request( async def request(
@ -265,7 +262,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
stream=False, stream=False,
stream_cls=None, stream_cls=None,
): ):
if not self.endpoint_impls: if not self.route_impls:
raise ValueError("Client not initialized") raise ValueError("Client not initialized")
# Create headers with provider data if available # Create headers with provider data if available
@ -296,11 +293,14 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
cast_to: Any, cast_to: Any,
options: Any, options: Any,
): ):
if self.route_impls is None:
raise ValueError("Client not initialized")
path = options.url path = options.url
body = options.params or {} body = options.params or {}
body |= options.json_data or {} body |= options.json_data or {}
matched_func, path_params, route = find_matching_endpoint(options.method, path, self.endpoint_impls) matched_func, path_params, route = find_matching_route(options.method, path, self.route_impls)
body |= path_params body |= path_params
body = self._convert_body(path, options.method, body) body = self._convert_body(path, options.method, body)
await start_trace(route, {"__location__": "library_client"}) await start_trace(route, {"__location__": "library_client"})
@ -342,10 +342,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
options: Any, options: Any,
stream_cls: Any, stream_cls: Any,
): ):
if self.route_impls is None:
raise ValueError("Client not initialized")
path = options.url path = options.url
body = options.params or {} body = options.params or {}
body |= options.json_data or {} body |= options.json_data or {}
func, path_params, route = find_matching_endpoint(options.method, path, self.endpoint_impls) func, path_params, route = find_matching_route(options.method, path, self.route_impls)
body |= path_params body |= path_params
body = self._convert_body(path, options.method, body) body = self._convert_body(path, options.method, body)
@ -397,7 +400,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
if not body: if not body:
return {} return {}
func, _, _ = find_matching_endpoint(method, path, self.endpoint_impls) if self.route_impls is None:
raise ValueError("Client not initialized")
func, _, _ = find_matching_route(method, path, self.route_impls)
sig = inspect.signature(func) sig = inspect.signature(func)
# Strip NOT_GIVENs to use the defaults in signature # Strip NOT_GIVENs to use the defaults in signature

View file

@ -47,7 +47,7 @@ from llama_stack.providers.datatypes import (
RemoteProviderSpec, RemoteProviderSpec,
ScoringFunctionsProtocolPrivate, ScoringFunctionsProtocolPrivate,
ShieldsProtocolPrivate, ShieldsProtocolPrivate,
ToolsProtocolPrivate, ToolGroupsProtocolPrivate,
VectorDBsProtocolPrivate, VectorDBsProtocolPrivate,
) )
@ -93,7 +93,7 @@ def api_protocol_map_for_compliance_check() -> dict[Api, Any]:
def additional_protocols_map() -> dict[Api, Any]: def additional_protocols_map() -> dict[Api, Any]:
return { return {
Api.inference: (ModelsProtocolPrivate, Models, Api.models), Api.inference: (ModelsProtocolPrivate, Models, Api.models),
Api.tool_groups: (ToolsProtocolPrivate, ToolGroups, Api.tool_groups), Api.tool_groups: (ToolGroupsProtocolPrivate, ToolGroups, Api.tool_groups),
Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs), Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs),
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields), Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets), Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),

View file

@ -11,7 +11,7 @@ from llama_stack.apis.common.content_types import (
InterleavedContent, InterleavedContent,
) )
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
ListToolDefsResponse, ListToolsResponse,
RAGDocument, RAGDocument,
RAGQueryConfig, RAGQueryConfig,
RAGQueryResult, RAGQueryResult,
@ -19,7 +19,8 @@ from llama_stack.apis.tools import (
ToolRuntime, ToolRuntime,
) )
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core")
@ -28,7 +29,7 @@ class ToolRuntimeRouter(ToolRuntime):
class RagToolImpl(RAGToolRuntime): class RagToolImpl(RAGToolRuntime):
def __init__( def __init__(
self, self,
routing_table: RoutingTable, routing_table: ToolGroupsRoutingTable,
) -> None: ) -> None:
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl") logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
self.routing_table = routing_table self.routing_table = routing_table
@ -59,7 +60,7 @@ class ToolRuntimeRouter(ToolRuntime):
def __init__( def __init__(
self, self,
routing_table: RoutingTable, routing_table: ToolGroupsRoutingTable,
) -> None: ) -> None:
logger.debug("Initializing ToolRuntimeRouter") logger.debug("Initializing ToolRuntimeRouter")
self.routing_table = routing_table self.routing_table = routing_table
@ -86,6 +87,6 @@ class ToolRuntimeRouter(ToolRuntime):
async def list_runtime_tools( async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
) -> ListToolDefsResponse: ) -> ListToolsResponse:
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}") logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint) return await self.routing_table.list_tools(tool_group_id)

View file

@ -46,7 +46,7 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
elif api == Api.eval: elif api == Api.eval:
return await p.register_benchmark(obj) return await p.register_benchmark(obj)
elif api == Api.tool_runtime: elif api == Api.tool_runtime:
return await p.register_tool(obj) return await p.register_toolgroup(obj)
else: else:
raise ValueError(f"Unknown API {api} for registering object with provider") raise ValueError(f"Unknown API {api} for registering object with provider")
@ -60,7 +60,7 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
elif api == Api.datasetio: elif api == Api.datasetio:
return await p.unregister_dataset(obj.identifier) return await p.unregister_dataset(obj.identifier)
elif api == Api.tool_runtime: elif api == Api.tool_runtime:
return await p.unregister_tool(obj.identifier) return await p.unregister_toolgroup(obj.identifier)
else: else:
raise ValueError(f"Unregister not supported for {api}") raise ValueError(f"Unregister not supported for {api}")
@ -136,7 +136,7 @@ class CommonRoutingTableImpl(RoutingTable):
elif isinstance(self, BenchmarksRoutingTable): elif isinstance(self, BenchmarksRoutingTable):
return ("Eval", "benchmark") return ("Eval", "benchmark")
elif isinstance(self, ToolGroupsRoutingTable): elif isinstance(self, ToolGroupsRoutingTable):
return ("Tools", "tool") return ("ToolGroups", "tool_group")
else: else:
raise ValueError("Unknown routing table type") raise ValueError("Unknown routing table type")

View file

@ -7,11 +7,8 @@
from typing import Any from typing import Any
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups, ToolHost from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import ToolGroupWithACL
ToolGroupWithACL,
ToolWithACL,
)
from llama_stack.log import get_logger from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl from .common import CommonRoutingTableImpl
@ -19,12 +16,70 @@ from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core")
def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None:
# handle the funny case like "builtin::rag/knowledge_search"
parts = toolgroup_name_with_maybe_tool_name.split("/")
if len(parts) == 2:
return parts[0]
else:
return None
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse: toolgroups_to_tools: dict[str, list[Tool]] = {}
tools = await self.get_all_with_type("tool") tool_to_toolgroup: dict[str, str] = {}
# overridden
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
# we don't index tools in the registry anymore, but only keep a cache of them by toolgroup_id
# TODO: we may want to invalidate the cache (for a given toolgroup_id) every once in a while?
toolgroup_id = parse_toolgroup_from_toolgroup_name_pair(routing_key)
if toolgroup_id: if toolgroup_id:
tools = [tool for tool in tools if tool.toolgroup_id == toolgroup_id] routing_key = toolgroup_id
return ListToolsResponse(data=tools)
if routing_key in self.tool_to_toolgroup:
routing_key = self.tool_to_toolgroup[routing_key]
return super().get_provider_impl(routing_key, provider_id)
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
if toolgroup_id:
if group_id := parse_toolgroup_from_toolgroup_name_pair(toolgroup_id):
toolgroup_id = group_id
toolgroups = [await self.get_tool_group(toolgroup_id)]
else:
toolgroups = await self.get_all_with_type("tool_group")
all_tools = []
for toolgroup in toolgroups:
if toolgroup.identifier not in self.toolgroups_to_tools:
await self._index_tools(toolgroup)
all_tools.extend(self.toolgroups_to_tools[toolgroup.identifier])
return ListToolsResponse(data=all_tools)
async def _index_tools(self, toolgroup: ToolGroup):
provider_impl = super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id)
tooldefs_response = await provider_impl.list_runtime_tools(toolgroup.identifier, toolgroup.mcp_endpoint)
# TODO: kill this Tool vs ToolDef distinction
tooldefs = tooldefs_response.data
tools = []
for t in tooldefs:
tools.append(
Tool(
identifier=t.name,
toolgroup_id=toolgroup.identifier,
description=t.description or "",
parameters=t.parameters or [],
metadata=t.metadata,
provider_id=toolgroup.provider_id,
)
)
self.toolgroups_to_tools[toolgroup.identifier] = tools
for tool in tools:
self.tool_to_toolgroup[tool.identifier] = toolgroup.identifier
async def list_tool_groups(self) -> ListToolGroupsResponse: async def list_tool_groups(self) -> ListToolGroupsResponse:
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group")) return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
@ -36,7 +91,13 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
return tool_group return tool_group
async def get_tool(self, tool_name: str) -> Tool: async def get_tool(self, tool_name: str) -> Tool:
return await self.get_object_by_identifier("tool", tool_name) if tool_name in self.tool_to_toolgroup:
toolgroup_id = self.tool_to_toolgroup[tool_name]
tools = self.toolgroups_to_tools[toolgroup_id]
for tool in tools:
if tool.identifier == tool_name:
return tool
raise ValueError(f"Tool '{tool_name}' not found")
async def register_tool_group( async def register_tool_group(
self, self,
@ -45,53 +106,26 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
mcp_endpoint: URL | None = None, mcp_endpoint: URL | None = None,
args: dict[str, Any] | None = None, args: dict[str, Any] | None = None,
) -> None: ) -> None:
tools = [] toolgroup = ToolGroupWithACL(
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint)
tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
for tool_def in tool_defs.data:
tools.append(
ToolWithACL(
identifier=tool_def.name,
toolgroup_id=toolgroup_id,
description=tool_def.description or "",
parameters=tool_def.parameters or [],
provider_id=provider_id,
provider_resource_id=tool_def.name,
metadata=tool_def.metadata,
tool_host=tool_host,
)
)
for tool in tools:
existing_tool = await self.get_tool(tool.identifier)
# Compare existing and new object if one exists
if existing_tool:
existing_dict = existing_tool.model_dump()
new_dict = tool.model_dump()
if existing_dict != new_dict:
raise ValueError(
f"Object {tool.identifier} already exists in registry. Please use a different identifier."
)
await self.register_object(tool)
await self.dist_registry.register(
ToolGroupWithACL(
identifier=toolgroup_id, identifier=toolgroup_id,
provider_id=provider_id, provider_id=provider_id,
provider_resource_id=toolgroup_id, provider_resource_id=toolgroup_id,
mcp_endpoint=mcp_endpoint, mcp_endpoint=mcp_endpoint,
args=args, args=args,
) )
) await self.register_object(toolgroup)
# ideally, indexing of the tools should not be necessary because anyone using
# the tools should first list the tools and then use them. but there are assumptions
# baked in some of the code and tests right now.
if not toolgroup.mcp_endpoint:
await self._index_tools(toolgroup)
return toolgroup
async def unregister_toolgroup(self, toolgroup_id: str) -> None: async def unregister_toolgroup(self, toolgroup_id: str) -> None:
tool_group = await self.get_tool_group(toolgroup_id) tool_group = await self.get_tool_group(toolgroup_id)
if tool_group is None: if tool_group is None:
raise ValueError(f"Tool group {toolgroup_id} not found") raise ValueError(f"Tool group {toolgroup_id} not found")
tools = await self.list_tools(toolgroup_id)
for tool in getattr(tools, "data", []):
await self.unregister_object(tool)
await self.unregister_object(tool_group) await self.unregister_object(tool_group)
async def shutdown(self) -> None: async def shutdown(self) -> None:

View file

@ -6,20 +6,23 @@
import inspect import inspect
import re import re
from collections.abc import Callable
from typing import Any
from pydantic import BaseModel from aiohttp import hdrs
from starlette.routing import Route
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
from llama_stack.apis.version import LLAMA_STACK_API_VERSION from llama_stack.apis.version import LLAMA_STACK_API_VERSION
from llama_stack.distribution.resolver import api_protocol_map from llama_stack.distribution.resolver import api_protocol_map
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
EndpointFunc = Callable[..., Any]
class ApiEndpoint(BaseModel): PathParams = dict[str, str]
route: str RouteInfo = tuple[EndpointFunc, str]
method: str PathImpl = dict[str, RouteInfo]
name: str RouteImpls = dict[str, PathImpl]
descriptive_name: str | None = None RouteMatch = tuple[EndpointFunc, PathParams, str]
def toolgroup_protocol_map(): def toolgroup_protocol_map():
@ -28,13 +31,13 @@ def toolgroup_protocol_map():
} }
def get_all_api_endpoints() -> dict[Api, list[ApiEndpoint]]: def get_all_api_routes() -> dict[Api, list[Route]]:
apis = {} apis = {}
protocols = api_protocol_map() protocols = api_protocol_map()
toolgroup_protocols = toolgroup_protocol_map() toolgroup_protocols = toolgroup_protocol_map()
for api, protocol in protocols.items(): for api, protocol in protocols.items():
endpoints = [] routes = []
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction) protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
# HACK ALERT # HACK ALERT
@ -51,26 +54,28 @@ def get_all_api_endpoints() -> dict[Api, list[ApiEndpoint]]:
if not hasattr(method, "__webmethod__"): if not hasattr(method, "__webmethod__"):
continue continue
webmethod = method.__webmethod__ # The __webmethod__ attribute is dynamically added by the @webmethod decorator
route = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}" # mypy doesn't know about this dynamic attribute, so we ignore the attr-defined error
if webmethod.method == "GET": webmethod = method.__webmethod__ # type: ignore[attr-defined]
method = "get" path = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
elif webmethod.method == "DELETE": if webmethod.method == hdrs.METH_GET:
method = "delete" http_method = hdrs.METH_GET
elif webmethod.method == hdrs.METH_DELETE:
http_method = hdrs.METH_DELETE
else: else:
method = "post" http_method = hdrs.METH_POST
endpoints.append( routes.append(
ApiEndpoint(route=route, method=method, name=name, descriptive_name=webmethod.descriptive_name) Route(path=path, methods=[http_method], name=name, endpoint=None)
) ) # setting endpoint to None since don't use a Router object
apis[api] = endpoints apis[api] = routes
return apis return apis
def initialize_endpoint_impls(impls): def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls:
endpoints = get_all_api_endpoints() routes = get_all_api_routes()
endpoint_impls = {} route_impls: RouteImpls = {}
def _convert_path_to_regex(path: str) -> str: def _convert_path_to_regex(path: str) -> str:
# Convert {param} to named capture groups # Convert {param} to named capture groups
@ -83,29 +88,34 @@ def initialize_endpoint_impls(impls):
return f"^{pattern}$" return f"^{pattern}$"
for api, api_endpoints in endpoints.items(): for api, api_routes in routes.items():
if api not in impls: if api not in impls:
continue continue
for endpoint in api_endpoints: for route in api_routes:
impl = impls[api] impl = impls[api]
func = getattr(impl, endpoint.name) func = getattr(impl, route.name)
if endpoint.method not in endpoint_impls: # Get the first (and typically only) method from the set, filtering out HEAD
endpoint_impls[endpoint.method] = {} available_methods = [m for m in route.methods if m != "HEAD"]
endpoint_impls[endpoint.method][_convert_path_to_regex(endpoint.route)] = ( if not available_methods:
continue # Skip if only HEAD method is available
method = available_methods[0].lower()
if method not in route_impls:
route_impls[method] = {}
route_impls[method][_convert_path_to_regex(route.path)] = (
func, func,
endpoint.descriptive_name or endpoint.route, route.path,
) )
return endpoint_impls return route_impls
def find_matching_endpoint(method, path, endpoint_impls): def find_matching_route(method: str, path: str, route_impls: RouteImpls) -> RouteMatch:
"""Find the matching endpoint implementation for a given method and path. """Find the matching endpoint implementation for a given method and path.
Args: Args:
method: HTTP method (GET, POST, etc.) method: HTTP method (GET, POST, etc.)
path: URL path to match against path: URL path to match against
endpoint_impls: A dictionary of endpoint implementations route_impls: A dictionary of endpoint implementations
Returns: Returns:
A tuple of (endpoint_function, path_params, descriptive_name) A tuple of (endpoint_function, path_params, descriptive_name)
@ -113,7 +123,7 @@ def find_matching_endpoint(method, path, endpoint_impls):
Raises: Raises:
ValueError: If no matching endpoint is found ValueError: If no matching endpoint is found
""" """
impls = endpoint_impls.get(method.lower()) impls = route_impls.get(method.lower())
if not impls: if not impls:
raise ValueError(f"No endpoint found for {path}") raise ValueError(f"No endpoint found for {path}")

View file

@ -6,6 +6,7 @@
import argparse import argparse
import asyncio import asyncio
import functools
import inspect import inspect
import json import json
import os import os
@ -13,6 +14,7 @@ import ssl
import sys import sys
import traceback import traceback
import warnings import warnings
from collections.abc import Callable
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from importlib.metadata import version as parse_version from importlib.metadata import version as parse_version
from pathlib import Path from pathlib import Path
@ -20,6 +22,7 @@ from typing import Annotated, Any
import rich.pretty import rich.pretty
import yaml import yaml
from aiohttp import hdrs
from fastapi import Body, FastAPI, HTTPException, Request from fastapi import Body, FastAPI, HTTPException, Request
from fastapi import Path as FastapiPath from fastapi import Path as FastapiPath
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
@ -35,9 +38,10 @@ from llama_stack.distribution.request_headers import (
request_provider_data_context, request_provider_data_context,
) )
from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.server.endpoints import ( from llama_stack.distribution.server.routes import (
find_matching_endpoint, find_matching_route,
initialize_endpoint_impls, get_all_api_routes,
initialize_route_impls,
) )
from llama_stack.distribution.stack import ( from llama_stack.distribution.stack import (
construct_stack, construct_stack,
@ -60,7 +64,6 @@ from llama_stack.providers.utils.telemetry.tracing import (
) )
from .auth import AuthenticationMiddleware from .auth import AuthenticationMiddleware
from .endpoints import get_all_api_endpoints
from .quota import QuotaMiddleware from .quota import QuotaMiddleware
REPO_ROOT = Path(__file__).parent.parent.parent.parent REPO_ROOT = Path(__file__).parent.parent.parent.parent
@ -209,8 +212,9 @@ async def log_request_pre_validation(request: Request):
logger.warning(f"Could not read or log request body for {request.method} {request.url.path}: {e}") logger.warning(f"Could not read or log request body for {request.method} {request.url.path}: {e}")
def create_dynamic_typed_route(func: Any, method: str, route: str): def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
async def endpoint(request: Request, **kwargs): @functools.wraps(func)
async def route_handler(request: Request, **kwargs):
# Get auth attributes from the request scope # Get auth attributes from the request scope
user_attributes = request.scope.get("user_attributes", {}) user_attributes = request.scope.get("user_attributes", {})
@ -250,9 +254,9 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
for param in new_params[1:] for param in new_params[1:]
] ]
endpoint.__signature__ = sig.replace(parameters=new_params) route_handler.__signature__ = sig.replace(parameters=new_params)
return endpoint return route_handler
class TracingMiddleware: class TracingMiddleware:
@ -274,14 +278,14 @@ class TracingMiddleware:
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}") logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
return await self.app(scope, receive, send) return await self.app(scope, receive, send)
if not hasattr(self, "endpoint_impls"): if not hasattr(self, "route_impls"):
self.endpoint_impls = initialize_endpoint_impls(self.impls) self.route_impls = initialize_route_impls(self.impls)
try: try:
_, _, trace_path = find_matching_endpoint(scope.get("method", "GET"), path, self.endpoint_impls) _, _, trace_path = find_matching_route(scope.get("method", hdrs.METH_GET), path, self.route_impls)
except ValueError: except ValueError:
# If no matching endpoint is found, pass through to FastAPI # If no matching endpoint is found, pass through to FastAPI
logger.debug(f"No matching endpoint found for path: {path}, falling back to FastAPI") logger.debug(f"No matching route found for path: {path}, falling back to FastAPI")
return await self.app(scope, receive, send) return await self.app(scope, receive, send)
trace_attributes = {"__location__": "server", "raw_path": path} trace_attributes = {"__location__": "server", "raw_path": path}
@ -423,7 +427,7 @@ def main(args: argparse.Namespace | None = None):
logger.info("Run configuration:") logger.info("Run configuration:")
safe_config = redact_sensitive_fields(config.model_dump()) safe_config = redact_sensitive_fields(config.model_dump())
logger.info(yaml.dump(safe_config, indent=2)) logger.info(yaml.dump(safe_config, indent=2, default_style=None))
app = FastAPI( app = FastAPI(
lifespan=lifespan, lifespan=lifespan,
@ -490,7 +494,7 @@ def main(args: argparse.Namespace | None = None):
else: else:
setup_logger(TelemetryAdapter(TelemetryConfig(), {})) setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
all_endpoints = get_all_api_endpoints() all_routes = get_all_api_routes()
if config.apis: if config.apis:
apis_to_serve = set(config.apis) apis_to_serve = set(config.apis)
@ -508,24 +512,29 @@ def main(args: argparse.Namespace | None = None):
for api_str in apis_to_serve: for api_str in apis_to_serve:
api = Api(api_str) api = Api(api_str)
endpoints = all_endpoints[api] routes = all_routes[api]
impl = impls[api] impl = impls[api]
for endpoint in endpoints: for route in routes:
if not hasattr(impl, endpoint.name): if not hasattr(impl, route.name):
# ideally this should be a typing violation already # ideally this should be a typing violation already
raise ValueError(f"Could not find method {endpoint.name} on {impl}!!") raise ValueError(f"Could not find method {route.name} on {impl}!")
impl_method = getattr(impl, endpoint.name) impl_method = getattr(impl, route.name)
logger.debug(f"{endpoint.method.upper()} {endpoint.route}") # Filter out HEAD method since it's automatically handled by FastAPI for GET routes
available_methods = [m for m in route.methods if m != "HEAD"]
if not available_methods:
raise ValueError(f"No methods found for {route.name} on {impl}")
method = available_methods[0]
logger.debug(f"{method} {route.path}")
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields") warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
getattr(app, endpoint.method)(endpoint.route, response_model=None)( getattr(app, method.lower())(route.path, response_model=None)(
create_dynamic_typed_route( create_dynamic_typed_route(
impl_method, impl_method,
endpoint.method, method.lower(),
endpoint.route, route.path,
) )
) )

View file

@ -36,7 +36,7 @@ class DistributionRegistry(Protocol):
REGISTER_PREFIX = "distributions:registry" REGISTER_PREFIX = "distributions:registry"
KEY_VERSION = "v8" KEY_VERSION = "v9"
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}" KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"

View file

@ -16,7 +16,7 @@ from llama_stack.apis.datatypes import Api
from llama_stack.apis.models import Model from llama_stack.apis.models import Model
from llama_stack.apis.scoring_functions import ScoringFn from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
from llama_stack.apis.tools import Tool from llama_stack.apis.tools import ToolGroup
from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@ -74,10 +74,10 @@ class BenchmarksProtocolPrivate(Protocol):
async def register_benchmark(self, benchmark: Benchmark) -> None: ... async def register_benchmark(self, benchmark: Benchmark) -> None: ...
class ToolsProtocolPrivate(Protocol): class ToolGroupsProtocolPrivate(Protocol):
async def register_tool(self, tool: Tool) -> None: ... async def register_toolgroup(self, toolgroup: ToolGroup) -> None: ...
async def unregister_tool(self, tool_id: str) -> None: ... async def unregister_toolgroup(self, toolgroup_id: str) -> None: ...
@json_schema_type @json_schema_type

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import json import json
import time
import uuid import uuid
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from typing import Any, cast from typing import Any, cast
@ -29,10 +30,12 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseObjectStream, OpenAIResponseObjectStream,
OpenAIResponseObjectStreamResponseCompleted, OpenAIResponseObjectStreamResponseCompleted,
OpenAIResponseObjectStreamResponseCreated, OpenAIResponseObjectStreamResponseCreated,
OpenAIResponseObjectStreamResponseOutputTextDelta,
OpenAIResponseOutput, OpenAIResponseOutput,
OpenAIResponseOutputMessageContent, OpenAIResponseOutputMessageContent,
OpenAIResponseOutputMessageContentOutputText, OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageFunctionToolCall, OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseOutputMessageWebSearchToolCall, OpenAIResponseOutputMessageWebSearchToolCall,
) )
from llama_stack.apis.inference.inference import ( from llama_stack.apis.inference.inference import (
@ -255,110 +258,14 @@ class OpenAIResponsesImpl:
""" """
return await self.responses_store.list_response_input_items(response_id, after, before, include, limit, order) return await self.responses_store.list_response_input_items(response_id, after, before, include, limit, order)
async def create_openai_response( async def _process_response_choices(
self, self,
input: str | list[OpenAIResponseInput], chat_response: OpenAIChatCompletion,
model: str, ctx: ChatCompletionContext,
instructions: str | None = None, tools: list[OpenAIResponseInputTool] | None,
previous_response_id: str | None = None, ) -> list[OpenAIResponseOutput]:
store: bool | None = True, """Handle tool execution and response message creation."""
stream: bool | None = False,
temperature: float | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
):
output_messages: list[OpenAIResponseOutput] = [] output_messages: list[OpenAIResponseOutput] = []
stream = False if stream is None else stream
# Huge TODO: we need to run this in a loop, until morale improves
# Create context to run "chat completion"
input = await self._prepend_previous_response(input, previous_response_id)
messages = await _convert_response_input_to_chat_messages(input)
await self._prepend_instructions(messages, instructions)
chat_tools, mcp_tool_to_server, mcp_list_message = (
await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {}, None)
)
if mcp_list_message:
output_messages.append(mcp_list_message)
ctx = ChatCompletionContext(
model=model,
messages=messages,
tools=chat_tools,
mcp_tool_to_server=mcp_tool_to_server,
stream=stream,
temperature=temperature,
)
# Run inference
chat_response = await self.inference_api.openai_chat_completion(
model=model,
messages=messages,
tools=chat_tools,
stream=stream,
temperature=temperature,
)
# Collect output
if stream:
# TODO: refactor this into a separate method that handles streaming
chat_response_id = ""
chat_response_content = []
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
# TODO: these chunk_ fields are hacky and only take the last chunk into account
chunk_created = 0
chunk_model = ""
chunk_finish_reason = ""
async for chunk in chat_response:
chat_response_id = chunk.id
chunk_created = chunk.created
chunk_model = chunk.model
for chunk_choice in chunk.choices:
# TODO: this only works for text content
chat_response_content.append(chunk_choice.delta.content or "")
if chunk_choice.finish_reason:
chunk_finish_reason = chunk_choice.finish_reason
# Aggregate tool call arguments across chunks, using their index as the aggregation key
if chunk_choice.delta.tool_calls:
for tool_call in chunk_choice.delta.tool_calls:
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
if response_tool_call:
response_tool_call.function.arguments += tool_call.function.arguments
else:
tool_call_dict: dict[str, Any] = tool_call.model_dump()
# Ensure we don't have any empty type field in the tool call dict.
# The OpenAI client used by providers often returns a type=None here.
tool_call_dict.pop("type", None)
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
chat_response_tool_calls[tool_call.index] = response_tool_call
# Convert the dict of tool calls by index to a list of tool calls to pass back in our response
if chat_response_tool_calls:
tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())]
else:
tool_calls = None
assistant_message = OpenAIAssistantMessageParam(
content="".join(chat_response_content),
tool_calls=tool_calls,
)
chat_response = OpenAIChatCompletion(
id=chat_response_id,
choices=[
OpenAIChoice(
message=assistant_message,
finish_reason=chunk_finish_reason,
index=0,
)
],
created=chunk_created,
model=chunk_model,
)
else:
# dump and reload to map to our pydantic types
chat_response = OpenAIChatCompletion(**chat_response.model_dump())
# Execute tool calls if any # Execute tool calls if any
for choice in chat_response.choices: for choice in chat_response.choices:
if choice.message.tool_calls and tools: if choice.message.tool_calls and tools:
@ -380,19 +287,13 @@ class OpenAIResponsesImpl:
else: else:
output_messages.append(await _convert_chat_choice_to_response_message(choice)) output_messages.append(await _convert_chat_choice_to_response_message(choice))
# Create response object return output_messages
response = OpenAIResponseObject(
created_at=chat_response.created,
id=f"resp-{uuid.uuid4()}",
model=model,
object="response",
status="completed",
output=output_messages,
)
logger.debug(f"OpenAI Responses response: {response}")
# Store response if requested async def _store_response(
if store: self,
response: OpenAIResponseObject,
input: str | list[OpenAIResponseInput],
) -> None:
new_input_id = f"msg_{uuid.uuid4()}" new_input_id = f"msg_{uuid.uuid4()}"
if isinstance(input, str): if isinstance(input, str):
# synthesize a message from the input string # synthesize a message from the input string
@ -421,17 +322,233 @@ class OpenAIResponsesImpl:
input=input_items_data, input=input_items_data,
) )
async def create_openai_response(
self,
input: str | list[OpenAIResponseInput],
model: str,
instructions: str | None = None,
previous_response_id: str | None = None,
store: bool | None = True,
stream: bool | None = False,
temperature: float | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
):
stream = False if stream is None else stream
output_messages: list[OpenAIResponseOutput] = []
# Input preprocessing
input = await self._prepend_previous_response(input, previous_response_id)
messages = await _convert_response_input_to_chat_messages(input)
await self._prepend_instructions(messages, instructions)
# Tool setup
chat_tools, mcp_tool_to_server, mcp_list_message = (
await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {}, None)
)
if mcp_list_message:
output_messages.append(mcp_list_message)
ctx = ChatCompletionContext(
model=model,
messages=messages,
tools=chat_tools,
mcp_tool_to_server=mcp_tool_to_server,
stream=stream,
temperature=temperature,
)
inference_result = await self.inference_api.openai_chat_completion(
model=model,
messages=messages,
tools=chat_tools,
stream=stream,
temperature=temperature,
)
if stream: if stream:
return self._create_streaming_response(
inference_result=inference_result,
ctx=ctx,
output_messages=output_messages,
input=input,
model=model,
store=store,
tools=tools,
)
else:
return await self._create_non_streaming_response(
inference_result=inference_result,
ctx=ctx,
output_messages=output_messages,
input=input,
model=model,
store=store,
tools=tools,
)
async def async_response() -> AsyncIterator[OpenAIResponseObjectStream]: async def _create_non_streaming_response(
# TODO: response created should actually get emitted much earlier in the process self,
yield OpenAIResponseObjectStreamResponseCreated(response=response) inference_result: Any,
yield OpenAIResponseObjectStreamResponseCompleted(response=response) ctx: ChatCompletionContext,
output_messages: list[OpenAIResponseOutput],
input: str | list[OpenAIResponseInput],
model: str,
store: bool | None,
tools: list[OpenAIResponseInputTool] | None,
) -> OpenAIResponseObject:
chat_response = OpenAIChatCompletion(**inference_result.model_dump())
return async_response() # Process response choices (tool execution and message creation)
output_messages.extend(
await self._process_response_choices(
chat_response=chat_response,
ctx=ctx,
tools=tools,
)
)
response = OpenAIResponseObject(
created_at=chat_response.created,
id=f"resp-{uuid.uuid4()}",
model=model,
object="response",
status="completed",
output=output_messages,
)
logger.debug(f"OpenAI Responses response: {response}")
# Store response if requested
if store:
await self._store_response(
response=response,
input=input,
)
return response return response
async def _create_streaming_response(
self,
inference_result: Any,
ctx: ChatCompletionContext,
output_messages: list[OpenAIResponseOutput],
input: str | list[OpenAIResponseInput],
model: str,
store: bool | None,
tools: list[OpenAIResponseInputTool] | None,
) -> AsyncIterator[OpenAIResponseObjectStream]:
# Create initial response and emit response.created immediately
response_id = f"resp-{uuid.uuid4()}"
created_at = int(time.time())
initial_response = OpenAIResponseObject(
created_at=created_at,
id=response_id,
model=model,
object="response",
status="in_progress",
output=output_messages.copy(),
)
# Emit response.created immediately
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
# For streaming, inference_result is an async iterator of chunks
# Stream chunks and emit delta events as they arrive
chat_response_id = ""
chat_response_content = []
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
chunk_created = 0
chunk_model = ""
chunk_finish_reason = ""
sequence_number = 0
# Create a placeholder message item for delta events
message_item_id = f"msg_{uuid.uuid4()}"
async for chunk in inference_result:
chat_response_id = chunk.id
chunk_created = chunk.created
chunk_model = chunk.model
for chunk_choice in chunk.choices:
# Emit incremental text content as delta events
if chunk_choice.delta.content:
sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputTextDelta(
content_index=0,
delta=chunk_choice.delta.content,
item_id=message_item_id,
output_index=0,
sequence_number=sequence_number,
)
# Collect content for final response
chat_response_content.append(chunk_choice.delta.content or "")
if chunk_choice.finish_reason:
chunk_finish_reason = chunk_choice.finish_reason
# Aggregate tool call arguments across chunks, using their index as the aggregation key
if chunk_choice.delta.tool_calls:
for tool_call in chunk_choice.delta.tool_calls:
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
if response_tool_call:
response_tool_call.function.arguments += tool_call.function.arguments
else:
tool_call_dict: dict[str, Any] = tool_call.model_dump()
tool_call_dict.pop("type", None)
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
chat_response_tool_calls[tool_call.index] = response_tool_call
# Convert collected chunks to complete response
if chat_response_tool_calls:
tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())]
else:
tool_calls = None
assistant_message = OpenAIAssistantMessageParam(
content="".join(chat_response_content),
tool_calls=tool_calls,
)
chat_response_obj = OpenAIChatCompletion(
id=chat_response_id,
choices=[
OpenAIChoice(
message=assistant_message,
finish_reason=chunk_finish_reason,
index=0,
)
],
created=chunk_created,
model=chunk_model,
)
# Process response choices (tool execution and message creation)
output_messages.extend(
await self._process_response_choices(
chat_response=chat_response_obj,
ctx=ctx,
tools=tools,
)
)
# Create final response
final_response = OpenAIResponseObject(
created_at=created_at,
id=response_id,
model=model,
object="response",
status="completed",
output=output_messages,
)
if store:
await self._store_response(
response=final_response,
input=input,
)
# Emit response.completed
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
async def _convert_response_tools_to_chat_tools( async def _convert_response_tools_to_chat_tools(
self, tools: list[OpenAIResponseInputTool] self, tools: list[OpenAIResponseInputTool]
) -> tuple[ ) -> tuple[
@ -441,7 +558,6 @@ class OpenAIResponsesImpl:
]: ]:
from llama_stack.apis.agents.openai_responses import ( from llama_stack.apis.agents.openai_responses import (
MCPListToolsTool, MCPListToolsTool,
OpenAIResponseOutputMessageMCPListTools,
) )
from llama_stack.apis.tools.tools import Tool from llama_stack.apis.tools.tools import Tool

View file

@ -75,6 +75,8 @@ class PromptGuardShield:
self.temperature = temperature self.temperature = temperature
self.threshold = threshold self.threshold = threshold
self.device = "cpu"
if torch.cuda.is_available():
self.device = "cuda" self.device = "cuda"
# load model and tokenizer # load model and tokenizer

View file

@ -25,14 +25,14 @@ from llama_stack.apis.tools import (
RAGQueryConfig, RAGQueryConfig,
RAGQueryResult, RAGQueryResult,
RAGToolRuntime, RAGToolRuntime,
Tool,
ToolDef, ToolDef,
ToolGroup,
ToolInvocationResult, ToolInvocationResult,
ToolParameter, ToolParameter,
ToolRuntime, ToolRuntime,
) )
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import ToolsProtocolPrivate from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
content_from_doc, content_from_doc,
@ -49,7 +49,7 @@ def make_random_string(length: int = 8):
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length)) return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime):
def __init__( def __init__(
self, self,
config: RagToolRuntimeConfig, config: RagToolRuntimeConfig,
@ -66,10 +66,10 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
async def shutdown(self): async def shutdown(self):
pass pass
async def register_tool(self, tool: Tool) -> None: async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
pass pass
async def unregister_tool(self, tool_id: str) -> None: async def unregister_toolgroup(self, toolgroup_id: str) -> None:
return return
async def insert( async def insert(

View file

@ -19,10 +19,10 @@ def available_providers() -> list[ProviderSpec]:
api=Api.agents, api=Api.agents,
provider_type="inline::meta-reference", provider_type="inline::meta-reference",
pip_packages=[ pip_packages=[
"matplotlib", # "matplotlib",
"pillow", # "pillow",
"pandas", # "pandas",
"scikit-learn", # "scikit-learn",
] ]
+ kvstore_dependencies(), + kvstore_dependencies(),
module="llama_stack.providers.inline.agents.meta_reference", module="llama_stack.providers.inline.agents.meta_reference",

View file

@ -13,7 +13,7 @@ def available_providers() -> list[ProviderSpec]:
InlineProviderSpec( InlineProviderSpec(
api=Api.eval, api=Api.eval,
provider_type="inline::meta-reference", provider_type="inline::meta-reference",
pip_packages=["tree_sitter", "pythainlp", "langdetect", "emoji", "nltk"], # pip_packages=["tree_sitter", "pythainlp", "langdetect", "emoji", "nltk"],
module="llama_stack.providers.inline.eval.meta_reference", module="llama_stack.providers.inline.eval.meta_reference",
config_class="llama_stack.providers.inline.eval.meta_reference.MetaReferenceEvalConfig", config_class="llama_stack.providers.inline.eval.meta_reference.MetaReferenceEvalConfig",
api_dependencies=[ api_dependencies=[

View file

@ -20,16 +20,16 @@ def available_providers() -> list[ProviderSpec]:
api=Api.tool_runtime, api=Api.tool_runtime,
provider_type="inline::rag-runtime", provider_type="inline::rag-runtime",
pip_packages=[ pip_packages=[
"blobfile", # "blobfile",
"chardet", # "chardet",
"pypdf", # "pypdf",
"tqdm", # "tqdm",
"numpy", # "numpy",
"scikit-learn", # "scikit-learn",
"scipy", # "scipy",
"nltk", # "nltk",
"sentencepiece", # "sentencepiece",
"transformers", # "transformers",
], ],
module="llama_stack.providers.inline.tool_runtime.rag", module="llama_stack.providers.inline.tool_runtime.rag",
config_class="llama_stack.providers.inline.tool_runtime.rag.config.RagToolRuntimeConfig", config_class="llama_stack.providers.inline.tool_runtime.rag.config.RagToolRuntimeConfig",

View file

@ -4,8 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from pathlib import Path
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, field_validator
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@ -24,11 +25,27 @@ class VLLMInferenceAdapterConfig(BaseModel):
default="fake", default="fake",
description="The API token", description="The API token",
) )
tls_verify: bool = Field( tls_verify: bool | str = Field(
default=True, default=True,
description="Whether to verify TLS certificates", description="Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file.",
) )
@field_validator("tls_verify")
@classmethod
def validate_tls_verify(cls, v):
if isinstance(v, str):
# Check if it's a boolean string
if v.lower() in ("true", "false"):
return v.lower() == "true"
# Otherwise, treat it as a cert path
cert_path = Path(v).expanduser().resolve()
if not cert_path.exists():
raise ValueError(f"TLS certificate file does not exist: {v}")
if not cert_path.is_file():
raise ValueError(f"TLS certificate path is not a file: {v}")
return v
return v
@classmethod @classmethod
def sample_run_config( def sample_run_config(
cls, cls,

View file

@ -313,7 +313,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
return AsyncOpenAI( return AsyncOpenAI(
base_url=self.config.url, base_url=self.config.url,
api_key=self.config.api_token, api_key=self.config.api_token,
http_client=None if self.config.tls_verify else httpx.AsyncClient(verify=False), http_client=httpx.AsyncClient(verify=self.config.tls_verify),
) )
async def completion( async def completion(

View file

@ -12,19 +12,19 @@ import httpx
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
ListToolDefsResponse, ListToolDefsResponse,
Tool,
ToolDef, ToolDef,
ToolGroup,
ToolInvocationResult, ToolInvocationResult,
ToolParameter, ToolParameter,
ToolRuntime, ToolRuntime,
) )
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ToolsProtocolPrivate from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from .config import BingSearchToolConfig from .config import BingSearchToolConfig
class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): class BingSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: BingSearchToolConfig): def __init__(self, config: BingSearchToolConfig):
self.config = config self.config = config
self.url = "https://api.bing.microsoft.com/v7.0/search" self.url = "https://api.bing.microsoft.com/v7.0/search"
@ -32,10 +32,10 @@ class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestP
async def initialize(self): async def initialize(self):
pass pass
async def register_tool(self, tool: Tool) -> None: async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
pass pass
async def unregister_tool(self, tool_id: str) -> None: async def unregister_toolgroup(self, toolgroup_id: str) -> None:
return return
def _get_api_key(self) -> str: def _get_api_key(self) -> str:

View file

@ -11,30 +11,30 @@ import httpx
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
ListToolDefsResponse, ListToolDefsResponse,
Tool,
ToolDef, ToolDef,
ToolGroup,
ToolInvocationResult, ToolInvocationResult,
ToolParameter, ToolParameter,
ToolRuntime, ToolRuntime,
) )
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.models.llama.datatypes import BuiltinTool from llama_stack.models.llama.datatypes import BuiltinTool
from llama_stack.providers.datatypes import ToolsProtocolPrivate from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from .config import BraveSearchToolConfig from .config import BraveSearchToolConfig
class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): class BraveSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: BraveSearchToolConfig): def __init__(self, config: BraveSearchToolConfig):
self.config = config self.config = config
async def initialize(self): async def initialize(self):
pass pass
async def register_tool(self, tool: Tool) -> None: async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
pass pass
async def unregister_tool(self, tool_id: str) -> None: async def unregister_toolgroup(self, toolgroup_id: str) -> None:
return return
def _get_api_key(self) -> str: def _get_api_key(self) -> str:

View file

@ -10,8 +10,8 @@ from pydantic import BaseModel
class MCPProviderDataValidator(BaseModel): class MCPProviderDataValidator(BaseModel):
# mcp_endpoint => list of headers to send # mcp_endpoint => dict of headers to send
mcp_headers: dict[str, list[str]] | None = None mcp_headers: dict[str, dict[str, str]] | None = None
class MCPProviderConfig(BaseModel): class MCPProviderConfig(BaseModel):

View file

@ -11,26 +11,33 @@ from llama_stack.apis.common.content_types import URL
from llama_stack.apis.datatypes import Api from llama_stack.apis.datatypes import Api
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
ListToolDefsResponse, ListToolDefsResponse,
ToolGroup,
ToolInvocationResult, ToolInvocationResult,
ToolRuntime, ToolRuntime,
) )
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ToolsProtocolPrivate from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from llama_stack.providers.utils.tools.mcp import convert_header_list_to_dict, invoke_mcp_tool, list_mcp_tools from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool, list_mcp_tools
from .config import MCPProviderConfig from .config import MCPProviderConfig
logger = get_logger(__name__, category="tools") logger = get_logger(__name__, category="tools")
class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]): def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]):
self.config = config self.config = config
async def initialize(self): async def initialize(self):
pass pass
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
pass
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
return
async def list_runtime_tools( async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
) -> ListToolDefsResponse: ) -> ListToolDefsResponse:
@ -62,5 +69,5 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, Nee
for uri, values in provider_data.mcp_headers.items(): for uri, values in provider_data.mcp_headers.items():
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri): if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
continue continue
headers.update(convert_header_list_to_dict(values)) headers.update(values)
return headers return headers

View file

@ -12,29 +12,29 @@ import httpx
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
ListToolDefsResponse, ListToolDefsResponse,
Tool,
ToolDef, ToolDef,
ToolGroup,
ToolInvocationResult, ToolInvocationResult,
ToolParameter, ToolParameter,
ToolRuntime, ToolRuntime,
) )
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ToolsProtocolPrivate from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from .config import TavilySearchToolConfig from .config import TavilySearchToolConfig
class TavilySearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): class TavilySearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: TavilySearchToolConfig): def __init__(self, config: TavilySearchToolConfig):
self.config = config self.config = config
async def initialize(self): async def initialize(self):
pass pass
async def register_tool(self, tool: Tool) -> None: async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
pass pass
async def unregister_tool(self, tool_id: str) -> None: async def unregister_toolgroup(self, toolgroup_id: str) -> None:
return return
def _get_api_key(self) -> str: def _get_api_key(self) -> str:

View file

@ -12,19 +12,19 @@ import httpx
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
ListToolDefsResponse, ListToolDefsResponse,
Tool,
ToolDef, ToolDef,
ToolGroup,
ToolInvocationResult, ToolInvocationResult,
ToolParameter, ToolParameter,
ToolRuntime, ToolRuntime,
) )
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ToolsProtocolPrivate from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from .config import WolframAlphaToolConfig from .config import WolframAlphaToolConfig
class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): class WolframAlphaToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: WolframAlphaToolConfig): def __init__(self, config: WolframAlphaToolConfig):
self.config = config self.config = config
self.url = "https://api.wolframalpha.com/v2/query" self.url = "https://api.wolframalpha.com/v2/query"
@ -32,10 +32,10 @@ class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
async def initialize(self): async def initialize(self):
pass pass
async def register_tool(self, tool: Tool) -> None: async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
pass pass
async def unregister_tool(self, tool_id: str) -> None: async def unregister_toolgroup(self, toolgroup_id: str) -> None:
return return
def _get_api_key(self) -> str: def _get_api_key(self) -> str:

View file

@ -1402,9 +1402,8 @@ class OpenAIChatCompletionToLlamaStackMixin:
outstanding_responses: list[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]], outstanding_responses: list[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]],
): ):
id = f"chatcmpl-{uuid.uuid4()}" id = f"chatcmpl-{uuid.uuid4()}"
for outstanding_response in outstanding_responses: for i, outstanding_response in enumerate(outstanding_responses):
response = await outstanding_response response = await outstanding_response
i = 0
async for chunk in response: async for chunk in response:
event = chunk.event event = chunk.event
finish_reason = _convert_stop_reason_to_openai_finish_reason(event.stop_reason) finish_reason = _convert_stop_reason_to_openai_finish_reason(event.stop_reason)
@ -1459,7 +1458,6 @@ class OpenAIChatCompletionToLlamaStackMixin:
model=model, model=model,
object="chat.completion.chunk", object="chat.completion.chunk",
) )
i = i + 1
async def _process_non_stream_response( async def _process_non_stream_response(
self, model: str, outstanding_responses: list[Awaitable[ChatCompletionResponse]] self, model: str, outstanding_responses: list[Awaitable[ChatCompletionResponse]]

View file

@ -51,16 +51,6 @@ async def sse_client_wrapper(endpoint: str, headers: dict[str, str]):
raise raise
def convert_header_list_to_dict(header_list: list[str]) -> dict[str, str]:
headers = {}
for header in header_list:
parts = header.split(":")
if len(parts) == 2:
k, v = parts
headers[k.strip()] = v.strip()
return headers
async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse: async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse:
tools = [] tools = []
async with sse_client_wrapper(endpoint, headers) as session: async with sse_client_wrapper(endpoint, headers) as session:

View file

@ -1,855 +0,0 @@
{
"bedrock": [
"aiosqlite",
"autoevals",
"blobfile",
"boto3",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlalchemy[asyncio]",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn"
],
"cerebras": [
"aiosqlite",
"autoevals",
"blobfile",
"cerebras_cloud_sdk",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"matplotlib",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlalchemy[asyncio]",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"ci-tests": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"fastapi",
"fire",
"fireworks-ai",
"httpx",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlalchemy[asyncio]",
"sqlite-vec",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"dell": [
"aiohttp",
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"huggingface_hub",
"langdetect",
"matplotlib",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlalchemy[asyncio]",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"fireworks": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"fireworks-ai",
"httpx",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlalchemy[asyncio]",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"groq": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"litellm",
"matplotlib",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlalchemy[asyncio]",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn"
],
"hf-endpoint": [
"aiohttp",
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"huggingface_hub",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlalchemy[asyncio]",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn"
],
"hf-serverless": [
"aiohttp",
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"huggingface_hub",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlalchemy[asyncio]",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"llama_api": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"fastapi",
"fire",
"httpx",
"langdetect",
"litellm",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlalchemy[asyncio]",
"sqlite-vec",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"meta-reference-gpu": [
"accelerate",
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"fairscale",
"faiss-cpu",
"fastapi",
"fbgemm-gpu-genai==1.1.2",
"fire",
"httpx",
"langdetect",
"lm-format-enforcer",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentence-transformers",
"sentencepiece",
"sqlalchemy[asyncio]",
"torch",
"torchao==0.8.0",
"torchvision",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"zmq"
],
"nvidia": [
"aiohttp",
"aiosqlite",
"blobfile",
"chardet",
"datasets",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"matplotlib",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlalchemy[asyncio]",
"tqdm",
"transformers",
"uvicorn"
],
"ollama": [
"aiohttp",
"aiosqlite",
"blobfile",
"chardet",
"chromadb-client",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"ollama",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"peft",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"requests",
"sqlalchemy[asyncio]",
"tqdm",
"tree_sitter",
"trl",
"uvicorn"
],
"open-benchmark": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"fastapi",
"fire",
"httpx",
"langdetect",
"litellm",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlalchemy[asyncio]",
"sqlite-vec",
"together",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn"
],
"passthrough": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlalchemy[asyncio]",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"remote-vllm": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlalchemy[asyncio]",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"sambanova": [
"aiosqlite",
"blobfile",
"chardet",
"chromadb-client",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"litellm",
"matplotlib",
"mcp",
"nltk",
"numpy",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlalchemy[asyncio]",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"starter": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"fastapi",
"fire",
"fireworks-ai",
"httpx",
"langdetect",
"litellm",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlalchemy[asyncio]",
"sqlite-vec",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"tgi": [
"aiohttp",
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"huggingface_hub",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlalchemy[asyncio]",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"together": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlalchemy[asyncio]",
"together",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"verification": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"fastapi",
"fire",
"httpx",
"langdetect",
"litellm",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlalchemy[asyncio]",
"sqlite-vec",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"vllm-gpu": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlalchemy[asyncio]",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"vllm",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"watsonx": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"ibm_watson_machine_learning",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlalchemy[asyncio]",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
]
}

View file

@ -25,23 +25,7 @@ distribution_spec:
- inline::rag-runtime - inline::rag-runtime
- remote::model-context-protocol - remote::model-context-protocol
- remote::wolfram-alpha - remote::wolfram-alpha
metadata_store: image_type: conda
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db
inference_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/inference_store.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: remote::ollama
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: remote::ollama
provider_model_id: all-minilm:latest
model_type: embedding
image_type: container
additional_pip_packages: additional_pip_packages:
- sqlalchemy[asyncio] - sqlalchemy[asyncio]
- blobfile

View file

@ -13,8 +13,8 @@ from llama_stack.distribution.datatypes import (
ShieldInput, ShieldInput,
ToolGroupInput, ToolGroupInput,
) )
from llama_stack.providers.inline.post_training.huggingface import HuggingFacePostTrainingConfig #from llama_stack.providers.inline.post_training.huggingface import HuggingFacePostTrainingConfig
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig #from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
@ -32,7 +32,6 @@ def get_distribution_template() -> DistributionTemplate:
"tool_runtime": [ "tool_runtime": [
"remote::brave-search", "remote::brave-search",
"remote::tavily-search", "remote::tavily-search",
"inline::rag-runtime",
"remote::model-context-protocol", "remote::model-context-protocol",
"remote::wolfram-alpha", "remote::wolfram-alpha",
], ],
@ -43,11 +42,11 @@ def get_distribution_template() -> DistributionTemplate:
provider_type="remote::ollama", provider_type="remote::ollama",
config=OllamaImplConfig.sample_run_config(), config=OllamaImplConfig.sample_run_config(),
) )
vector_io_provider_faiss = Provider( #vector_io_provider_faiss = Provider(
provider_id="faiss", # provider_id="faiss",
provider_type="inline::faiss", # provider_type="inline::faiss",
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"), # config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
) #)
inference_model = ModelInput( inference_model = ModelInput(
model_id="${env.INFERENCE_MODEL}", model_id="${env.INFERENCE_MODEL}",
provider_id="ollama", provider_id="ollama",
@ -70,10 +69,6 @@ def get_distribution_template() -> DistributionTemplate:
toolgroup_id="builtin::websearch", toolgroup_id="builtin::websearch",
provider_id="tavily-search", provider_id="tavily-search",
), ),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
ToolGroupInput( ToolGroupInput(
toolgroup_id="builtin::wolfram_alpha", toolgroup_id="builtin::wolfram_alpha",
provider_id="wolfram-alpha", provider_id="wolfram-alpha",

View file

@ -24,6 +24,10 @@ providers:
type: sqlite type: sqlite
namespace: null namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db
- provider_id: chromadb
provider_type: remote::chromadb
config:
url: ${env.CHROMADB_URL:http://host.docker.internal:8000}
safety: safety:
- provider_id: llama-guard - provider_id: llama-guard
provider_type: inline::llama-guard provider_type: inline::llama-guard

View file

@ -2,9 +2,9 @@
import { useEffect, useState } from "react"; import { useEffect, useState } from "react";
import { useParams } from "next/navigation"; import { useParams } from "next/navigation";
import LlamaStackClient from "llama-stack-client";
import { ChatCompletion } from "@/lib/types"; import { ChatCompletion } from "@/lib/types";
import { ChatCompletionDetailView } from "@/components/chat-completions/chat-completion-detail"; import { ChatCompletionDetailView } from "@/components/chat-completions/chat-completion-detail";
import { client } from "@/lib/client";
export default function ChatCompletionDetailPage() { export default function ChatCompletionDetailPage() {
const params = useParams(); const params = useParams();
@ -22,10 +22,6 @@ export default function ChatCompletionDetailPage() {
return; return;
} }
const client = new LlamaStackClient({
baseURL: process.env.NEXT_PUBLIC_LLAMA_STACK_BASE_URL,
});
const fetchCompletionDetail = async () => { const fetchCompletionDetail = async () => {
setIsLoading(true); setIsLoading(true);
setError(null); setError(null);

View file

@ -1,45 +1,19 @@
"use client"; "use client";
import React from "react"; import React from "react";
import { usePathname, useParams } from "next/navigation"; import LogsLayout from "@/components/layout/logs-layout";
import {
PageBreadcrumb,
BreadcrumbSegment,
} from "@/components/layout/page-breadcrumb";
import { truncateText } from "@/lib/truncate-text";
export default function ChatCompletionsLayout({ export default function ChatCompletionsLayout({
children, children,
}: { }: {
children: React.ReactNode; children: React.ReactNode;
}) { }) {
const pathname = usePathname();
const params = useParams();
let segments: BreadcrumbSegment[] = [];
// Default for /logs/chat-completions
if (pathname === "/logs/chat-completions") {
segments = [{ label: "Chat Completions" }];
}
// For /logs/chat-completions/[id]
const idParam = params?.id;
if (idParam && typeof idParam === "string") {
segments = [
{ label: "Chat Completions", href: "/logs/chat-completions" },
{ label: `Details (${truncateText(idParam, 20)})` },
];
}
return ( return (
<div className="container mx-auto p-4"> <LogsLayout
<> sectionLabel="Chat Completions"
{segments.length > 0 && ( basePath="/logs/chat-completions"
<PageBreadcrumb segments={segments} className="mb-4" /> >
)}
{children} {children}
</> </LogsLayout>
</div>
); );
} }

View file

@ -1,9 +1,9 @@
"use client"; "use client";
import { useEffect, useState } from "react"; import { useEffect, useState } from "react";
import LlamaStackClient from "llama-stack-client";
import { ChatCompletion } from "@/lib/types"; import { ChatCompletion } from "@/lib/types";
import { ChatCompletionsTable } from "@/components/chat-completions/chat-completion-table"; import { ChatCompletionsTable } from "@/components/chat-completions/chat-completions-table";
import { client } from "@/lib/client";
export default function ChatCompletionsPage() { export default function ChatCompletionsPage() {
const [completions, setCompletions] = useState<ChatCompletion[]>([]); const [completions, setCompletions] = useState<ChatCompletion[]>([]);
@ -11,9 +11,6 @@ export default function ChatCompletionsPage() {
const [error, setError] = useState<Error | null>(null); const [error, setError] = useState<Error | null>(null);
useEffect(() => { useEffect(() => {
const client = new LlamaStackClient({
baseURL: process.env.NEXT_PUBLIC_LLAMA_STACK_BASE_URL,
});
const fetchCompletions = async () => { const fetchCompletions = async () => {
setIsLoading(true); setIsLoading(true);
setError(null); setError(null);
@ -21,7 +18,7 @@ export default function ChatCompletionsPage() {
const response = await client.chat.completions.list(); const response = await client.chat.completions.list();
const data = Array.isArray(response) const data = Array.isArray(response)
? response ? response
: (response as any).data; : (response as { data: ChatCompletion[] }).data;
if (Array.isArray(data)) { if (Array.isArray(data)) {
setCompletions(data); setCompletions(data);
@ -46,7 +43,7 @@ export default function ChatCompletionsPage() {
return ( return (
<ChatCompletionsTable <ChatCompletionsTable
completions={completions} data={completions}
isLoading={isLoading} isLoading={isLoading}
error={error} error={error}
/> />

View file

@ -0,0 +1,125 @@
"use client";
import { useEffect, useState } from "react";
import { useParams } from "next/navigation";
import type { ResponseObject } from "llama-stack-client/resources/responses/responses";
import { OpenAIResponse, InputItemListResponse } from "@/lib/types";
import { ResponseDetailView } from "@/components/responses/responses-detail";
import { client } from "@/lib/client";
export default function ResponseDetailPage() {
const params = useParams();
const id = params.id as string;
const [responseDetail, setResponseDetail] = useState<OpenAIResponse | null>(
null,
);
const [inputItems, setInputItems] = useState<InputItemListResponse | null>(
null,
);
const [isLoading, setIsLoading] = useState<boolean>(true);
const [isLoadingInputItems, setIsLoadingInputItems] = useState<boolean>(true);
const [error, setError] = useState<Error | null>(null);
const [inputItemsError, setInputItemsError] = useState<Error | null>(null);
// Helper function to convert ResponseObject to OpenAIResponse
const convertResponseObject = (
responseData: ResponseObject,
): OpenAIResponse => {
return {
id: responseData.id,
created_at: responseData.created_at,
model: responseData.model,
object: responseData.object,
status: responseData.status,
output: responseData.output as OpenAIResponse["output"],
input: [], // ResponseObject doesn't include input; component uses inputItems prop instead
error: responseData.error,
parallel_tool_calls: responseData.parallel_tool_calls,
previous_response_id: responseData.previous_response_id,
temperature: responseData.temperature,
top_p: responseData.top_p,
truncation: responseData.truncation,
user: responseData.user,
};
};
useEffect(() => {
if (!id) {
setError(new Error("Response ID is missing."));
setIsLoading(false);
return;
}
const fetchResponseDetail = async () => {
setIsLoading(true);
setIsLoadingInputItems(true);
setError(null);
setInputItemsError(null);
setResponseDetail(null);
setInputItems(null);
try {
const [responseResult, inputItemsResult] = await Promise.allSettled([
client.responses.retrieve(id),
client.responses.inputItems.list(id, { order: "asc" }),
]);
// Handle response detail result
if (responseResult.status === "fulfilled") {
const convertedResponse = convertResponseObject(responseResult.value);
setResponseDetail(convertedResponse);
} else {
console.error(
`Error fetching response detail for ID ${id}:`,
responseResult.reason,
);
setError(
responseResult.reason instanceof Error
? responseResult.reason
: new Error("Failed to fetch response detail"),
);
}
// Handle input items result
if (inputItemsResult.status === "fulfilled") {
const inputItemsData =
inputItemsResult.value as unknown as InputItemListResponse;
setInputItems(inputItemsData);
} else {
console.error(
`Error fetching input items for response ID ${id}:`,
inputItemsResult.reason,
);
setInputItemsError(
inputItemsResult.reason instanceof Error
? inputItemsResult.reason
: new Error("Failed to fetch input items"),
);
}
} catch (err) {
console.error(`Unexpected error fetching data for ID ${id}:`, err);
setError(
err instanceof Error ? err : new Error("Unexpected error occurred"),
);
} finally {
setIsLoading(false);
setIsLoadingInputItems(false);
}
};
fetchResponseDetail();
}, [id]);
return (
<ResponseDetailView
response={responseDetail}
inputItems={inputItems}
isLoading={isLoading}
isLoadingInputItems={isLoadingInputItems}
error={error}
inputItemsError={inputItemsError}
id={id}
/>
);
}

View file

@ -0,0 +1,16 @@
"use client";
import React from "react";
import LogsLayout from "@/components/layout/logs-layout";
export default function ResponsesLayout({
children,
}: {
children: React.ReactNode;
}) {
return (
<LogsLayout sectionLabel="Responses" basePath="/logs/responses">
{children}
</LogsLayout>
);
}

View file

@ -1,7 +1,66 @@
export default function Responses() { "use client";
import { useEffect, useState } from "react";
import type { ResponseListResponse } from "llama-stack-client/resources/responses/responses";
import { OpenAIResponse } from "@/lib/types";
import { ResponsesTable } from "@/components/responses/responses-table";
import { client } from "@/lib/client";
export default function ResponsesPage() {
const [responses, setResponses] = useState<OpenAIResponse[]>([]);
const [isLoading, setIsLoading] = useState<boolean>(true);
const [error, setError] = useState<Error | null>(null);
// Helper function to convert ResponseListResponse.Data to OpenAIResponse
const convertResponseListData = (
responseData: ResponseListResponse.Data,
): OpenAIResponse => {
return {
id: responseData.id,
created_at: responseData.created_at,
model: responseData.model,
object: responseData.object,
status: responseData.status,
output: responseData.output as OpenAIResponse["output"],
input: responseData.input as OpenAIResponse["input"],
error: responseData.error,
parallel_tool_calls: responseData.parallel_tool_calls,
previous_response_id: responseData.previous_response_id,
temperature: responseData.temperature,
top_p: responseData.top_p,
truncation: responseData.truncation,
user: responseData.user,
};
};
useEffect(() => {
const fetchResponses = async () => {
setIsLoading(true);
setError(null);
try {
const response = await client.responses.list();
const responseListData = response as ResponseListResponse;
const convertedResponses: OpenAIResponse[] = responseListData.data.map(
convertResponseListData,
);
setResponses(convertedResponses);
} catch (err) {
console.error("Error fetching responses:", err);
setError(
err instanceof Error ? err : new Error("Failed to fetch responses"),
);
setResponses([]);
} finally {
setIsLoading(false);
}
};
fetchResponses();
}, []);
return ( return (
<div> <ResponsesTable data={responses} isLoading={isLoading} error={error} />
<h1>Under Construction</h1>
</div>
); );
} }

View file

@ -75,7 +75,7 @@ describe("ChatCompletionDetailView", () => {
/>, />,
); );
expect( expect(
screen.getByText("No details found for completion ID: notfound-id."), screen.getByText("No details found for ID: notfound-id."),
).toBeInTheDocument(); ).toBeInTheDocument();
}); });

View file

@ -3,45 +3,14 @@
import { ChatMessage, ChatCompletion } from "@/lib/types"; import { ChatMessage, ChatCompletion } from "@/lib/types";
import { ChatMessageItem } from "@/components/chat-completions/chat-messasge-item"; import { ChatMessageItem } from "@/components/chat-completions/chat-messasge-item";
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card"; import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
import { Skeleton } from "@/components/ui/skeleton"; import {
DetailLoadingView,
function ChatCompletionDetailLoadingView() { DetailErrorView,
return ( DetailNotFoundView,
<> DetailLayout,
<Skeleton className="h-8 w-3/4 mb-6" /> {/* Title Skeleton */} PropertiesCard,
<div className="flex flex-col md:flex-row gap-6"> PropertyItem,
<div className="flex-grow md:w-2/3 space-y-6"> } from "@/components/layout/detail-layout";
{[...Array(2)].map((_, i) => (
<Card key={`main-skeleton-card-${i}`}>
<CardHeader>
<CardTitle>
<Skeleton className="h-6 w-1/2" />
</CardTitle>
</CardHeader>
<CardContent className="space-y-2">
<Skeleton className="h-4 w-full" />
<Skeleton className="h-4 w-full" />
<Skeleton className="h-4 w-3/4" />
</CardContent>
</Card>
))}
</div>
<div className="md:w-1/3">
<div className="p-4 border rounded-lg shadow-sm bg-white space-y-3">
<Skeleton className="h-6 w-1/3 mb-3" />{" "}
{/* Properties Title Skeleton */}
{[...Array(5)].map((_, i) => (
<div key={`prop-skeleton-${i}`} className="space-y-1">
<Skeleton className="h-4 w-1/4" />
<Skeleton className="h-4 w-1/2" />
</div>
))}
</div>
</div>
</div>
</>
);
}
interface ChatCompletionDetailViewProps { interface ChatCompletionDetailViewProps {
completion: ChatCompletion | null; completion: ChatCompletion | null;
@ -56,39 +25,23 @@ export function ChatCompletionDetailView({
error, error,
id, id,
}: ChatCompletionDetailViewProps) { }: ChatCompletionDetailViewProps) {
const title = "Chat Completion Details";
if (error) { if (error) {
return ( return <DetailErrorView title={title} id={id} error={error} />;
<>
{/* We still want a title for consistency on error pages */}
<h1 className="text-2xl font-bold mb-6">Chat Completion Details</h1>
<p>
Error loading details for ID {id}: {error.message}
</p>
</>
);
} }
if (isLoading) { if (isLoading) {
return <ChatCompletionDetailLoadingView />; return <DetailLoadingView title={title} />;
} }
if (!completion) { if (!completion) {
// This state means: not loading, no error, but no completion data return <DetailNotFoundView title={title} id={id} />;
return (
<>
{/* We still want a title for consistency on not-found pages */}
<h1 className="text-2xl font-bold mb-6">Chat Completion Details</h1>
<p>No details found for completion ID: {id}.</p>
</>
);
} }
// If no error, not loading, and completion exists, render the details: // Main content cards
return ( const mainContent = (
<> <>
<h1 className="text-2xl font-bold mb-6">Chat Completion Details</h1>
<div className="flex flex-col md:flex-row gap-6">
<div className="flex-grow md:w-2/3 space-y-6">
<Card> <Card>
<CardHeader> <CardHeader>
<CardTitle>Input</CardTitle> <CardTitle>Input</CardTitle>
@ -98,13 +51,15 @@ export function ChatCompletionDetailView({
<ChatMessageItem key={`input-msg-${index}`} message={msg} /> <ChatMessageItem key={`input-msg-${index}`} message={msg} />
))} ))}
{completion.choices?.[0]?.message?.tool_calls && {completion.choices?.[0]?.message?.tool_calls &&
Array.isArray(completion.choices[0].message.tool_calls) &&
!completion.input_messages?.some( !completion.input_messages?.some(
(im) => (im) =>
im.role === "assistant" && im.role === "assistant" &&
im.tool_calls && im.tool_calls &&
Array.isArray(im.tool_calls) &&
im.tool_calls.length > 0, im.tool_calls.length > 0,
) && )
completion.choices[0].message.tool_calls.map( ? completion.choices[0].message.tool_calls.map(
(toolCall: any, index: number) => { (toolCall: any, index: number) => {
const assistantToolCallMessage: ChatMessage = { const assistantToolCallMessage: ChatMessage = {
role: "assistant", role: "assistant",
@ -118,7 +73,8 @@ export function ChatCompletionDetailView({
/> />
); );
}, },
)} )
: null}
</CardContent> </CardContent>
</Card> </Card>
@ -138,61 +94,52 @@ export function ChatCompletionDetailView({
)} )}
</CardContent> </CardContent>
</Card> </Card>
</div> </>
);
<div className="md:w-1/3"> // Properties sidebar
<Card> const sidebar = (
<CardHeader> <PropertiesCard>
<CardTitle>Properties</CardTitle> <PropertyItem
</CardHeader> label="Created"
<CardContent> value={new Date(completion.created * 1000).toLocaleString()}
<ul className="space-y-2 text-sm text-gray-600"> />
<li> <PropertyItem label="ID" value={completion.id} />
<strong>Created:</strong>{" "} <PropertyItem label="Model" value={completion.model} />
<span className="text-gray-900 font-medium"> <PropertyItem
{new Date(completion.created * 1000).toLocaleString()} label="Finish Reason"
</span> value={completion.choices?.[0]?.finish_reason || "N/A"}
</li> hasBorder
<li> />
<strong>ID:</strong>{" "} {(() => {
<span className="text-gray-900 font-medium"> const toolCalls = completion.choices?.[0]?.message?.tool_calls;
{completion.id} if (toolCalls && Array.isArray(toolCalls) && toolCalls.length > 0) {
</span> return (
</li> <PropertyItem
<li> label="Functions/Tools Called"
<strong>Model:</strong>{" "} value={
<span className="text-gray-900 font-medium"> <div>
{completion.model}
</span>
</li>
<li className="pt-1 mt-1 border-t border-gray-200">
<strong>Finish Reason:</strong>{" "}
<span className="text-gray-900 font-medium">
{completion.choices?.[0]?.finish_reason || "N/A"}
</span>
</li>
{completion.choices?.[0]?.message?.tool_calls &&
completion.choices[0].message.tool_calls.length > 0 && (
<li className="pt-1 mt-1 border-t border-gray-200">
<strong>Functions/Tools Called:</strong>
<ul className="list-disc list-inside pl-4 mt-1"> <ul className="list-disc list-inside pl-4 mt-1">
{completion.choices[0].message.tool_calls.map( {toolCalls.map((toolCall: any, index: number) => (
(toolCall: any, index: number) => (
<li key={index}> <li key={index}>
<span className="text-gray-900 font-medium"> <span className="text-gray-900 font-medium">
{toolCall.function?.name || "N/A"} {toolCall.function?.name || "N/A"}
</span> </span>
</li> </li>
), ))}
)}
</ul> </ul>
</li>
)}
</ul>
</CardContent>
</Card>
</div> </div>
</div> }
</> hasBorder
/>
);
}
return null;
})()}
</PropertiesCard>
);
return (
<DetailLayout title={title} mainContent={mainContent} sidebar={sidebar} />
); );
} }

View file

@ -1,8 +1,8 @@
import React from "react"; import React from "react";
import { render, screen, fireEvent } from "@testing-library/react"; import { render, screen, fireEvent } from "@testing-library/react";
import "@testing-library/jest-dom"; import "@testing-library/jest-dom";
import { ChatCompletionsTable } from "./chat-completion-table"; import { ChatCompletionsTable } from "./chat-completions-table";
import { ChatCompletion } from "@/lib/types"; // Assuming this path is correct import { ChatCompletion } from "@/lib/types";
// Mock next/navigation // Mock next/navigation
const mockPush = jest.fn(); const mockPush = jest.fn();
@ -13,21 +13,25 @@ jest.mock("next/navigation", () => ({
})); }));
// Mock helper functions // Mock helper functions
// These are hoisted, so their mocks are available throughout the file
jest.mock("@/lib/truncate-text"); jest.mock("@/lib/truncate-text");
jest.mock("@/lib/format-tool-call"); jest.mock("@/lib/format-message-content");
// Import the mocked functions to set up default or specific implementations // Import the mocked functions to set up default or specific implementations
import { truncateText as originalTruncateText } from "@/lib/truncate-text"; import { truncateText as originalTruncateText } from "@/lib/truncate-text";
import { formatToolCallToString as originalFormatToolCallToString } from "@/lib/format-tool-call"; import {
extractTextFromContentPart as originalExtractTextFromContentPart,
extractDisplayableText as originalExtractDisplayableText,
} from "@/lib/format-message-content";
// Cast to jest.Mock for typings // Cast to jest.Mock for typings
const truncateText = originalTruncateText as jest.Mock; const truncateText = originalTruncateText as jest.Mock;
const formatToolCallToString = originalFormatToolCallToString as jest.Mock; const extractTextFromContentPart =
originalExtractTextFromContentPart as jest.Mock;
const extractDisplayableText = originalExtractDisplayableText as jest.Mock;
describe("ChatCompletionsTable", () => { describe("ChatCompletionsTable", () => {
const defaultProps = { const defaultProps = {
completions: [] as ChatCompletion[], data: [] as ChatCompletion[],
isLoading: false, isLoading: false,
error: null, error: null,
}; };
@ -36,28 +40,26 @@ describe("ChatCompletionsTable", () => {
// Reset all mocks before each test // Reset all mocks before each test
mockPush.mockClear(); mockPush.mockClear();
truncateText.mockClear(); truncateText.mockClear();
formatToolCallToString.mockClear(); extractTextFromContentPart.mockClear();
extractDisplayableText.mockClear();
// Default pass-through implementation for tests not focusing on truncation/formatting // Default pass-through implementations
truncateText.mockImplementation((text: string | undefined) => text); truncateText.mockImplementation((text: string | undefined) => text);
formatToolCallToString.mockImplementation((toolCall: any) => extractTextFromContentPart.mockImplementation((content: unknown) =>
toolCall && typeof toolCall === "object" && toolCall.name typeof content === "string" ? content : "extracted text",
? `[DefaultToolCall:${toolCall.name}]` );
: "[InvalidToolCall]", extractDisplayableText.mockImplementation(
(message: unknown) =>
(message as { content?: string })?.content || "extracted output",
); );
}); });
test("renders without crashing with default props", () => { test("renders without crashing with default props", () => {
render(<ChatCompletionsTable {...defaultProps} />); render(<ChatCompletionsTable {...defaultProps} />);
// Check for a unique element that should be present in the non-empty, non-loading, non-error state
// For now, as per Task 1, we will test the empty state message
expect(screen.getByText("No chat completions found.")).toBeInTheDocument(); expect(screen.getByText("No chat completions found.")).toBeInTheDocument();
}); });
test("click on a row navigates to the correct URL", () => { test("click on a row navigates to the correct URL", () => {
const { rerender } = render(<ChatCompletionsTable {...defaultProps} />);
// Simulate a scenario where a completion exists and is clicked
const mockCompletion: ChatCompletion = { const mockCompletion: ChatCompletion = {
id: "comp_123", id: "comp_123",
object: "chat.completion", object: "chat.completion",
@ -73,9 +75,12 @@ describe("ChatCompletionsTable", () => {
input_messages: [{ role: "user", content: "Test input" }], input_messages: [{ role: "user", content: "Test input" }],
}; };
rerender( // Set up mocks to return expected values
<ChatCompletionsTable {...defaultProps} completions={[mockCompletion]} />, extractTextFromContentPart.mockReturnValue("Test input");
); extractDisplayableText.mockReturnValue("Test output");
render(<ChatCompletionsTable {...defaultProps} data={[mockCompletion]} />);
const row = screen.getByText("Test input").closest("tr"); const row = screen.getByText("Test input").closest("tr");
if (row) { if (row) {
fireEvent.click(row); fireEvent.click(row);
@ -91,14 +96,13 @@ describe("ChatCompletionsTable", () => {
<ChatCompletionsTable {...defaultProps} isLoading={true} />, <ChatCompletionsTable {...defaultProps} isLoading={true} />,
); );
// The Skeleton component uses data-slot="skeleton"
const skeletonSelector = '[data-slot="skeleton"]';
// Check for skeleton in the table caption // Check for skeleton in the table caption
const tableCaption = container.querySelector("caption"); const tableCaption = container.querySelector("caption");
expect(tableCaption).toBeInTheDocument(); expect(tableCaption).toBeInTheDocument();
if (tableCaption) { if (tableCaption) {
const captionSkeleton = tableCaption.querySelector(skeletonSelector); const captionSkeleton = tableCaption.querySelector(
'[data-slot="skeleton"]',
);
expect(captionSkeleton).toBeInTheDocument(); expect(captionSkeleton).toBeInTheDocument();
} }
@ -107,16 +111,10 @@ describe("ChatCompletionsTable", () => {
expect(tableBody).toBeInTheDocument(); expect(tableBody).toBeInTheDocument();
if (tableBody) { if (tableBody) {
const bodySkeletons = tableBody.querySelectorAll( const bodySkeletons = tableBody.querySelectorAll(
`td ${skeletonSelector}`, '[data-slot="skeleton"]',
); );
expect(bodySkeletons.length).toBeGreaterThan(0); // Ensure at least one skeleton cell exists expect(bodySkeletons.length).toBeGreaterThan(0);
} }
// General check: ensure multiple skeleton elements are present in the table overall
const allSkeletonsInTable = container.querySelectorAll(
`table ${skeletonSelector}`,
);
expect(allSkeletonsInTable.length).toBeGreaterThan(3); // e.g., caption + at least one row of 3 cells, or just a few
}); });
}); });
@ -140,14 +138,14 @@ describe("ChatCompletionsTable", () => {
{...defaultProps} {...defaultProps}
error={{ name: "Error", message: "" }} error={{ name: "Error", message: "" }}
/>, />,
); // Error with empty message );
expect( expect(
screen.getByText("Error fetching data: An unknown error occurred"), screen.getByText("Error fetching data: An unknown error occurred"),
).toBeInTheDocument(); ).toBeInTheDocument();
}); });
test("renders default error message when error prop is an object without message", () => { test("renders default error message when error prop is an object without message", () => {
render(<ChatCompletionsTable {...defaultProps} error={{} as Error} />); // Empty error object render(<ChatCompletionsTable {...defaultProps} error={{} as Error} />);
expect( expect(
screen.getByText("Error fetching data: An unknown error occurred"), screen.getByText("Error fetching data: An unknown error occurred"),
).toBeInTheDocument(); ).toBeInTheDocument();
@ -155,14 +153,8 @@ describe("ChatCompletionsTable", () => {
}); });
describe("Empty State", () => { describe("Empty State", () => {
test('renders "No chat completions found." and no table when completions array is empty', () => { test('renders "No chat completions found." and no table when data array is empty', () => {
render( render(<ChatCompletionsTable data={[]} isLoading={false} error={null} />);
<ChatCompletionsTable
completions={[]}
isLoading={false}
error={null}
/>,
);
expect( expect(
screen.getByText("No chat completions found."), screen.getByText("No chat completions found."),
).toBeInTheDocument(); ).toBeInTheDocument();
@ -179,7 +171,7 @@ describe("ChatCompletionsTable", () => {
{ {
id: "comp_1", id: "comp_1",
object: "chat.completion", object: "chat.completion",
created: 1710000000, // Fixed timestamp for test created: 1710000000,
model: "llama-test-model", model: "llama-test-model",
choices: [ choices: [
{ {
@ -206,9 +198,22 @@ describe("ChatCompletionsTable", () => {
}, },
]; ];
// Set up mocks to return expected values
extractTextFromContentPart.mockImplementation((content: unknown) => {
if (content === "Test input") return "Test input";
if (content === "Another input") return "Another input";
return "extracted text";
});
extractDisplayableText.mockImplementation((message: unknown) => {
const msg = message as { content?: string };
if (msg?.content === "Test output") return "Test output";
if (msg?.content === "Another output") return "Another output";
return "extracted output";
});
render( render(
<ChatCompletionsTable <ChatCompletionsTable
completions={mockCompletions} data={mockCompletions}
isLoading={false} isLoading={false}
error={null} error={null}
/>, />,
@ -242,7 +247,7 @@ describe("ChatCompletionsTable", () => {
}); });
}); });
describe("Text Truncation and Tool Call Formatting", () => { describe("Text Truncation and Content Extraction", () => {
test("truncates long input and output text", () => { test("truncates long input and output text", () => {
// Specific mock implementation for this test // Specific mock implementation for this test
truncateText.mockImplementation( truncateText.mockImplementation(
@ -259,6 +264,10 @@ describe("ChatCompletionsTable", () => {
"This is a very long input message that should be truncated."; "This is a very long input message that should be truncated.";
const longOutput = const longOutput =
"This is a very long output message that should also be truncated."; "This is a very long output message that should also be truncated.";
extractTextFromContentPart.mockReturnValue(longInput);
extractDisplayableText.mockReturnValue(longOutput);
const mockCompletions = [ const mockCompletions = [
{ {
id: "comp_trunc", id: "comp_trunc",
@ -278,7 +287,7 @@ describe("ChatCompletionsTable", () => {
render( render(
<ChatCompletionsTable <ChatCompletionsTable
completions={mockCompletions} data={mockCompletions}
isLoading={false} isLoading={false}
error={null} error={null}
/>, />,
@ -289,52 +298,50 @@ describe("ChatCompletionsTable", () => {
longInput.slice(0, 10) + "...", longInput.slice(0, 10) + "...",
); );
expect(truncatedTexts.length).toBe(2); // one for input, one for output expect(truncatedTexts.length).toBe(2); // one for input, one for output
// Optionally, verify each one is in the document if getAllByText doesn't throw on not found
truncatedTexts.forEach((textElement) => truncatedTexts.forEach((textElement) =>
expect(textElement).toBeInTheDocument(), expect(textElement).toBeInTheDocument(),
); );
}); });
test("formats tool call output using formatToolCallToString", () => { test("uses content extraction functions correctly", () => {
// Specific mock implementation for this test const mockCompletion = {
formatToolCallToString.mockImplementation( id: "comp_extract",
(toolCall: any) => `[TOOL:${toolCall.name}]`,
);
// Ensure no truncation interferes for this specific test for clarity of tool call format
truncateText.mockImplementation((text: string | undefined) => text);
const toolCall = { name: "search", args: { query: "llama" } };
const mockCompletions = [
{
id: "comp_tool",
object: "chat.completion", object: "chat.completion",
created: 1710003000, created: 1710003000,
model: "llama-tool-model", model: "llama-extract-model",
choices: [ choices: [
{ {
index: 0, index: 0,
message: { message: { role: "assistant", content: "Extracted output" },
role: "assistant",
content: "Tool output", // Content that will be prepended
tool_calls: [toolCall],
},
finish_reason: "stop", finish_reason: "stop",
}, },
], ],
input_messages: [{ role: "user", content: "Tool input" }], input_messages: [{ role: "user", content: "Extracted input" }],
}, };
];
extractTextFromContentPart.mockReturnValue("Extracted input");
extractDisplayableText.mockReturnValue("Extracted output");
render( render(
<ChatCompletionsTable <ChatCompletionsTable
completions={mockCompletions} data={[mockCompletion]}
isLoading={false} isLoading={false}
error={null} error={null}
/>, />,
); );
// The component concatenates message.content and the formatted tool call // Verify the extraction functions were called
expect(screen.getByText("Tool output [TOOL:search]")).toBeInTheDocument(); expect(extractTextFromContentPart).toHaveBeenCalledWith(
"Extracted input",
);
expect(extractDisplayableText).toHaveBeenCalledWith({
role: "assistant",
content: "Extracted output",
});
// Verify the extracted content is displayed
expect(screen.getByText("Extracted input")).toBeInTheDocument();
expect(screen.getByText("Extracted output")).toBeInTheDocument();
}); });
}); });
}); });

View file

@ -0,0 +1,43 @@
"use client";
import { ChatCompletion } from "@/lib/types";
import { LogsTable, LogTableRow } from "@/components/logs/logs-table";
import {
extractTextFromContentPart,
extractDisplayableText,
} from "@/lib/format-message-content";
interface ChatCompletionsTableProps {
data: ChatCompletion[];
isLoading: boolean;
error: Error | null;
}
function formatChatCompletionToRow(completion: ChatCompletion): LogTableRow {
return {
id: completion.id,
input: extractTextFromContentPart(completion.input_messages?.[0]?.content),
output: extractDisplayableText(completion.choices?.[0]?.message),
model: completion.model,
createdTime: new Date(completion.created * 1000).toLocaleString(),
detailPath: `/logs/chat-completions/${completion.id}`,
};
}
export function ChatCompletionsTable({
data,
isLoading,
error,
}: ChatCompletionsTableProps) {
const formattedData = data.map(formatChatCompletionToRow);
return (
<LogsTable
data={formattedData}
isLoading={isLoading}
error={error}
caption="A list of your recent chat completions."
emptyMessage="No chat completions found."
/>
);
}

View file

@ -4,45 +4,10 @@ import { ChatMessage } from "@/lib/types";
import React from "react"; import React from "react";
import { formatToolCallToString } from "@/lib/format-tool-call"; import { formatToolCallToString } from "@/lib/format-tool-call";
import { extractTextFromContentPart } from "@/lib/format-message-content"; import { extractTextFromContentPart } from "@/lib/format-message-content";
import {
// Sub-component or helper for the common label + content structure MessageBlock,
const MessageBlock: React.FC<{ ToolCallBlock,
label: string; } from "@/components/ui/message-components";
labelDetail?: string;
content: React.ReactNode;
}> = ({ label, labelDetail, content }) => {
return (
<div>
<p className="py-1 font-semibold text-gray-800 mb-1">
{label}
{labelDetail && (
<span className="text-xs text-gray-500 font-normal ml-1">
{labelDetail}
</span>
)}
</p>
<div className="py-1">{content}</div>
</div>
);
};
interface ToolCallBlockProps {
children: React.ReactNode;
className?: string;
}
const ToolCallBlock = ({ children, className }: ToolCallBlockProps) => {
// Common styling for both function call arguments and tool output blocks
// Let's use slate-50 background as it's good for code-like content.
const baseClassName =
"p-3 bg-slate-50 border border-slate-200 rounded-md text-sm";
return (
<div className={`${baseClassName} ${className || ""}`}>
<pre className="whitespace-pre-wrap text-xs">{children}</pre>
</div>
);
};
interface ChatMessageItemProps { interface ChatMessageItemProps {
message: ChatMessage; message: ChatMessage;
@ -65,7 +30,11 @@ export function ChatMessageItem({ message }: ChatMessageItemProps) {
); );
case "assistant": case "assistant":
if (message.tool_calls && message.tool_calls.length > 0) { if (
message.tool_calls &&
Array.isArray(message.tool_calls) &&
message.tool_calls.length > 0
) {
return ( return (
<> <>
{message.tool_calls.map((toolCall: any, index: number) => { {message.tool_calls.map((toolCall: any, index: number) => {

View file

@ -0,0 +1,141 @@
import React from "react";
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
import { Skeleton } from "@/components/ui/skeleton";
export function DetailLoadingView({ title }: { title: string }) {
return (
<>
<Skeleton className="h-8 w-3/4 mb-6" /> {/* Title Skeleton */}
<div className="flex flex-col md:flex-row gap-6">
<div className="flex-grow md:w-2/3 space-y-6">
{[...Array(2)].map((_, i) => (
<Card key={`main-skeleton-card-${i}`}>
<CardHeader>
<CardTitle>
<Skeleton className="h-6 w-1/2" />
</CardTitle>
</CardHeader>
<CardContent className="space-y-2">
<Skeleton className="h-4 w-full" />
<Skeleton className="h-4 w-full" />
<Skeleton className="h-4 w-3/4" />
</CardContent>
</Card>
))}
</div>
<div className="md:w-1/3">
<div className="p-4 border rounded-lg shadow-sm bg-white space-y-3">
<Skeleton className="h-6 w-1/3 mb-3" />{" "}
{/* Properties Title Skeleton */}
{[...Array(5)].map((_, i) => (
<div key={`prop-skeleton-${i}`} className="space-y-1">
<Skeleton className="h-4 w-1/4" />
<Skeleton className="h-4 w-1/2" />
</div>
))}
</div>
</div>
</div>
</>
);
}
export function DetailErrorView({
title,
id,
error,
}: {
title: string;
id: string;
error: Error;
}) {
return (
<>
<h1 className="text-2xl font-bold mb-6">{title}</h1>
<p>
Error loading details for ID {id}: {error.message}
</p>
</>
);
}
export function DetailNotFoundView({
title,
id,
}: {
title: string;
id: string;
}) {
return (
<>
<h1 className="text-2xl font-bold mb-6">{title}</h1>
<p>No details found for ID: {id}.</p>
</>
);
}
export interface PropertyItemProps {
label: string;
value: React.ReactNode;
className?: string;
hasBorder?: boolean;
}
export function PropertyItem({
label,
value,
className = "",
hasBorder = false,
}: PropertyItemProps) {
return (
<li
className={`${hasBorder ? "pt-1 mt-1 border-t border-gray-200" : ""} ${className}`}
>
<strong>{label}:</strong>{" "}
{typeof value === "string" || typeof value === "number" ? (
<span className="text-gray-900 font-medium">{value}</span>
) : (
value
)}
</li>
);
}
export interface PropertiesCardProps {
children: React.ReactNode;
}
export function PropertiesCard({ children }: PropertiesCardProps) {
return (
<Card>
<CardHeader>
<CardTitle>Properties</CardTitle>
</CardHeader>
<CardContent>
<ul className="space-y-2 text-sm text-gray-600">{children}</ul>
</CardContent>
</Card>
);
}
export interface DetailLayoutProps {
title: string;
mainContent: React.ReactNode;
sidebar: React.ReactNode;
}
export function DetailLayout({
title,
mainContent,
sidebar,
}: DetailLayoutProps) {
return (
<>
<h1 className="text-2xl font-bold mb-6">{title}</h1>
<div className="flex flex-col md:flex-row gap-6">
<div className="flex-grow md:w-2/3 space-y-6">{mainContent}</div>
<div className="md:w-1/3">{sidebar}</div>
</div>
</>
);
}

View file

@ -0,0 +1,49 @@
"use client";
import React from "react";
import { usePathname, useParams } from "next/navigation";
import {
PageBreadcrumb,
BreadcrumbSegment,
} from "@/components/layout/page-breadcrumb";
import { truncateText } from "@/lib/truncate-text";
interface LogsLayoutProps {
children: React.ReactNode;
sectionLabel: string;
basePath: string;
}
export default function LogsLayout({
children,
sectionLabel,
basePath,
}: LogsLayoutProps) {
const pathname = usePathname();
const params = useParams();
let segments: BreadcrumbSegment[] = [];
if (pathname === basePath) {
segments = [{ label: sectionLabel }];
}
const idParam = params?.id;
if (idParam && typeof idParam === "string") {
segments = [
{ label: sectionLabel, href: basePath },
{ label: `Details (${truncateText(idParam, 20)})` },
];
}
return (
<div className="container mx-auto p-4">
<>
{segments.length > 0 && (
<PageBreadcrumb segments={segments} className="mb-4" />
)}
{children}
</>
</div>
);
}

View file

@ -0,0 +1,350 @@
import React from "react";
import { render, screen, fireEvent } from "@testing-library/react";
import "@testing-library/jest-dom";
import { LogsTable, LogTableRow } from "./logs-table";
// Mock next/navigation
const mockPush = jest.fn();
jest.mock("next/navigation", () => ({
useRouter: () => ({
push: mockPush,
}),
}));
// Mock helper functions
jest.mock("@/lib/truncate-text");
// Import the mocked functions
import { truncateText as originalTruncateText } from "@/lib/truncate-text";
// Cast to jest.Mock for typings
const truncateText = originalTruncateText as jest.Mock;
describe("LogsTable", () => {
const defaultProps = {
data: [] as LogTableRow[],
isLoading: false,
error: null,
caption: "Test table caption",
emptyMessage: "No data found",
};
beforeEach(() => {
// Reset all mocks before each test
mockPush.mockClear();
truncateText.mockClear();
// Default pass-through implementation
truncateText.mockImplementation((text: string | undefined) => text);
});
test("renders without crashing with default props", () => {
render(<LogsTable {...defaultProps} />);
expect(screen.getByText("No data found")).toBeInTheDocument();
});
test("click on a row navigates to the correct URL", () => {
const mockData: LogTableRow[] = [
{
id: "row_123",
input: "Test input",
output: "Test output",
model: "test-model",
createdTime: "2024-01-01 12:00:00",
detailPath: "/test/path/row_123",
},
];
render(<LogsTable {...defaultProps} data={mockData} />);
const row = screen.getByText("Test input").closest("tr");
if (row) {
fireEvent.click(row);
expect(mockPush).toHaveBeenCalledWith("/test/path/row_123");
} else {
throw new Error('Row with "Test input" not found for router mock test.');
}
});
describe("Loading State", () => {
test("renders skeleton UI when isLoading is true", () => {
const { container } = render(
<LogsTable {...defaultProps} isLoading={true} />,
);
// Check for skeleton in the table caption
const tableCaption = container.querySelector("caption");
expect(tableCaption).toBeInTheDocument();
if (tableCaption) {
const captionSkeleton = tableCaption.querySelector(
'[data-slot="skeleton"]',
);
expect(captionSkeleton).toBeInTheDocument();
}
// Check for skeletons in the table body cells
const tableBody = container.querySelector("tbody");
expect(tableBody).toBeInTheDocument();
if (tableBody) {
const bodySkeletons = tableBody.querySelectorAll(
'[data-slot="skeleton"]',
);
expect(bodySkeletons.length).toBeGreaterThan(0);
}
// Check that table headers are still rendered
expect(screen.getByText("Input")).toBeInTheDocument();
expect(screen.getByText("Output")).toBeInTheDocument();
expect(screen.getByText("Model")).toBeInTheDocument();
expect(screen.getByText("Created")).toBeInTheDocument();
});
test("renders correct number of skeleton rows", () => {
const { container } = render(
<LogsTable {...defaultProps} isLoading={true} />,
);
const skeletonRows = container.querySelectorAll("tbody tr");
expect(skeletonRows.length).toBe(3); // Should render 3 skeleton rows
});
});
describe("Error State", () => {
test("renders error message when error prop is provided", () => {
const errorMessage = "Network Error";
render(
<LogsTable
{...defaultProps}
error={{ name: "Error", message: errorMessage }}
/>,
);
expect(
screen.getByText(`Error fetching data: ${errorMessage}`),
).toBeInTheDocument();
});
test("renders default error message when error.message is not available", () => {
render(
<LogsTable {...defaultProps} error={{ name: "Error", message: "" }} />,
);
expect(
screen.getByText("Error fetching data: An unknown error occurred"),
).toBeInTheDocument();
});
test("renders default error message when error prop is an object without message", () => {
render(<LogsTable {...defaultProps} error={{} as Error} />);
expect(
screen.getByText("Error fetching data: An unknown error occurred"),
).toBeInTheDocument();
});
test("does not render table when in error state", () => {
render(
<LogsTable
{...defaultProps}
error={{ name: "Error", message: "Test error" }}
/>,
);
const table = screen.queryByRole("table");
expect(table).not.toBeInTheDocument();
});
});
describe("Empty State", () => {
test("renders custom empty message when data array is empty", () => {
render(
<LogsTable
{...defaultProps}
data={[]}
emptyMessage="Custom empty message"
/>,
);
expect(screen.getByText("Custom empty message")).toBeInTheDocument();
// Ensure that the table structure is NOT rendered in the empty state
const table = screen.queryByRole("table");
expect(table).not.toBeInTheDocument();
});
});
describe("Data Rendering", () => {
test("renders table caption, headers, and data correctly", () => {
const mockData: LogTableRow[] = [
{
id: "row_1",
input: "First input",
output: "First output",
model: "model-1",
createdTime: "2024-01-01 12:00:00",
detailPath: "/path/1",
},
{
id: "row_2",
input: "Second input",
output: "Second output",
model: "model-2",
createdTime: "2024-01-02 13:00:00",
detailPath: "/path/2",
},
];
render(
<LogsTable
{...defaultProps}
data={mockData}
caption="Custom table caption"
/>,
);
// Table caption
expect(screen.getByText("Custom table caption")).toBeInTheDocument();
// Table headers
expect(screen.getByText("Input")).toBeInTheDocument();
expect(screen.getByText("Output")).toBeInTheDocument();
expect(screen.getByText("Model")).toBeInTheDocument();
expect(screen.getByText("Created")).toBeInTheDocument();
// Data rows
expect(screen.getByText("First input")).toBeInTheDocument();
expect(screen.getByText("First output")).toBeInTheDocument();
expect(screen.getByText("model-1")).toBeInTheDocument();
expect(screen.getByText("2024-01-01 12:00:00")).toBeInTheDocument();
expect(screen.getByText("Second input")).toBeInTheDocument();
expect(screen.getByText("Second output")).toBeInTheDocument();
expect(screen.getByText("model-2")).toBeInTheDocument();
expect(screen.getByText("2024-01-02 13:00:00")).toBeInTheDocument();
});
test("applies correct CSS classes to table rows", () => {
const mockData: LogTableRow[] = [
{
id: "row_1",
input: "Test input",
output: "Test output",
model: "test-model",
createdTime: "2024-01-01 12:00:00",
detailPath: "/test/path",
},
];
render(<LogsTable {...defaultProps} data={mockData} />);
const row = screen.getByText("Test input").closest("tr");
expect(row).toHaveClass("cursor-pointer");
expect(row).toHaveClass("hover:bg-muted/50");
});
test("applies correct alignment to Created column", () => {
const mockData: LogTableRow[] = [
{
id: "row_1",
input: "Test input",
output: "Test output",
model: "test-model",
createdTime: "2024-01-01 12:00:00",
detailPath: "/test/path",
},
];
render(<LogsTable {...defaultProps} data={mockData} />);
const createdCell = screen.getByText("2024-01-01 12:00:00").closest("td");
expect(createdCell).toHaveClass("text-right");
});
});
describe("Text Truncation", () => {
test("truncates input and output text using truncateText function", () => {
// Mock truncateText to return truncated versions
truncateText.mockImplementation((text: string | undefined) => {
if (typeof text === "string" && text.length > 10) {
return text.slice(0, 10) + "...";
}
return text;
});
const longInput =
"This is a very long input text that should be truncated";
const longOutput =
"This is a very long output text that should be truncated";
const mockData: LogTableRow[] = [
{
id: "row_1",
input: longInput,
output: longOutput,
model: "test-model",
createdTime: "2024-01-01 12:00:00",
detailPath: "/test/path",
},
];
render(<LogsTable {...defaultProps} data={mockData} />);
// Verify truncateText was called
expect(truncateText).toHaveBeenCalledWith(longInput);
expect(truncateText).toHaveBeenCalledWith(longOutput);
// Verify truncated text is displayed
const truncatedTexts = screen.getAllByText("This is a ...");
expect(truncatedTexts).toHaveLength(2); // one for input, one for output
truncatedTexts.forEach((textElement) =>
expect(textElement).toBeInTheDocument(),
);
});
test("does not truncate model names", () => {
const mockData: LogTableRow[] = [
{
id: "row_1",
input: "Test input",
output: "Test output",
model: "very-long-model-name-that-should-not-be-truncated",
createdTime: "2024-01-01 12:00:00",
detailPath: "/test/path",
},
];
render(<LogsTable {...defaultProps} data={mockData} />);
// Model name should not be passed to truncateText
expect(truncateText).not.toHaveBeenCalledWith(
"very-long-model-name-that-should-not-be-truncated",
);
// Full model name should be displayed
expect(
screen.getByText("very-long-model-name-that-should-not-be-truncated"),
).toBeInTheDocument();
});
});
describe("Accessibility", () => {
test("table has proper role and structure", () => {
const mockData: LogTableRow[] = [
{
id: "row_1",
input: "Test input",
output: "Test output",
model: "test-model",
createdTime: "2024-01-01 12:00:00",
detailPath: "/test/path",
},
];
render(<LogsTable {...defaultProps} data={mockData} />);
const table = screen.getByRole("table");
expect(table).toBeInTheDocument();
const columnHeaders = screen.getAllByRole("columnheader");
expect(columnHeaders).toHaveLength(4);
const rows = screen.getAllByRole("row");
expect(rows).toHaveLength(2); // 1 header row + 1 data row
});
});
});

View file

@ -1,12 +1,7 @@
"use client"; "use client";
import { useRouter } from "next/navigation"; import { useRouter } from "next/navigation";
import { ChatCompletion } from "@/lib/types";
import { truncateText } from "@/lib/truncate-text"; import { truncateText } from "@/lib/truncate-text";
import {
extractTextFromContentPart,
extractDisplayableText,
} from "@/lib/format-message-content";
import { import {
Table, Table,
TableBody, TableBody,
@ -18,17 +13,31 @@ import {
} from "@/components/ui/table"; } from "@/components/ui/table";
import { Skeleton } from "@/components/ui/skeleton"; import { Skeleton } from "@/components/ui/skeleton";
interface ChatCompletionsTableProps { // Generic table row data interface
completions: ChatCompletion[]; export interface LogTableRow {
isLoading: boolean; id: string;
error: Error | null; input: string;
output: string;
model: string;
createdTime: string;
detailPath: string;
} }
export function ChatCompletionsTable({ interface LogsTableProps {
completions, data: LogTableRow[];
isLoading: boolean;
error: Error | null;
caption: string;
emptyMessage: string;
}
export function LogsTable({
data,
isLoading, isLoading,
error, error,
}: ChatCompletionsTableProps) { caption,
emptyMessage,
}: LogsTableProps) {
const router = useRouter(); const router = useRouter();
const tableHeader = ( const tableHeader = (
@ -77,41 +86,25 @@ export function ChatCompletionsTable({
); );
} }
if (completions.length === 0) { if (data.length === 0) {
return <p>No chat completions found.</p>; return <p>{emptyMessage}</p>;
} }
return ( return (
<Table> <Table>
<TableCaption>A list of your recent chat completions.</TableCaption> <TableCaption>{caption}</TableCaption>
{tableHeader} {tableHeader}
<TableBody> <TableBody>
{completions.map((completion) => ( {data.map((row) => (
<TableRow <TableRow
key={completion.id} key={row.id}
onClick={() => onClick={() => router.push(row.detailPath)}
router.push(`/logs/chat-completions/${completion.id}`)
}
className="cursor-pointer hover:bg-muted/50" className="cursor-pointer hover:bg-muted/50"
> >
<TableCell> <TableCell>{truncateText(row.input)}</TableCell>
{truncateText( <TableCell>{truncateText(row.output)}</TableCell>
extractTextFromContentPart( <TableCell>{row.model}</TableCell>
completion.input_messages?.[0]?.content, <TableCell className="text-right">{row.createdTime}</TableCell>
),
)}
</TableCell>
<TableCell>
{(() => {
const message = completion.choices?.[0]?.message;
const outputText = extractDisplayableText(message);
return truncateText(outputText);
})()}
</TableCell>
<TableCell>{completion.model}</TableCell>
<TableCell className="text-right">
{new Date(completion.created * 1000).toLocaleString()}
</TableCell>
</TableRow> </TableRow>
))} ))}
</TableBody> </TableBody>

View file

@ -0,0 +1,56 @@
import { useFunctionCallGrouping } from "../hooks/function-call-grouping";
import { ItemRenderer } from "../items/item-renderer";
import { GroupedFunctionCallItemComponent } from "../items/grouped-function-call-item";
import {
isFunctionCallItem,
isFunctionCallOutputItem,
AnyResponseItem,
} from "../utils/item-types";
interface GroupedItemsDisplayProps {
items: AnyResponseItem[];
keyPrefix: string;
defaultRole?: string;
}
export function GroupedItemsDisplay({
items,
keyPrefix,
defaultRole = "unknown",
}: GroupedItemsDisplayProps) {
const groupedItems = useFunctionCallGrouping(items);
return (
<>
{groupedItems.map((groupedItem) => {
// If this is a function call with an output, render the grouped component
if (
groupedItem.outputItem &&
isFunctionCallItem(groupedItem.item) &&
isFunctionCallOutputItem(groupedItem.outputItem)
) {
return (
<GroupedFunctionCallItemComponent
key={`${keyPrefix}-${groupedItem.index}`}
functionCall={groupedItem.item}
output={groupedItem.outputItem}
index={groupedItem.index}
keyPrefix={keyPrefix}
/>
);
}
// Otherwise, render the individual item
return (
<ItemRenderer
key={`${keyPrefix}-${groupedItem.index}`}
item={groupedItem.item}
index={groupedItem.index}
keyPrefix={keyPrefix}
defaultRole={defaultRole}
/>
);
})}
</>
);
}

View file

@ -0,0 +1,92 @@
import { useMemo } from "react";
import {
isFunctionCallOutputItem,
AnyResponseItem,
FunctionCallOutputItem,
} from "../utils/item-types";
export interface GroupedItem {
item: AnyResponseItem;
index: number;
outputItem?: AnyResponseItem;
outputIndex?: number;
}
/**
* Hook to group function calls with their corresponding outputs
* @param items Array of items to group
* @returns Array of grouped items with their outputs
*/
export function useFunctionCallGrouping(
items: AnyResponseItem[],
): GroupedItem[] {
return useMemo(() => {
const groupedItems: GroupedItem[] = [];
const processedIndices = new Set<number>();
// Build a map of call_id to indices for function_call_output items
const callIdToIndices = new Map<string, number[]>();
for (let i = 0; i < items.length; i++) {
const item = items[i];
if (isFunctionCallOutputItem(item)) {
if (!callIdToIndices.has(item.call_id)) {
callIdToIndices.set(item.call_id, []);
}
callIdToIndices.get(item.call_id)!.push(i);
}
}
// Process items and group function calls with their outputs
for (let i = 0; i < items.length; i++) {
if (processedIndices.has(i)) {
continue;
}
const currentItem = items[i];
if (
currentItem.type === "function_call" &&
"name" in currentItem &&
"call_id" in currentItem
) {
const functionCallId = currentItem.call_id as string;
let outputIndex = -1;
let outputItem: FunctionCallOutputItem | null = null;
const relatedIndices = callIdToIndices.get(functionCallId) || [];
for (const idx of relatedIndices) {
const potentialOutput = items[idx];
outputIndex = idx;
outputItem = potentialOutput as FunctionCallOutputItem;
break;
}
if (outputItem && outputIndex !== -1) {
// Group function call with its function_call_output
groupedItems.push({
item: currentItem,
index: i,
outputItem,
outputIndex,
});
// Mark both items as processed
processedIndices.add(i);
processedIndices.add(outputIndex);
// Matching function call and output found, skip to next item
continue;
}
}
// render normally
groupedItems.push({
item: currentItem,
index: i,
});
processedIndices.add(i);
}
return groupedItems;
}, [items]);
}

View file

@ -0,0 +1,29 @@
import {
MessageBlock,
ToolCallBlock,
} from "@/components/ui/message-components";
import { FunctionCallItem } from "../utils/item-types";
interface FunctionCallItemProps {
item: FunctionCallItem;
index: number;
keyPrefix: string;
}
export function FunctionCallItemComponent({
item,
index,
keyPrefix,
}: FunctionCallItemProps) {
const name = item.name || "unknown";
const args = item.arguments || "{}";
const formattedFunctionCall = `${name}(${args})`;
return (
<MessageBlock
key={`${keyPrefix}-${index}`}
label="Function Call"
content={<ToolCallBlock>{formattedFunctionCall}</ToolCallBlock>}
/>
);
}

View file

@ -0,0 +1,37 @@
import {
MessageBlock,
ToolCallBlock,
} from "@/components/ui/message-components";
import { BaseItem } from "../utils/item-types";
interface GenericItemProps {
item: BaseItem;
index: number;
keyPrefix: string;
}
export function GenericItemComponent({
item,
index,
keyPrefix,
}: GenericItemProps) {
// Handle other types like function calls, tool outputs, etc.
const itemData = item as Record<string, unknown>;
const content = itemData.content
? typeof itemData.content === "string"
? itemData.content
: JSON.stringify(itemData.content, null, 2)
: JSON.stringify(itemData, null, 2);
const label = keyPrefix === "input" ? "Input" : "Output";
return (
<MessageBlock
key={`${keyPrefix}-${index}`}
label={label}
labelDetail={`(${itemData.type})`}
content={<ToolCallBlock>{content}</ToolCallBlock>}
/>
);
}

View file

@ -0,0 +1,54 @@
import {
MessageBlock,
ToolCallBlock,
} from "@/components/ui/message-components";
import { FunctionCallItem, FunctionCallOutputItem } from "../utils/item-types";
interface GroupedFunctionCallItemProps {
functionCall: FunctionCallItem;
output: FunctionCallOutputItem;
index: number;
keyPrefix: string;
}
export function GroupedFunctionCallItemComponent({
functionCall,
output,
index,
keyPrefix,
}: GroupedFunctionCallItemProps) {
const name = functionCall.name || "unknown";
const args = functionCall.arguments || "{}";
// Extract the output content from function_call_output
let outputContent = "";
if (output.output) {
outputContent =
typeof output.output === "string"
? output.output
: JSON.stringify(output.output);
} else {
outputContent = JSON.stringify(output, null, 2);
}
const functionCallContent = (
<div>
<div className="mb-2">
<span className="text-sm text-gray-600">Arguments</span>
<ToolCallBlock>{`${name}(${args})`}</ToolCallBlock>
</div>
<div>
<span className="text-sm text-gray-600">Output</span>
<ToolCallBlock>{outputContent}</ToolCallBlock>
</div>
</div>
);
return (
<MessageBlock
key={`${keyPrefix}-${index}`}
label="Function Call"
content={functionCallContent}
/>
);
}

View file

@ -0,0 +1,6 @@
export { MessageItemComponent } from "./message-item";
export { FunctionCallItemComponent } from "./function-call-item";
export { WebSearchItemComponent } from "./web-search-item";
export { GenericItemComponent } from "./generic-item";
export { GroupedFunctionCallItemComponent } from "./grouped-function-call-item";
export { ItemRenderer } from "./item-renderer";

View file

@ -0,0 +1,60 @@
import {
isMessageItem,
isFunctionCallItem,
isWebSearchCallItem,
AnyResponseItem,
} from "../utils/item-types";
import { MessageItemComponent } from "./message-item";
import { FunctionCallItemComponent } from "./function-call-item";
import { WebSearchItemComponent } from "./web-search-item";
import { GenericItemComponent } from "./generic-item";
interface ItemRendererProps {
item: AnyResponseItem;
index: number;
keyPrefix: string;
defaultRole?: string;
}
export function ItemRenderer({
item,
index,
keyPrefix,
defaultRole = "unknown",
}: ItemRendererProps) {
if (isMessageItem(item)) {
return (
<MessageItemComponent
item={item}
index={index}
keyPrefix={keyPrefix}
defaultRole={defaultRole}
/>
);
}
if (isFunctionCallItem(item)) {
return (
<FunctionCallItemComponent
item={item}
index={index}
keyPrefix={keyPrefix}
/>
);
}
if (isWebSearchCallItem(item)) {
return (
<WebSearchItemComponent item={item} index={index} keyPrefix={keyPrefix} />
);
}
// Fallback to generic item for unknown types
return (
<GenericItemComponent
item={item as any}
index={index}
keyPrefix={keyPrefix}
/>
);
}

View file

@ -0,0 +1,41 @@
import { MessageBlock } from "@/components/ui/message-components";
import { MessageItem } from "../utils/item-types";
interface MessageItemProps {
item: MessageItem;
index: number;
keyPrefix: string;
defaultRole?: string;
}
export function MessageItemComponent({
item,
index,
keyPrefix,
defaultRole = "unknown",
}: MessageItemProps) {
let content = "";
if (typeof item.content === "string") {
content = item.content;
} else if (Array.isArray(item.content)) {
content = item.content
.map((c) => {
return c.type === "input_text" || c.type === "output_text"
? c.text
: JSON.stringify(c);
})
.join(" ");
}
const role = item.role || defaultRole;
const label = role.charAt(0).toUpperCase() + role.slice(1);
return (
<MessageBlock
key={`${keyPrefix}-${index}`}
label={label}
content={content}
/>
);
}

View file

@ -0,0 +1,28 @@
import {
MessageBlock,
ToolCallBlock,
} from "@/components/ui/message-components";
import { WebSearchCallItem } from "../utils/item-types";
interface WebSearchItemProps {
item: WebSearchCallItem;
index: number;
keyPrefix: string;
}
export function WebSearchItemComponent({
item,
index,
keyPrefix,
}: WebSearchItemProps) {
const formattedWebSearch = `web_search_call(status: ${item.status})`;
return (
<MessageBlock
key={`${keyPrefix}-${index}`}
label="Function Call"
labelDetail="(Web Search)"
content={<ToolCallBlock>{formattedWebSearch}</ToolCallBlock>}
/>
);
}

View file

@ -0,0 +1,777 @@
import React from "react";
import { render, screen } from "@testing-library/react";
import "@testing-library/jest-dom";
import { ResponseDetailView } from "./responses-detail";
import { OpenAIResponse, InputItemListResponse } from "@/lib/types";
describe("ResponseDetailView", () => {
const defaultProps = {
response: null,
inputItems: null,
isLoading: false,
isLoadingInputItems: false,
error: null,
inputItemsError: null,
id: "test_id",
};
describe("Loading State", () => {
test("renders loading skeleton when isLoading is true", () => {
const { container } = render(
<ResponseDetailView {...defaultProps} isLoading={true} />,
);
// Check for skeleton elements
const skeletons = container.querySelectorAll('[data-slot="skeleton"]');
expect(skeletons.length).toBeGreaterThan(0);
// The title is replaced by a skeleton when loading, so we shouldn't expect the text
});
});
describe("Error State", () => {
test("renders error message when error prop is provided", () => {
const errorMessage = "Network Error";
render(
<ResponseDetailView
{...defaultProps}
error={{ name: "Error", message: errorMessage }}
/>,
);
expect(screen.getByText("Responses Details")).toBeInTheDocument();
// The error message is split across elements, so we check for parts
expect(
screen.getByText(/Error loading details for ID/),
).toBeInTheDocument();
expect(screen.getByText(/test_id/)).toBeInTheDocument();
expect(screen.getByText(/Network Error/)).toBeInTheDocument();
});
test("renders default error message when error.message is not available", () => {
render(
<ResponseDetailView
{...defaultProps}
error={{ name: "Error", message: "" }}
/>,
);
expect(
screen.getByText(/Error loading details for ID/),
).toBeInTheDocument();
expect(screen.getByText(/test_id/)).toBeInTheDocument();
});
});
describe("Not Found State", () => {
test("renders not found message when response is null and not loading/error", () => {
render(<ResponseDetailView {...defaultProps} response={null} />);
expect(screen.getByText("Responses Details")).toBeInTheDocument();
// The message is split across elements
expect(screen.getByText(/No details found for ID:/)).toBeInTheDocument();
expect(screen.getByText(/test_id/)).toBeInTheDocument();
});
});
describe("Response Data Rendering", () => {
const mockResponse: OpenAIResponse = {
id: "resp_123",
object: "response",
created_at: 1710000000,
model: "llama-test-model",
status: "completed",
output: [
{
type: "message",
role: "assistant",
content: "Test response output",
},
],
input: [
{
type: "message",
role: "user",
content: "Test input message",
},
],
temperature: 0.7,
top_p: 0.9,
parallel_tool_calls: true,
previous_response_id: "prev_resp_456",
};
test("renders response data with input and output sections", () => {
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
// Check main sections
expect(screen.getByText("Responses Details")).toBeInTheDocument();
expect(screen.getByText("Input")).toBeInTheDocument();
expect(screen.getByText("Output")).toBeInTheDocument();
// Check input content
expect(screen.getByText("Test input message")).toBeInTheDocument();
expect(screen.getByText("User")).toBeInTheDocument();
// Check output content
expect(screen.getByText("Test response output")).toBeInTheDocument();
expect(screen.getByText("Assistant")).toBeInTheDocument();
});
test("renders properties sidebar with all response metadata", () => {
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
// Check properties - use regex to handle text split across elements
expect(screen.getByText(/Created/)).toBeInTheDocument();
expect(
screen.getByText(new Date(1710000000 * 1000).toLocaleString()),
).toBeInTheDocument();
// Check for the specific ID label (not Previous Response ID)
expect(
screen.getByText((content, element) => {
return element?.tagName === "STRONG" && content === "ID:";
}),
).toBeInTheDocument();
expect(screen.getByText("resp_123")).toBeInTheDocument();
expect(screen.getByText(/Model/)).toBeInTheDocument();
expect(screen.getByText("llama-test-model")).toBeInTheDocument();
expect(screen.getByText(/Status/)).toBeInTheDocument();
expect(screen.getByText("completed")).toBeInTheDocument();
expect(screen.getByText(/Temperature/)).toBeInTheDocument();
expect(screen.getByText("0.7")).toBeInTheDocument();
expect(screen.getByText(/Top P/)).toBeInTheDocument();
expect(screen.getByText("0.9")).toBeInTheDocument();
expect(screen.getByText(/Parallel Tool Calls/)).toBeInTheDocument();
expect(screen.getByText("Yes")).toBeInTheDocument();
expect(screen.getByText(/Previous Response ID/)).toBeInTheDocument();
expect(screen.getByText("prev_resp_456")).toBeInTheDocument();
});
test("handles optional properties correctly", () => {
const minimalResponse: OpenAIResponse = {
id: "resp_minimal",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [],
input: [],
};
render(
<ResponseDetailView {...defaultProps} response={minimalResponse} />,
);
// Should show required properties
expect(screen.getByText("resp_minimal")).toBeInTheDocument();
expect(screen.getByText("test-model")).toBeInTheDocument();
expect(screen.getByText("completed")).toBeInTheDocument();
// Should not show optional properties
expect(screen.queryByText("Temperature")).not.toBeInTheDocument();
expect(screen.queryByText("Top P")).not.toBeInTheDocument();
expect(screen.queryByText("Parallel Tool Calls")).not.toBeInTheDocument();
expect(
screen.queryByText("Previous Response ID"),
).not.toBeInTheDocument();
});
test("renders error information when response has error", () => {
const errorResponse: OpenAIResponse = {
...mockResponse,
error: {
code: "invalid_request",
message: "The request was invalid",
},
};
render(<ResponseDetailView {...defaultProps} response={errorResponse} />);
// The error is shown in the properties sidebar, not as a separate "Error" label
expect(
screen.getByText("invalid_request: The request was invalid"),
).toBeInTheDocument();
});
});
describe("Input Items Handling", () => {
const mockResponse: OpenAIResponse = {
id: "resp_123",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [{ type: "message", role: "assistant", content: "output" }],
input: [{ type: "message", role: "user", content: "fallback input" }],
};
test("shows loading state for input items", () => {
render(
<ResponseDetailView
{...defaultProps}
response={mockResponse}
isLoadingInputItems={true}
/>,
);
// Check for skeleton loading in input items section
const { container } = render(
<ResponseDetailView
{...defaultProps}
response={mockResponse}
isLoadingInputItems={true}
/>,
);
const skeletons = container.querySelectorAll('[data-slot="skeleton"]');
expect(skeletons.length).toBeGreaterThan(0);
});
test("shows error message for input items with fallback", () => {
render(
<ResponseDetailView
{...defaultProps}
response={mockResponse}
inputItemsError={{
name: "Error",
message: "Failed to load input items",
}}
/>,
);
expect(
screen.getByText(
"Error loading input items: Failed to load input items",
),
).toBeInTheDocument();
expect(
screen.getByText("Falling back to response input data."),
).toBeInTheDocument();
// Should still show fallback input data
expect(screen.getByText("fallback input")).toBeInTheDocument();
});
test("uses input items data when available", () => {
const mockInputItems: InputItemListResponse = {
object: "list",
data: [
{
type: "message",
role: "user",
content: "input from items API",
},
],
};
render(
<ResponseDetailView
{...defaultProps}
response={mockResponse}
inputItems={mockInputItems}
/>,
);
// Should show input items data, not response.input
expect(screen.getByText("input from items API")).toBeInTheDocument();
expect(screen.queryByText("fallback input")).not.toBeInTheDocument();
});
test("falls back to response.input when input items is empty", () => {
const emptyInputItems: InputItemListResponse = {
object: "list",
data: [],
};
render(
<ResponseDetailView
{...defaultProps}
response={mockResponse}
inputItems={emptyInputItems}
/>,
);
// Should show fallback input data
expect(screen.getByText("fallback input")).toBeInTheDocument();
});
test("shows no input message when no data available", () => {
const responseWithoutInput: OpenAIResponse = {
...mockResponse,
input: [],
};
render(
<ResponseDetailView
{...defaultProps}
response={responseWithoutInput}
inputItems={null}
/>,
);
expect(screen.getByText("No input data available.")).toBeInTheDocument();
});
});
describe("Input Display Components", () => {
test("renders string content input correctly", () => {
const mockResponse: OpenAIResponse = {
id: "resp_123",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [],
input: [
{
type: "message",
role: "user",
content: "Simple string input",
},
],
};
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
expect(screen.getByText("Simple string input")).toBeInTheDocument();
expect(screen.getByText("User")).toBeInTheDocument();
});
test("renders array content input correctly", () => {
const mockResponse: OpenAIResponse = {
id: "resp_123",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [],
input: [
{
type: "message",
role: "user",
content: [
{ type: "input_text", text: "First part" },
{ type: "output_text", text: "Second part" },
],
},
],
};
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
expect(screen.getByText("First part Second part")).toBeInTheDocument();
expect(screen.getByText("User")).toBeInTheDocument();
});
test("renders non-message input types correctly", () => {
const mockResponse: OpenAIResponse = {
id: "resp_123",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [],
input: [
{
type: "function_call",
content: "function call content",
},
],
};
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
expect(screen.getByText("function call content")).toBeInTheDocument();
// Use getAllByText to find the specific "Input" with the type detail
const inputElements = screen.getAllByText("Input");
expect(inputElements.length).toBeGreaterThan(0);
expect(screen.getByText("(function_call)")).toBeInTheDocument();
});
test("handles input with object content", () => {
const mockResponse: OpenAIResponse = {
id: "resp_123",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [],
input: [
{
type: "custom_type",
content: JSON.stringify({ key: "value", nested: { data: "test" } }),
},
],
};
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
// Should show JSON stringified content (without quotes around keys in the rendered output)
expect(screen.getByText(/key.*value/)).toBeInTheDocument();
// Use getAllByText to find the specific "Input" with the type detail
const inputElements = screen.getAllByText("Input");
expect(inputElements.length).toBeGreaterThan(0);
expect(screen.getByText("(custom_type)")).toBeInTheDocument();
});
test("renders function call input correctly", () => {
const mockResponse: OpenAIResponse = {
id: "resp_123",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [],
input: [
{
type: "function_call",
id: "call_456",
status: "completed",
name: "input_function",
arguments: '{"param": "value"}',
},
],
};
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
expect(
screen.getByText('input_function({"param": "value"})'),
).toBeInTheDocument();
expect(screen.getByText("Function Call")).toBeInTheDocument();
});
test("renders web search call input correctly", () => {
const mockResponse: OpenAIResponse = {
id: "resp_123",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [],
input: [
{
type: "web_search_call",
id: "search_789",
status: "completed",
},
],
};
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
expect(
screen.getByText("web_search_call(status: completed)"),
).toBeInTheDocument();
expect(screen.getByText("Function Call")).toBeInTheDocument();
expect(screen.getByText("(Web Search)")).toBeInTheDocument();
});
});
describe("Output Display Components", () => {
test("renders message output with string content", () => {
const mockResponse: OpenAIResponse = {
id: "resp_123",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [
{
type: "message",
role: "assistant",
content: "Simple string output",
},
],
input: [],
};
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
expect(screen.getByText("Simple string output")).toBeInTheDocument();
expect(screen.getByText("Assistant")).toBeInTheDocument();
});
test("renders message output with array content", () => {
const mockResponse: OpenAIResponse = {
id: "resp_123",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [
{
type: "message",
role: "assistant",
content: [
{ type: "output_text", text: "First output" },
{ type: "input_text", text: "Second output" },
],
},
],
input: [],
};
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
expect(
screen.getByText("First output Second output"),
).toBeInTheDocument();
expect(screen.getByText("Assistant")).toBeInTheDocument();
});
test("renders function call output correctly", () => {
const mockResponse: OpenAIResponse = {
id: "resp_123",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [
{
type: "function_call",
id: "call_123",
status: "completed",
name: "search_function",
arguments: '{"query": "test"}',
},
],
input: [],
};
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
expect(
screen.getByText('search_function({"query": "test"})'),
).toBeInTheDocument();
expect(screen.getByText("Function Call")).toBeInTheDocument();
});
test("renders function call output without arguments", () => {
const mockResponse: OpenAIResponse = {
id: "resp_123",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [
{
type: "function_call",
id: "call_123",
status: "completed",
name: "simple_function",
},
],
input: [],
};
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
expect(screen.getByText("simple_function({})")).toBeInTheDocument();
expect(screen.getByText(/Function Call/)).toBeInTheDocument();
});
test("renders web search call output correctly", () => {
const mockResponse: OpenAIResponse = {
id: "resp_123",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [
{
type: "web_search_call",
id: "search_123",
status: "completed",
},
],
input: [],
};
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
expect(
screen.getByText("web_search_call(status: completed)"),
).toBeInTheDocument();
expect(screen.getByText(/Function Call/)).toBeInTheDocument();
expect(screen.getByText("(Web Search)")).toBeInTheDocument();
});
test("renders unknown output types with JSON fallback", () => {
const mockResponse: OpenAIResponse = {
id: "resp_123",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [
{
type: "unknown_type",
custom_field: "custom_value",
data: { nested: "object" },
} as any,
],
input: [],
};
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
// Should show JSON stringified content
expect(
screen.getByText(/custom_field.*custom_value/),
).toBeInTheDocument();
expect(screen.getByText("(unknown_type)")).toBeInTheDocument();
});
test("shows no output message when output array is empty", () => {
const mockResponse: OpenAIResponse = {
id: "resp_123",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [],
input: [],
};
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
expect(screen.getByText("No output data available.")).toBeInTheDocument();
});
test("groups function call with its output correctly", () => {
const mockResponse: OpenAIResponse = {
id: "resp_123",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [
{
type: "function_call",
id: "call_123",
status: "completed",
name: "get_weather",
arguments: '{"city": "Tokyo"}',
},
{
type: "message",
role: "assistant",
call_id: "call_123",
content: "sunny and warm",
} as any, // Using any to bypass the type restriction for this test
],
input: [],
};
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
// Should show the function call and message as separate items (not grouped)
expect(screen.getByText("Function Call")).toBeInTheDocument();
expect(
screen.getByText('get_weather({"city": "Tokyo"})'),
).toBeInTheDocument();
expect(screen.getByText("Assistant")).toBeInTheDocument();
expect(screen.getByText("sunny and warm")).toBeInTheDocument();
// Should NOT have the grouped "Arguments" and "Output" labels
expect(screen.queryByText("Arguments")).not.toBeInTheDocument();
});
test("groups function call with function_call_output correctly", () => {
const mockResponse: OpenAIResponse = {
id: "resp_123",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [
{
type: "function_call",
call_id: "call_123",
status: "completed",
name: "get_weather",
arguments: '{"city": "Tokyo"}',
},
{
type: "function_call_output",
id: "fc_68364957013081...",
status: "completed",
call_id: "call_123",
output: "sunny and warm",
} as any, // Using any to bypass the type restriction for this test
],
input: [],
};
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
// Should show the function call grouped with its clean output
expect(screen.getByText("Function Call")).toBeInTheDocument();
expect(screen.getByText("Arguments")).toBeInTheDocument();
expect(
screen.getByText('get_weather({"city": "Tokyo"})'),
).toBeInTheDocument();
// Use getAllByText since there are multiple "Output" elements (card title and output label)
const outputElements = screen.getAllByText("Output");
expect(outputElements.length).toBeGreaterThan(0);
expect(screen.getByText("sunny and warm")).toBeInTheDocument();
});
});
describe("Edge Cases and Error Handling", () => {
test("handles missing role in message input", () => {
const mockResponse: OpenAIResponse = {
id: "resp_123",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [],
input: [
{
type: "message",
content: "Message without role",
},
],
};
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
expect(screen.getByText("Message without role")).toBeInTheDocument();
expect(screen.getByText("Unknown")).toBeInTheDocument(); // Default role
});
test("handles missing name in function call output", () => {
const mockResponse: OpenAIResponse = {
id: "resp_123",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [
{
type: "function_call",
id: "call_123",
status: "completed",
},
],
input: [],
};
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
// When name is missing, it falls back to JSON.stringify of the entire output
const functionCallElements = screen.getAllByText(/function_call/);
expect(functionCallElements.length).toBeGreaterThan(0);
expect(screen.getByText(/call_123/)).toBeInTheDocument();
});
});
});

View file

@ -0,0 +1,171 @@
"use client";
import { OpenAIResponse, InputItemListResponse } from "@/lib/types";
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
import { Skeleton } from "@/components/ui/skeleton";
import {
DetailLoadingView,
DetailErrorView,
DetailNotFoundView,
DetailLayout,
PropertiesCard,
PropertyItem,
} from "@/components/layout/detail-layout";
import { GroupedItemsDisplay } from "./grouping/grouped-items-display";
interface ResponseDetailViewProps {
response: OpenAIResponse | null;
inputItems: InputItemListResponse | null;
isLoading: boolean;
isLoadingInputItems: boolean;
error: Error | null;
inputItemsError: Error | null;
id: string;
}
export function ResponseDetailView({
response,
inputItems,
isLoading,
isLoadingInputItems,
error,
inputItemsError,
id,
}: ResponseDetailViewProps) {
const title = "Responses Details";
if (error) {
return <DetailErrorView title={title} id={id} error={error} />;
}
if (isLoading) {
return <DetailLoadingView title={title} />;
}
if (!response) {
return <DetailNotFoundView title={title} id={id} />;
}
// Main content cards
const mainContent = (
<>
<Card>
<CardHeader>
<CardTitle>Input</CardTitle>
</CardHeader>
<CardContent>
{/* Show loading state for input items */}
{isLoadingInputItems ? (
<div className="space-y-2">
<Skeleton className="h-4 w-full" />
<Skeleton className="h-4 w-3/4" />
<Skeleton className="h-4 w-1/2" />
</div>
) : inputItemsError ? (
<div className="text-red-500 text-sm">
Error loading input items: {inputItemsError.message}
<br />
<span className="text-gray-500 text-xs">
Falling back to response input data.
</span>
</div>
) : null}
{/* Display input items if available, otherwise fall back to response.input */}
{(() => {
const dataToDisplay =
inputItems?.data && inputItems.data.length > 0
? inputItems.data
: response.input;
if (dataToDisplay && dataToDisplay.length > 0) {
return (
<GroupedItemsDisplay
items={dataToDisplay}
keyPrefix="input"
defaultRole="unknown"
/>
);
} else {
return (
<p className="text-gray-500 italic text-sm">
No input data available.
</p>
);
}
})()}
</CardContent>
</Card>
<Card>
<CardHeader>
<CardTitle>Output</CardTitle>
</CardHeader>
<CardContent>
{response.output?.length > 0 ? (
<GroupedItemsDisplay
items={response.output}
keyPrefix="output"
defaultRole="assistant"
/>
) : (
<p className="text-gray-500 italic text-sm">
No output data available.
</p>
)}
</CardContent>
</Card>
</>
);
// Properties sidebar
const sidebar = (
<PropertiesCard>
<PropertyItem
label="Created"
value={new Date(response.created_at * 1000).toLocaleString()}
/>
<PropertyItem label="ID" value={response.id} />
<PropertyItem label="Model" value={response.model} />
<PropertyItem label="Status" value={response.status} hasBorder />
{response.temperature && (
<PropertyItem
label="Temperature"
value={response.temperature}
hasBorder
/>
)}
{response.top_p && <PropertyItem label="Top P" value={response.top_p} />}
{response.parallel_tool_calls && (
<PropertyItem
label="Parallel Tool Calls"
value={response.parallel_tool_calls ? "Yes" : "No"}
/>
)}
{response.previous_response_id && (
<PropertyItem
label="Previous Response ID"
value={
<span className="text-xs">{response.previous_response_id}</span>
}
hasBorder
/>
)}
{response.error && (
<PropertyItem
label="Error"
value={
<span className="text-red-900 font-medium">
{response.error.code}: {response.error.message}
</span>
}
className="pt-1 mt-1 border-t border-red-200"
/>
)}
</PropertiesCard>
);
return (
<DetailLayout title={title} mainContent={mainContent} sidebar={sidebar} />
);
}

View file

@ -0,0 +1,537 @@
import React from "react";
import { render, screen, fireEvent } from "@testing-library/react";
import "@testing-library/jest-dom";
import { ResponsesTable } from "./responses-table";
import { OpenAIResponse } from "@/lib/types";
// Mock next/navigation
const mockPush = jest.fn();
jest.mock("next/navigation", () => ({
useRouter: () => ({
push: mockPush,
}),
}));
// Mock helper functions
jest.mock("@/lib/truncate-text");
// Import the mocked functions
import { truncateText as originalTruncateText } from "@/lib/truncate-text";
// Cast to jest.Mock for typings
const truncateText = originalTruncateText as jest.Mock;
describe("ResponsesTable", () => {
const defaultProps = {
data: [] as OpenAIResponse[],
isLoading: false,
error: null,
};
beforeEach(() => {
// Reset all mocks before each test
mockPush.mockClear();
truncateText.mockClear();
// Default pass-through implementation
truncateText.mockImplementation((text: string | undefined) => text);
});
test("renders without crashing with default props", () => {
render(<ResponsesTable {...defaultProps} />);
expect(screen.getByText("No responses found.")).toBeInTheDocument();
});
test("click on a row navigates to the correct URL", () => {
const mockResponse: OpenAIResponse = {
id: "resp_123",
object: "response",
created_at: Math.floor(Date.now() / 1000),
model: "llama-test-model",
status: "completed",
output: [
{
type: "message",
role: "assistant",
content: "Test output",
},
],
input: [
{
type: "message",
role: "user",
content: "Test input",
},
],
};
render(<ResponsesTable {...defaultProps} data={[mockResponse]} />);
const row = screen.getByText("Test input").closest("tr");
if (row) {
fireEvent.click(row);
expect(mockPush).toHaveBeenCalledWith("/logs/responses/resp_123");
} else {
throw new Error('Row with "Test input" not found for router mock test.');
}
});
describe("Loading State", () => {
test("renders skeleton UI when isLoading is true", () => {
const { container } = render(
<ResponsesTable {...defaultProps} isLoading={true} />,
);
// Check for skeleton in the table caption
const tableCaption = container.querySelector("caption");
expect(tableCaption).toBeInTheDocument();
if (tableCaption) {
const captionSkeleton = tableCaption.querySelector(
'[data-slot="skeleton"]',
);
expect(captionSkeleton).toBeInTheDocument();
}
// Check for skeletons in the table body cells
const tableBody = container.querySelector("tbody");
expect(tableBody).toBeInTheDocument();
if (tableBody) {
const bodySkeletons = tableBody.querySelectorAll(
'[data-slot="skeleton"]',
);
expect(bodySkeletons.length).toBeGreaterThan(0);
}
});
});
describe("Error State", () => {
test("renders error message when error prop is provided", () => {
const errorMessage = "Network Error";
render(
<ResponsesTable
{...defaultProps}
error={{ name: "Error", message: errorMessage }}
/>,
);
expect(
screen.getByText(`Error fetching data: ${errorMessage}`),
).toBeInTheDocument();
});
test("renders default error message when error.message is not available", () => {
render(
<ResponsesTable
{...defaultProps}
error={{ name: "Error", message: "" }}
/>,
);
expect(
screen.getByText("Error fetching data: An unknown error occurred"),
).toBeInTheDocument();
});
test("renders default error message when error prop is an object without message", () => {
render(<ResponsesTable {...defaultProps} error={{} as Error} />);
expect(
screen.getByText("Error fetching data: An unknown error occurred"),
).toBeInTheDocument();
});
});
describe("Empty State", () => {
test('renders "No responses found." and no table when data array is empty', () => {
render(<ResponsesTable data={[]} isLoading={false} error={null} />);
expect(screen.getByText("No responses found.")).toBeInTheDocument();
// Ensure that the table structure is NOT rendered in the empty state
const table = screen.queryByRole("table");
expect(table).not.toBeInTheDocument();
});
});
describe("Data Rendering", () => {
test("renders table caption, headers, and response data correctly", () => {
const mockResponses = [
{
id: "resp_1",
object: "response" as const,
created_at: 1710000000,
model: "llama-test-model",
status: "completed",
output: [
{
type: "message" as const,
role: "assistant" as const,
content: "Test output",
},
],
input: [
{
type: "message",
role: "user",
content: "Test input",
},
],
},
{
id: "resp_2",
object: "response" as const,
created_at: 1710001000,
model: "llama-another-model",
status: "completed",
output: [
{
type: "message" as const,
role: "assistant" as const,
content: "Another output",
},
],
input: [
{
type: "message",
role: "user",
content: "Another input",
},
],
},
];
render(
<ResponsesTable data={mockResponses} isLoading={false} error={null} />,
);
// Table caption
expect(
screen.getByText("A list of your recent responses."),
).toBeInTheDocument();
// Table headers
expect(screen.getByText("Input")).toBeInTheDocument();
expect(screen.getByText("Output")).toBeInTheDocument();
expect(screen.getByText("Model")).toBeInTheDocument();
expect(screen.getByText("Created")).toBeInTheDocument();
// Data rows
expect(screen.getByText("Test input")).toBeInTheDocument();
expect(screen.getByText("Test output")).toBeInTheDocument();
expect(screen.getByText("llama-test-model")).toBeInTheDocument();
expect(
screen.getByText(new Date(1710000000 * 1000).toLocaleString()),
).toBeInTheDocument();
expect(screen.getByText("Another input")).toBeInTheDocument();
expect(screen.getByText("Another output")).toBeInTheDocument();
expect(screen.getByText("llama-another-model")).toBeInTheDocument();
expect(
screen.getByText(new Date(1710001000 * 1000).toLocaleString()),
).toBeInTheDocument();
});
});
describe("Input Text Extraction", () => {
test("extracts text from string content", () => {
const mockResponse: OpenAIResponse = {
id: "resp_string",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [{ type: "message", role: "assistant", content: "output" }],
input: [
{
type: "message",
role: "user",
content: "Simple string input",
},
],
};
render(
<ResponsesTable data={[mockResponse]} isLoading={false} error={null} />,
);
expect(screen.getByText("Simple string input")).toBeInTheDocument();
});
test("extracts text from array content with input_text type", () => {
const mockResponse: OpenAIResponse = {
id: "resp_array",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [{ type: "message", role: "assistant", content: "output" }],
input: [
{
type: "message",
role: "user",
content: [
{ type: "input_text", text: "Array input text" },
{ type: "input_text", text: "Should not be used" },
],
},
],
};
render(
<ResponsesTable data={[mockResponse]} isLoading={false} error={null} />,
);
expect(screen.getByText("Array input text")).toBeInTheDocument();
});
test("returns empty string when no message input found", () => {
const mockResponse: OpenAIResponse = {
id: "resp_no_input",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [{ type: "message", role: "assistant", content: "output" }],
input: [
{
type: "other_type",
content: "Not a message",
},
],
};
const { container } = render(
<ResponsesTable data={[mockResponse]} isLoading={false} error={null} />,
);
// Find the input cell (first cell in the data row) and verify it's empty
const inputCell = container.querySelector("tbody tr td:first-child");
expect(inputCell).toBeInTheDocument();
expect(inputCell).toHaveTextContent("");
});
});
describe("Output Text Extraction", () => {
test("extracts text from string message content", () => {
const mockResponse: OpenAIResponse = {
id: "resp_string_output",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [
{
type: "message",
role: "assistant",
content: "Simple string output",
},
],
input: [{ type: "message", content: "input" }],
};
render(
<ResponsesTable data={[mockResponse]} isLoading={false} error={null} />,
);
expect(screen.getByText("Simple string output")).toBeInTheDocument();
});
test("extracts text from array message content with output_text type", () => {
const mockResponse: OpenAIResponse = {
id: "resp_array_output",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [
{
type: "message",
role: "assistant",
content: [
{ type: "output_text", text: "Array output text" },
{ type: "output_text", text: "Should not be used" },
],
},
],
input: [{ type: "message", content: "input" }],
};
render(
<ResponsesTable data={[mockResponse]} isLoading={false} error={null} />,
);
expect(screen.getByText("Array output text")).toBeInTheDocument();
});
test("formats function call output", () => {
const mockResponse: OpenAIResponse = {
id: "resp_function_call",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [
{
type: "function_call",
id: "call_123",
status: "completed",
name: "search_function",
arguments: '{"query": "test"}',
},
],
input: [{ type: "message", content: "input" }],
};
render(
<ResponsesTable data={[mockResponse]} isLoading={false} error={null} />,
);
expect(
screen.getByText('search_function({"query": "test"})'),
).toBeInTheDocument();
});
test("formats function call output without arguments", () => {
const mockResponse: OpenAIResponse = {
id: "resp_function_no_args",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [
{
type: "function_call",
id: "call_123",
status: "completed",
name: "simple_function",
},
],
input: [{ type: "message", content: "input" }],
};
render(
<ResponsesTable data={[mockResponse]} isLoading={false} error={null} />,
);
expect(screen.getByText("simple_function({})")).toBeInTheDocument();
});
test("formats web search call output", () => {
const mockResponse: OpenAIResponse = {
id: "resp_web_search",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [
{
type: "web_search_call",
id: "search_123",
status: "completed",
},
],
input: [{ type: "message", content: "input" }],
};
render(
<ResponsesTable data={[mockResponse]} isLoading={false} error={null} />,
);
expect(
screen.getByText("web_search_call(status: completed)"),
).toBeInTheDocument();
});
test("falls back to JSON.stringify for unknown tool call types", () => {
const mockResponse: OpenAIResponse = {
id: "resp_unknown_tool",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [
{
type: "unknown_call",
id: "unknown_123",
status: "completed",
custom_field: "custom_value",
} as any,
],
input: [{ type: "message", content: "input" }],
};
render(
<ResponsesTable data={[mockResponse]} isLoading={false} error={null} />,
);
// Should contain the JSON stringified version
expect(screen.getByText(/unknown_call/)).toBeInTheDocument();
});
test("falls back to JSON.stringify for entire output when no message or tool call found", () => {
const mockResponse: OpenAIResponse = {
id: "resp_fallback",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [
{
type: "unknown_type",
data: "some data",
} as any,
],
input: [{ type: "message", content: "input" }],
};
render(
<ResponsesTable data={[mockResponse]} isLoading={false} error={null} />,
);
// Should contain the JSON stringified version of the output array
expect(screen.getByText(/unknown_type/)).toBeInTheDocument();
});
});
describe("Text Truncation", () => {
test("truncates long input and output text", () => {
// Specific mock implementation for this test
truncateText.mockImplementation(
(text: string | undefined, maxLength?: number) => {
const defaultTestMaxLength = 10;
const effectiveMaxLength = maxLength ?? defaultTestMaxLength;
return typeof text === "string" && text.length > effectiveMaxLength
? text.slice(0, effectiveMaxLength) + "..."
: text;
},
);
const longInput =
"This is a very long input message that should be truncated.";
const longOutput =
"This is a very long output message that should also be truncated.";
const mockResponse: OpenAIResponse = {
id: "resp_trunc",
object: "response",
created_at: 1710002000,
model: "llama-trunc-model",
status: "completed",
output: [
{
type: "message",
role: "assistant",
content: longOutput,
},
],
input: [
{
type: "message",
role: "user",
content: longInput,
},
],
};
render(
<ResponsesTable data={[mockResponse]} isLoading={false} error={null} />,
);
// The truncated text should be present for both input and output
const truncatedTexts = screen.getAllByText(
longInput.slice(0, 10) + "...",
);
expect(truncatedTexts.length).toBe(2); // one for input, one for output
truncatedTexts.forEach((textElement) =>
expect(textElement).toBeInTheDocument(),
);
});
});
});

View file

@ -0,0 +1,117 @@
"use client";
import {
OpenAIResponse,
ResponseInput,
ResponseInputMessageContent,
} from "@/lib/types";
import { LogsTable, LogTableRow } from "@/components/logs/logs-table";
import {
isMessageInput,
isMessageItem,
isFunctionCallItem,
isWebSearchCallItem,
MessageItem,
FunctionCallItem,
WebSearchCallItem,
} from "./utils/item-types";
interface ResponsesTableProps {
data: OpenAIResponse[];
isLoading: boolean;
error: Error | null;
}
function getInputText(response: OpenAIResponse): string {
const firstInput = response.input.find(isMessageInput);
if (firstInput) {
return extractContentFromItem(firstInput);
}
return "";
}
function getOutputText(response: OpenAIResponse): string {
const firstMessage = response.output.find((item) =>
isMessageItem(item as any),
);
if (firstMessage) {
const content = extractContentFromItem(firstMessage as MessageItem);
if (content) {
return content;
}
}
const functionCall = response.output.find((item) =>
isFunctionCallItem(item as any),
);
if (functionCall) {
return formatFunctionCall(functionCall as FunctionCallItem);
}
const webSearchCall = response.output.find((item) =>
isWebSearchCallItem(item as any),
);
if (webSearchCall) {
return formatWebSearchCall(webSearchCall as WebSearchCallItem);
}
return JSON.stringify(response.output);
}
function extractContentFromItem(item: {
content?: string | ResponseInputMessageContent[];
}): string {
if (!item.content) {
return "";
}
if (typeof item.content === "string") {
return item.content;
} else if (Array.isArray(item.content)) {
const textContent = item.content.find(
(c: ResponseInputMessageContent) =>
c.type === "input_text" || c.type === "output_text",
);
return textContent?.text || "";
}
return "";
}
function formatFunctionCall(functionCall: FunctionCallItem): string {
const args = functionCall.arguments || "{}";
const name = functionCall.name || "unknown";
return `${name}(${args})`;
}
function formatWebSearchCall(webSearchCall: WebSearchCallItem): string {
return `web_search_call(status: ${webSearchCall.status})`;
}
function formatResponseToRow(response: OpenAIResponse): LogTableRow {
return {
id: response.id,
input: getInputText(response),
output: getOutputText(response),
model: response.model,
createdTime: new Date(response.created_at * 1000).toLocaleString(),
detailPath: `/logs/responses/${response.id}`,
};
}
export function ResponsesTable({
data,
isLoading,
error,
}: ResponsesTableProps) {
const formattedData = data.map(formatResponseToRow);
return (
<LogsTable
data={formattedData}
isLoading={isLoading}
error={error}
caption="A list of your recent responses."
emptyMessage="No responses found."
/>
);
}

View file

@ -0,0 +1,61 @@
/**
* Type guards for different item types in responses
*/
import type {
ResponseInput,
ResponseOutput,
ResponseMessage,
ResponseToolCall,
} from "@/lib/types";
export interface BaseItem {
type: string;
[key: string]: unknown;
}
export type MessageItem = ResponseMessage;
export type FunctionCallItem = ResponseToolCall & { type: "function_call" };
export type WebSearchCallItem = ResponseToolCall & { type: "web_search_call" };
export type FunctionCallOutputItem = BaseItem & {
type: "function_call_output";
call_id: string;
output?: string | object;
};
export type AnyResponseItem =
| ResponseInput
| ResponseOutput
| FunctionCallOutputItem;
export function isMessageInput(
item: ResponseInput,
): item is ResponseInput & { type: "message" } {
return item.type === "message";
}
export function isMessageItem(item: AnyResponseItem): item is MessageItem {
return item.type === "message" && "content" in item;
}
export function isFunctionCallItem(
item: AnyResponseItem,
): item is FunctionCallItem {
return item.type === "function_call" && "name" in item;
}
export function isWebSearchCallItem(
item: AnyResponseItem,
): item is WebSearchCallItem {
return item.type === "web_search_call";
}
export function isFunctionCallOutputItem(
item: AnyResponseItem,
): item is FunctionCallOutputItem {
return (
item.type === "function_call_output" &&
"call_id" in item &&
typeof (item as any).call_id === "string"
);
}

View file

@ -0,0 +1,49 @@
import React from "react";
export interface MessageBlockProps {
label: string;
labelDetail?: string;
content: React.ReactNode;
className?: string;
contentClassName?: string;
}
export const MessageBlock: React.FC<MessageBlockProps> = ({
label,
labelDetail,
content,
className = "",
contentClassName = "",
}) => {
return (
<div className={`mb-4 ${className}`}>
<p className="py-1 font-semibold text-gray-800 mb-1">
{label}
{labelDetail && (
<span className="text-xs text-gray-500 font-normal ml-1">
{labelDetail}
</span>
)}
</p>
<div className={`py-1 whitespace-pre-wrap ${contentClassName}`}>
{content}
</div>
</div>
);
};
export interface ToolCallBlockProps {
children: React.ReactNode;
className?: string;
}
export const ToolCallBlock = ({ children, className }: ToolCallBlockProps) => {
const baseClassName =
"p-3 bg-slate-50 border border-slate-200 rounded-md text-sm";
return (
<div className={`${baseClassName} ${className || ""}`}>
<pre className="whitespace-pre-wrap text-xs">{children}</pre>
</div>
);
};

View file

@ -0,0 +1,12 @@
import LlamaStackClient from "llama-stack-client";
import OpenAI from "openai";
export const client =
process.env.NEXT_PUBLIC_USE_OPENAI_CLIENT === "true" // useful for testing
? new OpenAI({
apiKey: process.env.NEXT_PUBLIC_OPENAI_API_KEY,
dangerouslyAllowBrowser: true,
})
: new LlamaStackClient({
baseURL: process.env.NEXT_PUBLIC_LLAMA_STACK_BASE_URL,
});

View file

@ -43,10 +43,14 @@ export function extractDisplayableText(
return ""; return "";
} }
let textPart = extractTextFromContentPart(message.content); const textPart = extractTextFromContentPart(message.content);
let toolCallPart = ""; let toolCallPart = "";
if (message.tool_calls && message.tool_calls.length > 0) { if (
message.tool_calls &&
Array.isArray(message.tool_calls) &&
message.tool_calls.length > 0
) {
// For summary, usually the first tool call is sufficient // For summary, usually the first tool call is sufficient
toolCallPart = formatToolCallToString(message.tool_calls[0]); toolCallPart = formatToolCallToString(message.tool_calls[0]);
} }

View file

@ -18,20 +18,20 @@ export interface ImageUrlContentBlock {
export type ChatMessageContentPart = export type ChatMessageContentPart =
| TextContentBlock | TextContentBlock
| ImageUrlContentBlock | ImageUrlContentBlock
| { type: string; [key: string]: any }; // Fallback for other potential types | { type: string; [key: string]: unknown }; // Fallback for other potential types
export interface ChatMessage { export interface ChatMessage {
role: string; role: string;
content: string | ChatMessageContentPart[]; // Updated content type content: string | ChatMessageContentPart[]; // Updated content type
name?: string | null; name?: string | null;
tool_calls?: any | null; // This could also be refined to a more specific ToolCall[] type tool_calls?: unknown | null; // This could also be refined to a more specific ToolCall[] type
} }
export interface Choice { export interface Choice {
message: ChatMessage; message: ChatMessage;
finish_reason: string; finish_reason: string;
index: number; index: number;
logprobs?: any | null; logprobs?: unknown | null;
} }
export interface ChatCompletion { export interface ChatCompletion {
@ -42,3 +42,62 @@ export interface ChatCompletion {
model: string; model: string;
input_messages: ChatMessage[]; input_messages: ChatMessage[];
} }
// Response types for OpenAI Responses API
export interface ResponseInputMessageContent {
text?: string;
type: "input_text" | "input_image" | "output_text";
image_url?: string;
detail?: "low" | "high" | "auto";
}
export interface ResponseMessage {
content: string | ResponseInputMessageContent[];
role: "system" | "developer" | "user" | "assistant";
type: "message";
id?: string;
status?: string;
}
export interface ResponseToolCall {
id: string;
status: string;
type: "web_search_call" | "function_call";
arguments?: string;
call_id?: string;
name?: string;
}
export type ResponseOutput = ResponseMessage | ResponseToolCall;
export interface ResponseInput {
type: string;
content?: string | ResponseInputMessageContent[];
role?: string;
[key: string]: unknown; // Flexible for various input types
}
export interface OpenAIResponse {
id: string;
created_at: number;
model: string;
object: "response";
status: string;
output: ResponseOutput[];
input: ResponseInput[];
error?: {
code: string;
message: string;
};
parallel_tool_calls?: boolean;
previous_response_id?: string;
temperature?: number;
top_p?: number;
truncation?: string;
user?: string;
}
export interface InputItemListResponse {
data: ResponseInput[];
object: "list";
}

View file

@ -19,6 +19,7 @@
"lucide-react": "^0.510.0", "lucide-react": "^0.510.0",
"next": "15.3.2", "next": "15.3.2",
"next-themes": "^0.4.6", "next-themes": "^0.4.6",
"openai": "^4.103.0",
"react": "^19.0.0", "react": "^19.0.0",
"react-dom": "^19.0.0", "react-dom": "^19.0.0",
"tailwind-merge": "^3.3.0" "tailwind-merge": "^3.3.0"
@ -9092,7 +9093,7 @@
}, },
"node_modules/llama-stack-client": { "node_modules/llama-stack-client": {
"version": "0.0.1-alpha.0", "version": "0.0.1-alpha.0",
"resolved": "git+ssh://git@github.com/stainless-sdks/llama-stack-node.git#efa814980d44b3b2c92944377a086915137b2134", "resolved": "git+ssh://git@github.com/stainless-sdks/llama-stack-node.git#5d34d229fb53b6dad02da0f19f4b310b529c6b15",
"license": "Apache-2.0", "license": "Apache-2.0",
"dependencies": { "dependencies": {
"@types/node": "^18.11.18", "@types/node": "^18.11.18",
@ -9804,6 +9805,51 @@
"url": "https://github.com/sponsors/sindresorhus" "url": "https://github.com/sponsors/sindresorhus"
} }
}, },
"node_modules/openai": {
"version": "4.103.0",
"resolved": "https://registry.npmjs.org/openai/-/openai-4.103.0.tgz",
"integrity": "sha512-eWcz9kdurkGOFDtd5ySS5y251H2uBgq9+1a2lTBnjMMzlexJ40Am5t6Mu76SSE87VvitPa0dkIAp75F+dZVC0g==",
"license": "Apache-2.0",
"dependencies": {
"@types/node": "^18.11.18",
"@types/node-fetch": "^2.6.4",
"abort-controller": "^3.0.0",
"agentkeepalive": "^4.2.1",
"form-data-encoder": "1.7.2",
"formdata-node": "^4.3.2",
"node-fetch": "^2.6.7"
},
"bin": {
"openai": "bin/cli"
},
"peerDependencies": {
"ws": "^8.18.0",
"zod": "^3.23.8"
},
"peerDependenciesMeta": {
"ws": {
"optional": true
},
"zod": {
"optional": true
}
}
},
"node_modules/openai/node_modules/@types/node": {
"version": "18.19.103",
"resolved": "https://registry.npmjs.org/@types/node/-/node-18.19.103.tgz",
"integrity": "sha512-hHTHp+sEz6SxFsp+SA+Tqrua3AbmlAw+Y//aEwdHrdZkYVRWdvWD3y5uPZ0flYOkgskaFWqZ/YGFm3FaFQ0pRw==",
"license": "MIT",
"dependencies": {
"undici-types": "~5.26.4"
}
},
"node_modules/openai/node_modules/undici-types": {
"version": "5.26.5",
"resolved": "https://registry.npmjs.org/undici-types/-/undici-types-5.26.5.tgz",
"integrity": "sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==",
"license": "MIT"
},
"node_modules/optionator": { "node_modules/optionator": {
"version": "0.9.4", "version": "0.9.4",
"resolved": "https://registry.npmjs.org/optionator/-/optionator-0.9.4.tgz", "resolved": "https://registry.npmjs.org/optionator/-/optionator-0.9.4.tgz",
@ -12223,7 +12269,7 @@
"version": "8.18.2", "version": "8.18.2",
"resolved": "https://registry.npmjs.org/ws/-/ws-8.18.2.tgz", "resolved": "https://registry.npmjs.org/ws/-/ws-8.18.2.tgz",
"integrity": "sha512-DMricUmwGZUVr++AEAe2uiVM7UoO9MAVZMDu05UQOaUII0lp+zOzLLU4Xqh/JvTqklB1T4uELaaPBKyjE1r4fQ==", "integrity": "sha512-DMricUmwGZUVr++AEAe2uiVM7UoO9MAVZMDu05UQOaUII0lp+zOzLLU4Xqh/JvTqklB1T4uELaaPBKyjE1r4fQ==",
"dev": true, "devOptional": true,
"license": "MIT", "license": "MIT",
"engines": { "engines": {
"node": ">=10.0.0" "node": ">=10.0.0"
@ -12334,7 +12380,7 @@
"version": "3.24.4", "version": "3.24.4",
"resolved": "https://registry.npmjs.org/zod/-/zod-3.24.4.tgz", "resolved": "https://registry.npmjs.org/zod/-/zod-3.24.4.tgz",
"integrity": "sha512-OdqJE9UDRPwWsrHjLN2F8bPxvwJBK22EHLWtanu0LSYr5YqzsaaW3RMgmjwr8Rypg5k+meEJdSPXJZXE/yqOMg==", "integrity": "sha512-OdqJE9UDRPwWsrHjLN2F8bPxvwJBK22EHLWtanu0LSYr5YqzsaaW3RMgmjwr8Rypg5k+meEJdSPXJZXE/yqOMg==",
"dev": true, "devOptional": true,
"license": "MIT", "license": "MIT",
"funding": { "funding": {
"url": "https://github.com/sponsors/colinhacks" "url": "https://github.com/sponsors/colinhacks"

View file

@ -19,7 +19,7 @@
"@radix-ui/react-tooltip": "^1.2.6", "@radix-ui/react-tooltip": "^1.2.6",
"class-variance-authority": "^0.7.1", "class-variance-authority": "^0.7.1",
"clsx": "^2.1.1", "clsx": "^2.1.1",
"llama-stack-client": "github:stainless-sdks/llama-stack-node#ehhuang/dev", "llama-stack-client": "0.2.8",
"lucide-react": "^0.510.0", "lucide-react": "^0.510.0",
"next": "15.3.2", "next": "15.3.2",
"next-themes": "^0.4.6", "next-themes": "^0.4.6",

View file

@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "llama_stack" name = "llama_stack"
version = "0.2.7" version = "0.2.8"
authors = [{ name = "Meta Llama", email = "llama-oss@meta.com" }] authors = [{ name = "Meta Llama", email = "llama-oss@meta.com" }]
description = "Llama Stack" description = "Llama Stack"
readme = "README.md" readme = "README.md"
@ -21,13 +21,13 @@ classifiers = [
"Topic :: Scientific/Engineering :: Information Analysis", "Topic :: Scientific/Engineering :: Information Analysis",
] ]
dependencies = [ dependencies = [
"blobfile", "aiohttp",
"fire", "fire",
"httpx", "httpx",
"huggingface-hub", "huggingface-hub",
"jinja2>=3.1.6", "jinja2>=3.1.6",
"jsonschema", "jsonschema",
"llama-stack-client>=0.2.7", "llama-stack-client>=0.2.8",
"openai>=1.66", "openai>=1.66",
"prompt-toolkit", "prompt-toolkit",
"python-dotenv", "python-dotenv",
@ -36,6 +36,7 @@ dependencies = [
"requests", "requests",
"rich", "rich",
"setuptools", "setuptools",
"starlette",
"termcolor", "termcolor",
"tiktoken", "tiktoken",
"pillow", "pillow",
@ -43,6 +44,14 @@ dependencies = [
] ]
[project.optional-dependencies] [project.optional-dependencies]
ui = [
"streamlit",
"pandas",
"llama-stack-client>=0.2.8",
"streamlit-option-menu",
]
[dependency-groups]
dev = [ dev = [
"pytest", "pytest",
"pytest-timeout", "pytest-timeout",
@ -73,10 +82,11 @@ unit = [
"opentelemetry-exporter-otlp-proto-http", "opentelemetry-exporter-otlp-proto-http",
"sqlalchemy", "sqlalchemy",
"sqlalchemy[asyncio]>=2.0.41", "sqlalchemy[asyncio]>=2.0.41",
"blobfile",
] ]
# These are the core dependencies required for running integration tests. They are shared across all # These are the core dependencies required for running integration tests. They are shared across all
# providers. If a provider requires additional dependencies, please add them to your environment # providers. If a provider requires additional dependencies, please add them to your environment
# separately. If you are using "uv" to execute your tests, you can use the "--with" flag to specify extra # separately. If you are using "uv" to execute your tests, you can use the "--group" flag to specify extra
# dependencies. # dependencies.
test = [ test = [
"openai", "openai",
@ -112,12 +122,6 @@ docs = [
"sphinxcontrib.openapi", "sphinxcontrib.openapi",
] ]
codegen = ["rich", "pydantic", "jinja2>=3.1.6"] codegen = ["rich", "pydantic", "jinja2>=3.1.6"]
ui = [
"streamlit",
"pandas",
"llama-stack-client>=0.2.7",
"streamlit-option-menu",
]
[project.urls] [project.urls]
Homepage = "https://github.com/meta-llama/llama-stack" Homepage = "https://github.com/meta-llama/llama-stack"
@ -138,7 +142,6 @@ explicit = true
[tool.uv.sources] [tool.uv.sources]
torch = [{ index = "pytorch-cpu" }] torch = [{ index = "pytorch-cpu" }]
torchvision = [{ index = "pytorch-cpu" }] torchvision = [{ index = "pytorch-cpu" }]
llama-stack = { workspace = true }
[tool.ruff] [tool.ruff]
line-length = 120 line-length = 120
@ -332,10 +335,5 @@ init_forbid_extra = true
init_typed = true init_typed = true
warn_required_dynamic_aliases = true warn_required_dynamic_aliases = true
[dependency-groups]
dev = [
"llama-stack",
]
[tool.ruff.lint.pep8-naming] [tool.ruff.lint.pep8-naming]
classmethod-decorators = ["classmethod", "pydantic.field_validator"] classmethod-decorators = ["classmethod", "pydantic.field_validator"]

View file

@ -1,63 +1,203 @@
# This file was autogenerated by uv via the following command: # This file was autogenerated by uv via the following command:
# uv export --frozen --no-hashes --no-emit-project --output-file=requirements.txt # uv export --frozen --no-hashes --no-emit-project --no-default-groups --output-file=requirements.txt
aiohappyeyeballs==2.5.0
# via aiohttp
aiohttp==3.11.13
# via llama-stack
aiosignal==1.3.2
# via aiohttp
annotated-types==0.7.0 annotated-types==0.7.0
# via pydantic
anyio==4.8.0 anyio==4.8.0
# via
# httpx
# llama-stack-client
# openai
# starlette
async-timeout==5.0.1 ; python_full_version < '3.11'
# via aiohttp
attrs==25.1.0 attrs==25.1.0
blobfile==3.0.0 # via
# aiohttp
# jsonschema
# referencing
certifi==2025.1.31 certifi==2025.1.31
# via
# httpcore
# httpx
# requests
charset-normalizer==3.4.1 charset-normalizer==3.4.1
# via requests
click==8.1.8 click==8.1.8
# via llama-stack-client
colorama==0.4.6 ; sys_platform == 'win32' colorama==0.4.6 ; sys_platform == 'win32'
# via
# click
# tqdm
distro==1.9.0 distro==1.9.0
# via
# llama-stack-client
# openai
ecdsa==0.19.1 ecdsa==0.19.1
# via python-jose
exceptiongroup==1.2.2 ; python_full_version < '3.11' exceptiongroup==1.2.2 ; python_full_version < '3.11'
# via anyio
filelock==3.17.0 filelock==3.17.0
# via huggingface-hub
fire==0.7.0 fire==0.7.0
# via llama-stack
frozenlist==1.5.0
# via
# aiohttp
# aiosignal
fsspec==2024.12.0 fsspec==2024.12.0
# via huggingface-hub
h11==0.16.0 h11==0.16.0
# via
# httpcore
# llama-stack
httpcore==1.0.9 httpcore==1.0.9
# via httpx
httpx==0.28.1 httpx==0.28.1
# via
# llama-stack
# llama-stack-client
# openai
huggingface-hub==0.29.0 huggingface-hub==0.29.0
# via llama-stack
idna==3.10 idna==3.10
# via
# anyio
# httpx
# requests
# yarl
jinja2==3.1.6 jinja2==3.1.6
# via llama-stack
jiter==0.8.2 jiter==0.8.2
# via openai
jsonschema==4.23.0 jsonschema==4.23.0
# via llama-stack
jsonschema-specifications==2024.10.1 jsonschema-specifications==2024.10.1
llama-stack-client==0.2.7 # via jsonschema
lxml==5.3.1 llama-stack-client==0.2.8
# via llama-stack
markdown-it-py==3.0.0 markdown-it-py==3.0.0
# via rich
markupsafe==3.0.2 markupsafe==3.0.2
# via jinja2
mdurl==0.1.2 mdurl==0.1.2
# via markdown-it-py
multidict==6.1.0
# via
# aiohttp
# yarl
numpy==2.2.3 numpy==2.2.3
# via pandas
openai==1.71.0 openai==1.71.0
# via llama-stack
packaging==24.2 packaging==24.2
# via huggingface-hub
pandas==2.2.3 pandas==2.2.3
# via llama-stack-client
pillow==11.1.0 pillow==11.1.0
# via llama-stack
prompt-toolkit==3.0.50 prompt-toolkit==3.0.50
# via
# llama-stack
# llama-stack-client
propcache==0.3.0
# via
# aiohttp
# yarl
pyaml==25.1.0 pyaml==25.1.0
# via llama-stack-client
pyasn1==0.4.8 pyasn1==0.4.8
pycryptodomex==3.21.0 # via
# python-jose
# rsa
pydantic==2.10.6 pydantic==2.10.6
# via
# llama-stack
# llama-stack-client
# openai
pydantic-core==2.27.2 pydantic-core==2.27.2
# via pydantic
pygments==2.19.1 pygments==2.19.1
# via rich
python-dateutil==2.9.0.post0 python-dateutil==2.9.0.post0
# via pandas
python-dotenv==1.0.1 python-dotenv==1.0.1
# via llama-stack
python-jose==3.4.0 python-jose==3.4.0
# via llama-stack
pytz==2025.1 pytz==2025.1
# via pandas
pyyaml==6.0.2 pyyaml==6.0.2
# via
# huggingface-hub
# pyaml
referencing==0.36.2 referencing==0.36.2
# via
# jsonschema
# jsonschema-specifications
regex==2024.11.6 regex==2024.11.6
# via tiktoken
requests==2.32.3 requests==2.32.3
# via
# huggingface-hub
# llama-stack
# tiktoken
rich==13.9.4 rich==13.9.4
# via
# llama-stack
# llama-stack-client
rpds-py==0.22.3 rpds-py==0.22.3
# via
# jsonschema
# referencing
rsa==4.9 rsa==4.9
# via python-jose
setuptools==80.8.0 setuptools==80.8.0
# via llama-stack
six==1.17.0 six==1.17.0
# via
# ecdsa
# python-dateutil
sniffio==1.3.1 sniffio==1.3.1
# via
# anyio
# llama-stack-client
# openai
starlette==0.45.3
# via llama-stack
termcolor==2.5.0 termcolor==2.5.0
# via
# fire
# llama-stack
# llama-stack-client
tiktoken==0.9.0 tiktoken==0.9.0
# via llama-stack
tqdm==4.67.1 tqdm==4.67.1
# via
# huggingface-hub
# llama-stack-client
# openai
typing-extensions==4.12.2 typing-extensions==4.12.2
# via
# anyio
# huggingface-hub
# llama-stack-client
# multidict
# openai
# pydantic
# pydantic-core
# referencing
# rich
tzdata==2025.1 tzdata==2025.1
# via pandas
urllib3==2.3.0 urllib3==2.3.0
# via requests
wcwidth==0.2.13 wcwidth==0.2.13
# via prompt-toolkit
yarl==1.18.3
# via aiohttp

View file

@ -7,7 +7,6 @@
import concurrent.futures import concurrent.futures
import importlib import importlib
import json
import subprocess import subprocess
import sys import sys
from collections.abc import Iterable from collections.abc import Iterable
@ -108,21 +107,6 @@ def collect_template_dependencies(template_dir: Path) -> tuple[str | None, list[
return None, [] return None, []
def generate_dependencies_file(change_tracker: ChangedPathTracker):
templates_dir = REPO_ROOT / "llama_stack" / "templates"
distribution_deps = {}
for template_dir in find_template_dirs(templates_dir):
name, deps = collect_template_dependencies(template_dir)
if name:
distribution_deps[name] = deps
deps_file = REPO_ROOT / "llama_stack" / "templates" / "dependencies.json"
change_tracker.add_paths(deps_file)
with open(deps_file, "w") as f:
f.write(json.dumps(distribution_deps, indent=2) + "\n")
def main(): def main():
templates_dir = REPO_ROOT / "llama_stack" / "templates" templates_dir = REPO_ROOT / "llama_stack" / "templates"
change_tracker = ChangedPathTracker() change_tracker = ChangedPathTracker()
@ -143,8 +127,6 @@ def main():
list(executor.map(process_func, template_dirs)) list(executor.map(process_func, template_dirs))
progress.update(task, advance=len(template_dirs)) progress.update(task, advance=len(template_dirs))
generate_dependencies_file(change_tracker)
if check_for_changes(change_tracker): if check_for_changes(change_tracker):
print( print(
"Distribution template changes detected. Please commit the changes.", "Distribution template changes detected. Please commit the changes.",

View file

@ -10,10 +10,10 @@ PYTHON_VERSION=${PYTHON_VERSION:-3.10}
command -v uv >/dev/null 2>&1 || { echo >&2 "uv is required but it's not installed. Exiting."; exit 1; } command -v uv >/dev/null 2>&1 || { echo >&2 "uv is required but it's not installed. Exiting."; exit 1; }
uv python find $PYTHON_VERSION uv python find "$PYTHON_VERSION"
FOUND_PYTHON=$? FOUND_PYTHON=$?
if [ $FOUND_PYTHON -ne 0 ]; then if [ $FOUND_PYTHON -ne 0 ]; then
uv python install $PYTHON_VERSION uv python install "$PYTHON_VERSION"
fi fi
uv run --python $PYTHON_VERSION --with-editable . --with-editable ".[dev]" --with-editable ".[unit]" pytest --asyncio-mode=auto -s -v tests/unit/ $@ uv run --python "$PYTHON_VERSION" --with-editable . --group unit pytest --asyncio-mode=auto -s -v tests/unit/ $@

View file

@ -6,7 +6,6 @@ dependencies = [
"aiohttp", "aiohttp",
"aiosqlite", "aiosqlite",
"autoevals", "autoevals",
"blobfile",
"chardet", "chardet",
"chromadb-client", "chromadb-client",
"datasets", "datasets",

View file

@ -41,7 +41,6 @@ def openai_client(client_with_models):
], ],
], ],
) )
@pytest.mark.skip(reason="Very flaky, sometimes there is a message not a function call, standard tool calling issues")
def test_responses_store(openai_client, client_with_models, text_model_id, stream, tools): def test_responses_store(openai_client, client_with_models, text_model_id, stream, tools):
if isinstance(client_with_models, LlamaStackAsLibraryClient): if isinstance(client_with_models, LlamaStackAsLibraryClient):
pytest.skip("OpenAI responses are not supported when testing with library client yet.") pytest.skip("OpenAI responses are not supported when testing with library client yet.")
@ -68,13 +67,15 @@ def test_responses_store(openai_client, client_with_models, text_model_id, strea
for chunk in response: for chunk in response:
if response_id is None: if response_id is None:
response_id = chunk.response.id response_id = chunk.response.id
if not tools:
if chunk.type == "response.completed": if chunk.type == "response.completed":
response_id = chunk.response.id response_id = chunk.response.id
output_type = chunk.response.output[0].type
if output_type == "message":
content = chunk.response.output[0].content[0].text content = chunk.response.output[0].content[0].text
else: else:
response_id = response.id response_id = response.id
if not tools: output_type = response.output[0].type
if output_type == "message":
content = response.output[0].content[0].text content = response.output[0].content[0].text
# list responses - use the underlying HTTP client for endpoints not in SDK # list responses - use the underlying HTTP client for endpoints not in SDK
@ -87,9 +88,8 @@ def test_responses_store(openai_client, client_with_models, text_model_id, strea
retrieved_response = client.responses.retrieve(response_id) retrieved_response = client.responses.retrieve(response_id)
assert retrieved_response.id == response_id assert retrieved_response.id == response_id
assert retrieved_response.model == text_model_id assert retrieved_response.model == text_model_id
if tools: assert retrieved_response.output[0].type == output_type, retrieved_response
assert retrieved_response.output[0].type == "function_call" if output_type == "message":
else:
assert retrieved_response.output[0].content[0].text == content assert retrieved_response.output[0].content[0].text == content

View file

@ -224,6 +224,43 @@ def test_openai_chat_completion_streaming(compat_client, client_with_models, tex
assert expected.lower() in "".join(streamed_content) assert expected.lower() in "".join(streamed_content)
@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:streaming_01",
"inference:chat_completion:streaming_02",
],
)
def test_openai_chat_completion_streaming_with_n(compat_client, client_with_models, text_model_id, test_case):
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
provider = provider_from_model(client_with_models, text_model_id)
if provider.provider_type == "remote::ollama":
pytest.skip(f"Model {text_model_id} hosted by {provider.provider_type} doesn't support n > 1.")
tc = TestCase(test_case)
question = tc["question"]
expected = tc["expected"]
response = compat_client.chat.completions.create(
model=text_model_id,
messages=[{"role": "user", "content": question}],
stream=True,
timeout=120, # Increase timeout to 2 minutes for large conversation history,
n=2,
)
streamed_content = {}
for chunk in response:
for choice in chunk.choices:
if choice.delta.content:
streamed_content[choice.index] = (
streamed_content.get(choice.index, "") + choice.delta.content.lower().strip()
)
assert len(streamed_content) == 2
for i, content in streamed_content.items():
assert expected.lower() in content, f"Choice {i}: Expected {expected.lower()} in {content}"
@pytest.mark.parametrize( @pytest.mark.parametrize(
"stream", "stream",
[ [
@ -253,6 +290,7 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea
for chunk in response: for chunk in response:
if response_id is None: if response_id is None:
response_id = chunk.id response_id = chunk.id
if chunk.choices[0].delta.content:
content += chunk.choices[0].delta.content content += chunk.choices[0].delta.content
else: else:
response_id = response.id response_id = response.id
@ -263,8 +301,8 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea
retrieved_response = client.chat.completions.retrieve(response_id) retrieved_response = client.chat.completions.retrieve(response_id)
assert retrieved_response.id == response_id assert retrieved_response.id == response_id
assert retrieved_response.input_messages[0]["content"] == message assert retrieved_response.input_messages[0]["content"] == message, retrieved_response
assert retrieved_response.choices[0].message.content == content assert retrieved_response.choices[0].message.content == content, retrieved_response
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -274,7 +312,6 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea
False, False,
], ],
) )
@pytest.mark.skip(reason="Very flaky, tool calling really wacky on CI")
def test_inference_store_tool_calls(openai_client, client_with_models, text_model_id, stream): def test_inference_store_tool_calls(openai_client, client_with_models, text_model_id, stream):
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id) skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
client = openai_client client = openai_client
@ -312,7 +349,9 @@ def test_inference_store_tool_calls(openai_client, client_with_models, text_mode
for chunk in response: for chunk in response:
if response_id is None: if response_id is None:
response_id = chunk.id response_id = chunk.id
content += chunk.choices[0].delta.content if delta := chunk.choices[0].delta:
if delta.content:
content += delta.content
else: else:
response_id = response.id response_id = response.id
content = response.choices[0].message.content content = response.choices[0].message.content
@ -323,5 +362,11 @@ def test_inference_store_tool_calls(openai_client, client_with_models, text_mode
retrieved_response = client.chat.completions.retrieve(response_id) retrieved_response = client.chat.completions.retrieve(response_id)
assert retrieved_response.id == response_id assert retrieved_response.id == response_id
assert retrieved_response.input_messages[0]["content"] == message assert retrieved_response.input_messages[0]["content"] == message
assert retrieved_response.choices[0].message.tool_calls[0].function.name == "get_weather" tool_calls = retrieved_response.choices[0].message.tool_calls
assert retrieved_response.choices[0].message.tool_calls[0].function.arguments == '{"city":"Tokyo"}' # sometimes model doesn't ouptut tool calls, but we still want to test that the tool was called
if tool_calls:
assert len(tool_calls) == 1
assert tool_calls[0].function.name == "get_weather"
assert "tokyo" in tool_calls[0].function.arguments.lower()
else:
assert retrieved_response.choices[0].message.content == content

View file

@ -25,10 +25,12 @@ def test_web_search_tool(llama_stack_client, sample_search_query):
if "TAVILY_SEARCH_API_KEY" not in os.environ: if "TAVILY_SEARCH_API_KEY" not in os.environ:
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test") pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
tools = llama_stack_client.tool_runtime.list_tools()
assert any(tool.identifier == "web_search" for tool in tools)
response = llama_stack_client.tool_runtime.invoke_tool( response = llama_stack_client.tool_runtime.invoke_tool(
tool_name="web_search", kwargs={"query": sample_search_query} tool_name="web_search", kwargs={"query": sample_search_query}
) )
# Verify the response # Verify the response
assert response.content is not None assert response.content is not None
assert len(response.content) > 0 assert len(response.content) > 0
@ -49,11 +51,12 @@ def test_wolfram_alpha_tool(llama_stack_client, sample_wolfram_alpha_query):
if "WOLFRAM_ALPHA_API_KEY" not in os.environ: if "WOLFRAM_ALPHA_API_KEY" not in os.environ:
pytest.skip("WOLFRAM_ALPHA_API_KEY not set, skipping test") pytest.skip("WOLFRAM_ALPHA_API_KEY not set, skipping test")
tools = llama_stack_client.tool_runtime.list_tools()
assert any(tool.identifier == "wolfram_alpha" for tool in tools)
response = llama_stack_client.tool_runtime.invoke_tool( response = llama_stack_client.tool_runtime.invoke_tool(
tool_name="wolfram_alpha", kwargs={"query": sample_wolfram_alpha_query} tool_name="wolfram_alpha", kwargs={"query": sample_wolfram_alpha_query}
) )
print(response.content)
assert response.content is not None assert response.content is not None
assert len(response.content) > 0 assert len(response.content) > 0
assert isinstance(response.content, str) assert isinstance(response.content, str)

View file

@ -31,8 +31,7 @@ def test_mcp_invocation(llama_stack_client, mcp_server):
test_toolgroup_id = MCP_TOOLGROUP_ID test_toolgroup_id = MCP_TOOLGROUP_ID
uri = mcp_server["server_url"] uri = mcp_server["server_url"]
# registering itself should fail since it requires listing tools # registering should not raise an error anymore even if you don't specify the auth token
with pytest.raises(Exception, match="Unauthorized"):
llama_stack_client.toolgroups.register( llama_stack_client.toolgroups.register(
toolgroup_id=test_toolgroup_id, toolgroup_id=test_toolgroup_id,
provider_id="model-context-protocol", provider_id="model-context-protocol",
@ -41,27 +40,18 @@ def test_mcp_invocation(llama_stack_client, mcp_server):
provider_data = { provider_data = {
"mcp_headers": { "mcp_headers": {
uri: [ uri: {
f"Authorization: Bearer {AUTH_TOKEN}", "Authorization": f"Bearer {AUTH_TOKEN}",
], },
}, },
} }
auth_headers = { auth_headers = {
"X-LlamaStack-Provider-Data": json.dumps(provider_data), "X-LlamaStack-Provider-Data": json.dumps(provider_data),
} }
try: with pytest.raises(Exception, match="Unauthorized"):
llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id, extra_headers=auth_headers) llama_stack_client.tools.list()
except Exception as e:
# An error is OK since the toolgroup may not exist
print(f"Error unregistering toolgroup: {e}")
llama_stack_client.toolgroups.register(
toolgroup_id=test_toolgroup_id,
provider_id="model-context-protocol",
mcp_endpoint=dict(uri=uri),
extra_headers=auth_headers,
)
response = llama_stack_client.tools.list( response = llama_stack_client.tools.list(
toolgroup_id=test_toolgroup_id, toolgroup_id=test_toolgroup_id,
extra_headers=auth_headers, extra_headers=auth_headers,

View file

@ -51,7 +51,5 @@ def test_register_and_unregister_toolgroup(llama_stack_client):
with pytest.raises(Exception, match=f"Tool group '{test_toolgroup_id}' not found"): with pytest.raises(Exception, match=f"Tool group '{test_toolgroup_id}' not found"):
llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id) llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id)
# Verify tools are also unregistered with pytest.raises(Exception, match=f"Tool group '{test_toolgroup_id}' not found"):
unregister_tools_list_response = llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id) llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id)
assert isinstance(unregister_tools_list_response, list)
assert not unregister_tools_list_response

View file

@ -15,7 +15,7 @@ from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataS
from llama_stack.apis.datatypes import Api from llama_stack.apis.datatypes import Api
from llama_stack.apis.models.models import Model, ModelType from llama_stack.apis.models.models import Model, ModelType
from llama_stack.apis.shields.shields import Shield from llama_stack.apis.shields.shields import Shield
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolParameter from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter
from llama_stack.apis.vector_dbs.vector_dbs import VectorDB from llama_stack.apis.vector_dbs.vector_dbs import VectorDB
from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable
from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable
@ -101,11 +101,11 @@ class ToolGroupsImpl(Impl):
def __init__(self): def __init__(self):
super().__init__(Api.tool_runtime) super().__init__(Api.tool_runtime)
async def register_tool(self, tool): async def register_toolgroup(self, toolgroup: ToolGroup):
return tool return toolgroup
async def unregister_tool(self, tool_name: str): async def unregister_toolgroup(self, toolgroup_id: str):
return tool_name return toolgroup_id
async def list_runtime_tools(self, toolgroup_id, mcp_endpoint): async def list_runtime_tools(self, toolgroup_id, mcp_endpoint):
return ListToolDefsResponse( return ListToolDefsResponse(

View file

@ -232,9 +232,17 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
# Check that we got the content from our mocked tool execution result # Check that we got the content from our mocked tool execution result
chunks = [chunk async for chunk in result] chunks = [chunk async for chunk in result]
assert len(chunks) > 0 assert len(chunks) == 2 # Should have response.created and response.completed
assert chunks[0].response.output[0].type == "function_call"
assert chunks[0].response.output[0].name == "get_weather" # Check response.created event (should have empty output)
assert chunks[0].type == "response.created"
assert len(chunks[0].response.output) == 0
# Check response.completed event (should have the tool call)
assert chunks[1].type == "response.completed"
assert len(chunks[1].response.output) == 1
assert chunks[1].response.output[0].type == "function_call"
assert chunks[1].response.output[0].name == "get_weather"
@pytest.mark.asyncio @pytest.mark.asyncio
@ -620,3 +628,69 @@ async def test_responses_store_list_input_items_logic():
result = await responses_store.list_response_input_items("resp_123", limit=0, order=Order.asc) result = await responses_store.list_response_input_items("resp_123", limit=0, order=Order.asc)
assert result.object == "list" assert result.object == "list"
assert len(result.data) == 0 # Should return no items assert len(result.data) == 0 # Should return no items
@pytest.mark.asyncio
async def test_store_response_uses_rehydrated_input_with_previous_response(
openai_responses_impl, mock_responses_store, mock_inference_api
):
"""Test that _store_response uses the full re-hydrated input (including previous responses)
rather than just the original input when previous_response_id is provided."""
# Setup - Create a previous response that should be included in the stored input
previous_response = OpenAIResponseObjectWithInput(
id="resp-previous-123",
object="response",
created_at=1234567890,
model="meta-llama/Llama-3.1-8B-Instruct",
status="completed",
input=[
OpenAIResponseMessage(
id="msg-prev-user", role="user", content=[OpenAIResponseInputMessageContentText(text="What is 2+2?")]
)
],
output=[
OpenAIResponseMessage(
id="msg-prev-assistant",
role="assistant",
content=[OpenAIResponseOutputMessageContentOutputText(text="2+2 equals 4.")],
)
],
)
mock_responses_store.get_response_object.return_value = previous_response
current_input = "Now what is 3+3?"
model = "meta-llama/Llama-3.1-8B-Instruct"
mock_chat_completion = load_chat_completion_fixture("simple_chat_completion.yaml")
mock_inference_api.openai_chat_completion.return_value = mock_chat_completion
# Execute - Create response with previous_response_id
result = await openai_responses_impl.create_openai_response(
input=current_input,
model=model,
previous_response_id="resp-previous-123",
store=True,
)
store_call_args = mock_responses_store.store_response_object.call_args
stored_input = store_call_args.kwargs["input"]
# Verify that the stored input contains the full re-hydrated conversation:
# 1. Previous user message
# 2. Previous assistant response
# 3. Current user message
assert len(stored_input) == 3
assert stored_input[0].role == "user"
assert stored_input[0].content[0].text == "What is 2+2?"
assert stored_input[1].role == "assistant"
assert stored_input[1].content[0].text == "2+2 equals 4."
assert stored_input[2].role == "user"
assert stored_input[2].content == "Now what is 3+3?"
# Verify the response itself is correct
assert result.model == model
assert result.status == "completed"

View file

@ -27,7 +27,7 @@ export TOGETHER_API_KEY=<your_together_api_key>
``` ```
then run then run
```bash ```bash
uv run --with-editable ".[dev]" python tests/verifications/generate_report.py --run-tests uv run python tests/verifications/generate_report.py --run-tests
``` ```
## Running Tests ## Running Tests

View file

@ -10,17 +10,17 @@ from tests.verifications.openai_api.fixtures.fixtures import _load_all_verificat
def pytest_generate_tests(metafunc): def pytest_generate_tests(metafunc):
"""Dynamically parametrize tests based on the selected provider and config.""" """Dynamically parametrize tests based on the selected provider and config."""
if "model" in metafunc.fixturenames: if "model" in metafunc.fixturenames:
model = metafunc.config.getoption("model")
if model:
metafunc.parametrize("model", [model])
return
provider = metafunc.config.getoption("provider") provider = metafunc.config.getoption("provider")
if not provider: if not provider:
print("Warning: --provider not specified. Skipping model parametrization.") print("Warning: --provider not specified. Skipping model parametrization.")
metafunc.parametrize("model", []) metafunc.parametrize("model", [])
return return
model = metafunc.config.getoption("model")
if model:
metafunc.parametrize("model", [model])
return
try: try:
config_data = _load_all_verification_configs() config_data = _load_all_verification_configs()
except (OSError, FileNotFoundError) as e: except (OSError, FileNotFoundError) as e:

View file

@ -77,11 +77,12 @@ test_response_image:
image_url: "https://upload.wikimedia.org/wikipedia/commons/f/f7/Llamas%2C_Vernagt-Stausee%2C_Italy.jpg" image_url: "https://upload.wikimedia.org/wikipedia/commons/f/f7/Llamas%2C_Vernagt-Stausee%2C_Italy.jpg"
output: "llama" output: "llama"
# the models are really poor at tool calling after seeing images :/
test_response_multi_turn_image: test_response_multi_turn_image:
test_name: test_response_multi_turn_image test_name: test_response_multi_turn_image
test_params: test_params:
case: case:
- case_id: "llama_image_search" - case_id: "llama_image_understanding"
turns: turns:
- input: - input:
- role: user - role: user
@ -91,7 +92,5 @@ test_response_multi_turn_image:
- type: input_image - type: input_image
image_url: "https://upload.wikimedia.org/wikipedia/commons/f/f7/Llamas%2C_Vernagt-Stausee%2C_Italy.jpg" image_url: "https://upload.wikimedia.org/wikipedia/commons/f/f7/Llamas%2C_Vernagt-Stausee%2C_Italy.jpg"
output: "llama" output: "llama"
- input: "Search the web using the search tool for the animal from the previous response. Your search query should be a single phrase that includes the animal's name and the words 'maverick', 'scout' and 'llm'" - input: "What country do you find this animal primarily in? What continent?"
tools: output: "peru"
- type: web_search
output: "model"

Some files were not shown because too many files have changed in this diff Show more