feat: small ollama package

This commit is contained in:
Raghotham Murthy 2025-05-28 21:13:48 -07:00
commit 2d5d05a2b4
103 changed files with 7262 additions and 7422 deletions

View file

@ -1,10 +1,8 @@
# What does this PR do?
[Provide a short summary of what this PR does and why. Link to relevant issues if applicable.]
<!-- Provide a short summary of what this PR does and why. Link to relevant issues if applicable. -->
[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])
<!-- If resolving an issue, uncomment and update the line below -->
<!-- Closes #[issue-number] -->
## Test Plan
[Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*]
[//]: # (## Documentation)
<!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* -->

View file

@ -13,7 +13,7 @@ runs:
- name: Install dependencies
shell: bash
run: |
uv sync --all-extras
uv sync --all-groups
uv pip install ollama faiss-cpu
# always test against the latest version of the client
# TODO: this is not necessarily a good idea. we need to test against both published and latest

View file

@ -53,7 +53,7 @@ repos:
- black==24.3.0
- repo: https://github.com/astral-sh/uv-pre-commit
rev: 0.6.3
rev: 0.7.8
hooks:
- id: uv-lock
- id: uv-export
@ -61,6 +61,7 @@ repos:
"--frozen",
"--no-hashes",
"--no-emit-project",
"--no-default-groups",
"--output-file=requirements.txt"
]
@ -88,8 +89,8 @@ repos:
- id: distro-codegen
name: Distribution Template Codegen
additional_dependencies:
- uv==0.6.0
entry: uv run --extra codegen ./scripts/distro_codegen.py
- uv==0.7.8
entry: uv run --group codegen ./scripts/distro_codegen.py
language: python
pass_filenames: false
require_serial: true
@ -97,8 +98,8 @@ repos:
- id: openapi-codegen
name: API Spec Codegen
additional_dependencies:
- uv==0.6.2
entry: sh -c 'uv run --with ".[dev]" ./docs/openapi_generator/run_openapi_generator.sh > /dev/null'
- uv==0.7.8
entry: sh -c 'uv run ./docs/openapi_generator/run_openapi_generator.sh > /dev/null'
language: python
pass_filenames: false
require_serial: true

View file

@ -5,28 +5,21 @@
# Required
version: 2
# Build documentation in the "docs/" directory with Sphinx
sphinx:
configuration: docs/source/conf.py
# Set the OS, Python version and other tools you might need
build:
os: ubuntu-22.04
tools:
python: "3.12"
# You can also specify other tool versions:
# nodejs: "19"
# rust: "1.64"
# golang: "1.19"
# Build documentation in the "docs/" directory with Sphinx
sphinx:
configuration: docs/source/conf.py
# Optionally build your docs in additional formats such as PDF and ePub
# formats:
# - pdf
# - epub
# Optional but recommended, declare the Python requirements required
# to build your documentation
# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html
python:
jobs:
pre_create_environment:
- asdf plugin add uv
- asdf install uv latest
- asdf global uv latest
create_environment:
- uv venv "${READTHEDOCS_VIRTUALENV_PATH}"
install:
- requirements: docs/requirements.txt
- UV_PROJECT_ENVIRONMENT="${READTHEDOCS_VIRTUALENV_PATH}" uv sync --frozen --group docs

View file

@ -168,10 +168,10 @@ If you are making changes to the documentation at [https://llama-stack.readthedo
```bash
# This rebuilds the documentation pages.
uv run --with ".[docs]" make -C docs/ html
uv run --group docs make -C docs/ html
# This will start a local server (usually at http://127.0.0.1:8000) that automatically rebuilds and refreshes when you make changes to the documentation.
uv run --with ".[docs]" sphinx-autobuild docs/source docs/build/html --write-all
uv run --group docs sphinx-autobuild docs/source docs/build/html --write-all
```
### Update API Documentation
@ -179,7 +179,7 @@ uv run --with ".[docs]" sphinx-autobuild docs/source docs/build/html --write-all
If you modify or add new API endpoints, update the API documentation accordingly. You can do this by running the following command:
```bash
uv run --with ".[dev]" ./docs/openapi_generator/run_openapi_generator.sh
uv run ./docs/openapi_generator/run_openapi_generator.sh
```
The generated API documentation will be available in `docs/_static/`. Make sure to review the changes before committing.

View file

@ -1,5 +1,4 @@
include pyproject.toml
include llama_stack/templates/dependencies.json
include llama_stack/models/llama/llama3/tokenizer.model
include llama_stack/models/llama/llama4/tokenizer.model
include llama_stack/distribution/*.sh

View file

@ -107,26 +107,29 @@ By reducing friction and complexity, Llama Stack empowers developers to focus on
### API Providers
Here is a list of the various API providers and available distributions that can help developers get started easily with Llama Stack.
| **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** |
|:------------------------:|:----------------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:|
| Meta Reference | Single Node | ✅ | ✅ | ✅ | ✅ | ✅ |
| SambaNova | Hosted | | ✅ | | ✅ | |
| Cerebras | Hosted | | ✅ | | | |
| Fireworks | Hosted | ✅ | ✅ | ✅ | | |
| AWS Bedrock | Hosted | | ✅ | | ✅ | |
| Together | Hosted | ✅ | ✅ | | ✅ | |
| Groq | Hosted | | ✅ | | | |
| Ollama | Single Node | | ✅ | | | |
| TGI | Hosted and Single Node | | ✅ | | | |
| NVIDIA NIM | Hosted and Single Node | | ✅ | | | |
| Chroma | Single Node | | | ✅ | | |
| PG Vector | Single Node | | | ✅ | | |
| PyTorch ExecuTorch | On-device iOS | ✅ | ✅ | | | |
| vLLM | Hosted and Single Node | | ✅ | | | |
| OpenAI | Hosted | | ✅ | | | |
| Anthropic | Hosted | | ✅ | | | |
| Gemini | Hosted | | ✅ | | | |
| watsonx | Hosted | | ✅ | | | |
| **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** | **Post Training** |
|:------------------------:|:----------------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:|:-----------------:|
| Meta Reference | Single Node | ✅ | ✅ | ✅ | ✅ | ✅ | |
| SambaNova | Hosted | | ✅ | | ✅ | | |
| Cerebras | Hosted | | ✅ | | | | |
| Fireworks | Hosted | ✅ | ✅ | ✅ | | | |
| AWS Bedrock | Hosted | | ✅ | | ✅ | | |
| Together | Hosted | ✅ | ✅ | | ✅ | | |
| Groq | Hosted | | ✅ | | | | |
| Ollama | Single Node | | ✅ | | | | |
| TGI | Hosted and Single Node | | ✅ | | | | |
| NVIDIA NIM | Hosted and Single Node | | ✅ | | | | |
| Chroma | Single Node | | | ✅ | | | |
| PG Vector | Single Node | | | ✅ | | | |
| PyTorch ExecuTorch | On-device iOS | ✅ | ✅ | | | | |
| vLLM | Hosted and Single Node | | ✅ | | | | |
| OpenAI | Hosted | | ✅ | | | | |
| Anthropic | Hosted | | ✅ | | | | |
| Gemini | Hosted | | ✅ | | | | |
| watsonx | Hosted | | ✅ | | | | |
| HuggingFace | Single Node | | | | | | ✅ |
| TorchTune | Single Node | | | | | | ✅ |
| NVIDIA NEMO | Hosted | | | | | | ✅ |
### Distributions

View file

@ -7540,6 +7540,9 @@
{
"$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseCreated"
},
{
"$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseOutputTextDelta"
},
{
"$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseCompleted"
}
@ -7548,6 +7551,7 @@
"propertyName": "type",
"mapping": {
"response.created": "#/components/schemas/OpenAIResponseObjectStreamResponseCreated",
"response.output_text.delta": "#/components/schemas/OpenAIResponseObjectStreamResponseOutputTextDelta",
"response.completed": "#/components/schemas/OpenAIResponseObjectStreamResponseCompleted"
}
}
@ -7590,6 +7594,41 @@
],
"title": "OpenAIResponseObjectStreamResponseCreated"
},
"OpenAIResponseObjectStreamResponseOutputTextDelta": {
"type": "object",
"properties": {
"content_index": {
"type": "integer"
},
"delta": {
"type": "string"
},
"item_id": {
"type": "string"
},
"output_index": {
"type": "integer"
},
"sequence_number": {
"type": "integer"
},
"type": {
"type": "string",
"const": "response.output_text.delta",
"default": "response.output_text.delta"
}
},
"additionalProperties": false,
"required": [
"content_index",
"delta",
"item_id",
"output_index",
"sequence_number",
"type"
],
"title": "OpenAIResponseObjectStreamResponseOutputTextDelta"
},
"CreateUploadSessionRequest": {
"type": "object",
"properties": {
@ -9555,9 +9594,6 @@
"toolgroup_id": {
"type": "string"
},
"tool_host": {
"$ref": "#/components/schemas/ToolHost"
},
"description": {
"type": "string"
},
@ -9599,21 +9635,11 @@
"provider_id",
"type",
"toolgroup_id",
"tool_host",
"description",
"parameters"
],
"title": "Tool"
},
"ToolHost": {
"type": "string",
"enum": [
"distribution",
"client",
"model_context_protocol"
],
"title": "ToolHost"
},
"ToolGroup": {
"type": "object",
"properties": {

View file

@ -5294,11 +5294,13 @@ components:
OpenAIResponseObjectStream:
oneOf:
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated'
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseOutputTextDelta'
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCompleted'
discriminator:
propertyName: type
mapping:
response.created: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated'
response.output_text.delta: '#/components/schemas/OpenAIResponseObjectStreamResponseOutputTextDelta'
response.completed: '#/components/schemas/OpenAIResponseObjectStreamResponseCompleted'
"OpenAIResponseObjectStreamResponseCompleted":
type: object
@ -5330,6 +5332,33 @@ components:
- type
title: >-
OpenAIResponseObjectStreamResponseCreated
"OpenAIResponseObjectStreamResponseOutputTextDelta":
type: object
properties:
content_index:
type: integer
delta:
type: string
item_id:
type: string
output_index:
type: integer
sequence_number:
type: integer
type:
type: string
const: response.output_text.delta
default: response.output_text.delta
additionalProperties: false
required:
- content_index
- delta
- item_id
- output_index
- sequence_number
- type
title: >-
OpenAIResponseObjectStreamResponseOutputTextDelta
CreateUploadSessionRequest:
type: object
properties:
@ -6713,8 +6742,6 @@ components:
default: tool
toolgroup_id:
type: string
tool_host:
$ref: '#/components/schemas/ToolHost'
description:
type: string
parameters:
@ -6737,17 +6764,9 @@ components:
- provider_id
- type
- toolgroup_id
- tool_host
- description
- parameters
title: Tool
ToolHost:
type: string
enum:
- distribution
- client
- model_context_protocol
title: ToolHost
ToolGroup:
type: object
properties:

File diff suppressed because it is too large Load diff

View file

@ -6,7 +6,7 @@ Here's a collection of comprehensive guides, examples, and resources for buildin
From the llama-stack root directory, run the following command to render the docs locally:
```bash
uv run --with ".[docs]" sphinx-autobuild docs/source docs/build/html --write-all
uv run --group docs sphinx-autobuild docs/source docs/build/html --write-all
```
You can open up the docs in your browser at http://localhost:8000

View file

@ -30,6 +30,18 @@ Runs inference with an LLM.
## Post Training
Fine-tunes a model.
#### Post Training Providers
The following providers are available for Post Training:
```{toctree}
:maxdepth: 1
external
post_training/huggingface
post_training/torchtune
post_training/nvidia_nemo
```
## Safety
Applies safety policies to the output at a Systems (not only model) level.

View file

@ -0,0 +1,122 @@
---
orphan: true
---
# HuggingFace SFTTrainer
[HuggingFace SFTTrainer](https://huggingface.co/docs/trl/en/sft_trainer) is an inline post training provider for Llama Stack. It allows you to run supervised fine tuning on a variety of models using many datasets
## Features
- Simple access through the post_training API
- Fully integrated with Llama Stack
- GPU support, CPU support, and MPS support (MacOS Metal Performance Shaders)
## Usage
To use the HF SFTTrainer in your Llama Stack project, follow these steps:
1. Configure your Llama Stack project to use this provider.
2. Kick off a SFT job using the Llama Stack post_training API.
## Setup
You can access the HuggingFace trainer via the `ollama` distribution:
```bash
llama stack build --template ollama --image-type venv
llama stack run --image-type venv ~/.llama/distributions/ollama/ollama-run.yaml
```
## Run Training
You can access the provider and the `supervised_fine_tune` method via the post_training API:
```python
import time
import uuid
from llama_stack_client.types import (
post_training_supervised_fine_tune_params,
algorithm_config_param,
)
def create_http_client():
from llama_stack_client import LlamaStackClient
return LlamaStackClient(base_url="http://localhost:8321")
client = create_http_client()
# Example Dataset
client.datasets.register(
purpose="post-training/messages",
source={
"type": "uri",
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
},
dataset_id="simpleqa",
)
training_config = post_training_supervised_fine_tune_params.TrainingConfig(
data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig(
batch_size=32,
data_format="instruct",
dataset_id="simpleqa",
shuffle=True,
),
gradient_accumulation_steps=1,
max_steps_per_epoch=0,
max_validation_steps=1,
n_epochs=4,
)
algorithm_config = algorithm_config_param.LoraFinetuningConfig( # this config is also currently mandatory but should not be
alpha=1,
apply_lora_to_mlp=True,
apply_lora_to_output=False,
lora_attn_modules=["q_proj"],
rank=1,
type="LoRA",
)
job_uuid = f"test-job{uuid.uuid4()}"
# Example Model
training_model = "ibm-granite/granite-3.3-8b-instruct"
start_time = time.time()
response = client.post_training.supervised_fine_tune(
job_uuid=job_uuid,
logger_config={},
model=training_model,
hyperparam_search_config={},
training_config=training_config,
algorithm_config=algorithm_config,
checkpoint_dir="output",
)
print("Job: ", job_uuid)
# Wait for the job to complete!
while True:
status = client.post_training.job.status(job_uuid=job_uuid)
if not status:
print("Job not found")
break
print(status)
if status.status == "completed":
break
print("Waiting for job to complete...")
time.sleep(5)
end_time = time.time()
print("Job completed in", end_time - start_time, "seconds!")
print("Artifacts:")
print(client.post_training.job.artifacts(job_uuid=job_uuid))
```

View file

@ -0,0 +1,163 @@
---
orphan: true
---
# NVIDIA NEMO
[NVIDIA NEMO](https://developer.nvidia.com/nemo-framework) is a remote post training provider for Llama Stack. It provides enterprise-grade fine-tuning capabilities through NVIDIA's NeMo Customizer service.
## Features
- Enterprise-grade fine-tuning capabilities
- Support for LoRA and SFT fine-tuning
- Integration with NVIDIA's NeMo Customizer service
- Support for various NVIDIA-optimized models
- Efficient training with NVIDIA hardware acceleration
## Usage
To use NVIDIA NEMO in your Llama Stack project, follow these steps:
1. Configure your Llama Stack project to use this provider.
2. Set up your NVIDIA API credentials.
3. Kick off a fine-tuning job using the Llama Stack post_training API.
## Setup
You'll need to set the following environment variables:
```bash
export NVIDIA_API_KEY="your-api-key"
export NVIDIA_DATASET_NAMESPACE="default"
export NVIDIA_CUSTOMIZER_URL="your-customizer-url"
export NVIDIA_PROJECT_ID="your-project-id"
export NVIDIA_OUTPUT_MODEL_DIR="your-output-model-dir"
```
## Run Training
You can access the provider and the `supervised_fine_tune` method via the post_training API:
```python
import time
import uuid
from llama_stack_client.types import (
post_training_supervised_fine_tune_params,
algorithm_config_param,
)
def create_http_client():
from llama_stack_client import LlamaStackClient
return LlamaStackClient(base_url="http://localhost:8321")
client = create_http_client()
# Example Dataset
client.datasets.register(
purpose="post-training/messages",
source={
"type": "uri",
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
},
dataset_id="simpleqa",
)
training_config = post_training_supervised_fine_tune_params.TrainingConfig(
data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig(
batch_size=8, # Default batch size for NEMO
data_format="instruct",
dataset_id="simpleqa",
shuffle=True,
),
n_epochs=50, # Default epochs for NEMO
optimizer_config=post_training_supervised_fine_tune_params.TrainingConfigOptimizerConfig(
lr=0.0001, # Default learning rate
weight_decay=0.01, # NEMO-specific parameter
),
# NEMO-specific parameters
log_every_n_steps=None,
val_check_interval=0.25,
sequence_packing_enabled=False,
hidden_dropout=None,
attention_dropout=None,
ffn_dropout=None,
)
algorithm_config = algorithm_config_param.LoraFinetuningConfig(
alpha=16, # Default alpha for NEMO
type="LoRA",
)
job_uuid = f"test-job{uuid.uuid4()}"
# Example Model - must be a supported NEMO model
training_model = "meta/llama-3.1-8b-instruct"
start_time = time.time()
response = client.post_training.supervised_fine_tune(
job_uuid=job_uuid,
logger_config={},
model=training_model,
hyperparam_search_config={},
training_config=training_config,
algorithm_config=algorithm_config,
checkpoint_dir="output",
)
print("Job: ", job_uuid)
# Wait for the job to complete!
while True:
status = client.post_training.job.status(job_uuid=job_uuid)
if not status:
print("Job not found")
break
print(status)
if status.status == "completed":
break
print("Waiting for job to complete...")
time.sleep(5)
end_time = time.time()
print("Job completed in", end_time - start_time, "seconds!")
print("Artifacts:")
print(client.post_training.job.artifacts(job_uuid=job_uuid))
```
## Supported Models
Currently supports the following models:
- meta/llama-3.1-8b-instruct
- meta/llama-3.2-1b-instruct
## Supported Parameters
### TrainingConfig
- n_epochs (default: 50)
- data_config
- optimizer_config
- log_every_n_steps
- val_check_interval (default: 0.25)
- sequence_packing_enabled (default: False)
- hidden_dropout (0.0-1.0)
- attention_dropout (0.0-1.0)
- ffn_dropout (0.0-1.0)
### DataConfig
- dataset_id
- batch_size (default: 8)
### OptimizerConfig
- lr (default: 0.0001)
- weight_decay (default: 0.01)
### LoRA Config
- alpha (default: 16)
- type (must be "LoRA")
Note: Some parameters from the standard Llama Stack API are not supported and will be ignored with a warning.

View file

@ -0,0 +1,125 @@
---
orphan: true
---
# TorchTune
[TorchTune](https://github.com/pytorch/torchtune) is an inline post training provider for Llama Stack. It provides a simple and efficient way to fine-tune language models using PyTorch.
## Features
- Simple access through the post_training API
- Fully integrated with Llama Stack
- GPU support and single device capabilities.
- Support for LoRA
## Usage
To use TorchTune in your Llama Stack project, follow these steps:
1. Configure your Llama Stack project to use this provider.
2. Kick off a fine-tuning job using the Llama Stack post_training API.
## Setup
You can access the TorchTune trainer by writing your own yaml pointing to the provider:
```yaml
post_training:
- provider_id: torchtune
provider_type: inline::torchtune
config: {}
```
you can then build and run your own stack with this provider.
## Run Training
You can access the provider and the `supervised_fine_tune` method via the post_training API:
```python
import time
import uuid
from llama_stack_client.types import (
post_training_supervised_fine_tune_params,
algorithm_config_param,
)
def create_http_client():
from llama_stack_client import LlamaStackClient
return LlamaStackClient(base_url="http://localhost:8321")
client = create_http_client()
# Example Dataset
client.datasets.register(
purpose="post-training/messages",
source={
"type": "uri",
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
},
dataset_id="simpleqa",
)
training_config = post_training_supervised_fine_tune_params.TrainingConfig(
data_config=post_training_supervised_fine_tune_params.TrainingConfigDataConfig(
batch_size=32,
data_format="instruct",
dataset_id="simpleqa",
shuffle=True,
),
gradient_accumulation_steps=1,
max_steps_per_epoch=0,
max_validation_steps=1,
n_epochs=4,
)
algorithm_config = algorithm_config_param.LoraFinetuningConfig(
alpha=1,
apply_lora_to_mlp=True,
apply_lora_to_output=False,
lora_attn_modules=["q_proj"],
rank=1,
type="LoRA",
)
job_uuid = f"test-job{uuid.uuid4()}"
# Example Model
training_model = "meta-llama/Llama-2-7b-hf"
start_time = time.time()
response = client.post_training.supervised_fine_tune(
job_uuid=job_uuid,
logger_config={},
model=training_model,
hyperparam_search_config={},
training_config=training_config,
algorithm_config=algorithm_config,
checkpoint_dir="output",
)
print("Job: ", job_uuid)
# Wait for the job to complete!
while True:
status = client.post_training.job.status(job_uuid=job_uuid)
if not status:
print("Job not found")
break
print(status)
if status.status == "completed":
break
print("Waiting for job to complete...")
time.sleep(5)
end_time = time.time()
print("Job completed in", end_time - start_time, "seconds!")
print("Artifacts:")
print(client.post_training.job.artifacts(job_uuid=job_uuid))
```

View file

@ -149,6 +149,16 @@ class OpenAIResponseObjectStreamResponseCreated(BaseModel):
type: Literal["response.created"] = "response.created"
@json_schema_type
class OpenAIResponseObjectStreamResponseOutputTextDelta(BaseModel):
content_index: int
delta: str
item_id: str
output_index: int
sequence_number: int
type: Literal["response.output_text.delta"] = "response.output_text.delta"
@json_schema_type
class OpenAIResponseObjectStreamResponseCompleted(BaseModel):
response: OpenAIResponseObject
@ -156,7 +166,9 @@ class OpenAIResponseObjectStreamResponseCompleted(BaseModel):
OpenAIResponseObjectStream = Annotated[
OpenAIResponseObjectStreamResponseCreated | OpenAIResponseObjectStreamResponseCompleted,
OpenAIResponseObjectStreamResponseCreated
| OpenAIResponseObjectStreamResponseOutputTextDelta
| OpenAIResponseObjectStreamResponseCompleted,
Field(discriminator="type"),
]
register_schema(OpenAIResponseObjectStream, name="OpenAIResponseObjectStream")

View file

@ -1,30 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any
from pydantic import BaseModel
from llama_stack.apis.common.content_types import URL
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class RestAPIMethod(Enum):
GET = "GET"
POST = "POST"
PUT = "PUT"
DELETE = "DELETE"
@json_schema_type
class RestAPIExecutionConfig(BaseModel):
url: URL
method: RestAPIMethod
params: dict[str, Any] | None = None
headers: dict[str, Any] | None = None
body: dict[str, Any] | None = None

View file

@ -27,18 +27,10 @@ class ToolParameter(BaseModel):
default: Any | None = None
@json_schema_type
class ToolHost(Enum):
distribution = "distribution"
client = "client"
model_context_protocol = "model_context_protocol"
@json_schema_type
class Tool(Resource):
type: Literal[ResourceType.tool] = ResourceType.tool
toolgroup_id: str
tool_host: ToolHost
description: str
parameters: list[ToolParameter]
metadata: dict[str, Any] | None = None

View file

@ -267,8 +267,8 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
if args.run:
config_dict = yaml.safe_load(run_config.read_text())
config = parse_and_maybe_upgrade_config(config_dict)
if not os.path.exists(config.external_providers_dir):
os.makedirs(config.external_providers_dir, exist_ok=True)
if config.external_providers_dir and not config.external_providers_dir.exists():
config.external_providers_dir.mkdir(exist_ok=True)
run_args = formulate_run_args(args.image_type, args.image_name, config, args.template)
run_args.extend([str(os.getenv("LLAMA_STACK_PORT", 8321)), "--config", run_config])
run_command(run_args)

View file

@ -125,7 +125,6 @@ RUN apt-get update && apt-get install -y \
curl wget telnet git\
procps psmisc lsof \
traceroute \
bubblewrap \
gcc \
&& rm -rf /var/lib/apt/lists/*

View file

@ -16,7 +16,7 @@ from llama_stack.apis.inspect import (
VersionInfo,
)
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
from llama_stack.distribution.server.routes import get_all_api_routes
from llama_stack.providers.datatypes import HealthStatus
@ -42,15 +42,15 @@ class DistributionInspectImpl(Inspect):
run_config: StackRunConfig = self.config.run_config
ret = []
all_endpoints = get_all_api_endpoints()
all_endpoints = get_all_api_routes()
for api, endpoints in all_endpoints.items():
# Always include provider and inspect APIs, filter others based on run config
if api.value in ["providers", "inspect"]:
ret.extend(
[
RouteInfo(
route=e.route,
method=e.method,
route=e.path,
method=next(iter([m for m in e.methods if m != "HEAD"])),
provider_types=[], # These APIs don't have "real" providers - they're internal to the stack
)
for e in endpoints
@ -62,8 +62,8 @@ class DistributionInspectImpl(Inspect):
ret.extend(
[
RouteInfo(
route=e.route,
method=e.method,
route=e.path,
method=next(iter([m for m in e.methods if m != "HEAD"])),
provider_types=[p.provider_type for p in providers],
)
for e in endpoints

View file

@ -37,10 +37,7 @@ from llama_stack.distribution.request_headers import (
request_provider_data_context,
)
from llama_stack.distribution.resolver import ProviderRegistry
from llama_stack.distribution.server.endpoints import (
find_matching_endpoint,
initialize_endpoint_impls,
)
from llama_stack.distribution.server.routes import find_matching_route, initialize_route_impls
from llama_stack.distribution.stack import (
construct_stack,
get_stack_run_config_from_template,
@ -208,7 +205,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
async def initialize(self) -> bool:
try:
self.endpoint_impls = None
self.route_impls = None
self.impls = await construct_stack(self.config, self.custom_provider_registry)
except ModuleNotFoundError as _e:
cprint(_e.msg, color="red", file=sys.stderr)
@ -254,7 +251,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
safe_config = redact_sensitive_fields(self.config.model_dump())
console.print(yaml.dump(safe_config, indent=2))
self.endpoint_impls = initialize_endpoint_impls(self.impls)
self.route_impls = initialize_route_impls(self.impls)
return True
async def request(
@ -265,7 +262,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
stream=False,
stream_cls=None,
):
if not self.endpoint_impls:
if not self.route_impls:
raise ValueError("Client not initialized")
# Create headers with provider data if available
@ -296,11 +293,14 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
cast_to: Any,
options: Any,
):
if self.route_impls is None:
raise ValueError("Client not initialized")
path = options.url
body = options.params or {}
body |= options.json_data or {}
matched_func, path_params, route = find_matching_endpoint(options.method, path, self.endpoint_impls)
matched_func, path_params, route = find_matching_route(options.method, path, self.route_impls)
body |= path_params
body = self._convert_body(path, options.method, body)
await start_trace(route, {"__location__": "library_client"})
@ -342,10 +342,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
options: Any,
stream_cls: Any,
):
if self.route_impls is None:
raise ValueError("Client not initialized")
path = options.url
body = options.params or {}
body |= options.json_data or {}
func, path_params, route = find_matching_endpoint(options.method, path, self.endpoint_impls)
func, path_params, route = find_matching_route(options.method, path, self.route_impls)
body |= path_params
body = self._convert_body(path, options.method, body)
@ -397,7 +400,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
if not body:
return {}
func, _, _ = find_matching_endpoint(method, path, self.endpoint_impls)
if self.route_impls is None:
raise ValueError("Client not initialized")
func, _, _ = find_matching_route(method, path, self.route_impls)
sig = inspect.signature(func)
# Strip NOT_GIVENs to use the defaults in signature

View file

@ -47,7 +47,7 @@ from llama_stack.providers.datatypes import (
RemoteProviderSpec,
ScoringFunctionsProtocolPrivate,
ShieldsProtocolPrivate,
ToolsProtocolPrivate,
ToolGroupsProtocolPrivate,
VectorDBsProtocolPrivate,
)
@ -93,7 +93,7 @@ def api_protocol_map_for_compliance_check() -> dict[Api, Any]:
def additional_protocols_map() -> dict[Api, Any]:
return {
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
Api.tool_groups: (ToolsProtocolPrivate, ToolGroups, Api.tool_groups),
Api.tool_groups: (ToolGroupsProtocolPrivate, ToolGroups, Api.tool_groups),
Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs),
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),

View file

@ -11,7 +11,7 @@ from llama_stack.apis.common.content_types import (
InterleavedContent,
)
from llama_stack.apis.tools import (
ListToolDefsResponse,
ListToolsResponse,
RAGDocument,
RAGQueryConfig,
RAGQueryResult,
@ -19,7 +19,8 @@ from llama_stack.apis.tools import (
ToolRuntime,
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
logger = get_logger(name=__name__, category="core")
@ -28,7 +29,7 @@ class ToolRuntimeRouter(ToolRuntime):
class RagToolImpl(RAGToolRuntime):
def __init__(
self,
routing_table: RoutingTable,
routing_table: ToolGroupsRoutingTable,
) -> None:
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
self.routing_table = routing_table
@ -59,7 +60,7 @@ class ToolRuntimeRouter(ToolRuntime):
def __init__(
self,
routing_table: RoutingTable,
routing_table: ToolGroupsRoutingTable,
) -> None:
logger.debug("Initializing ToolRuntimeRouter")
self.routing_table = routing_table
@ -86,6 +87,6 @@ class ToolRuntimeRouter(ToolRuntime):
async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
) -> ListToolDefsResponse:
) -> ListToolsResponse:
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
return await self.routing_table.list_tools(tool_group_id)

View file

@ -46,7 +46,7 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
elif api == Api.eval:
return await p.register_benchmark(obj)
elif api == Api.tool_runtime:
return await p.register_tool(obj)
return await p.register_toolgroup(obj)
else:
raise ValueError(f"Unknown API {api} for registering object with provider")
@ -60,7 +60,7 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
elif api == Api.datasetio:
return await p.unregister_dataset(obj.identifier)
elif api == Api.tool_runtime:
return await p.unregister_tool(obj.identifier)
return await p.unregister_toolgroup(obj.identifier)
else:
raise ValueError(f"Unregister not supported for {api}")
@ -136,7 +136,7 @@ class CommonRoutingTableImpl(RoutingTable):
elif isinstance(self, BenchmarksRoutingTable):
return ("Eval", "benchmark")
elif isinstance(self, ToolGroupsRoutingTable):
return ("Tools", "tool")
return ("ToolGroups", "tool_group")
else:
raise ValueError("Unknown routing table type")

View file

@ -7,11 +7,8 @@
from typing import Any
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups, ToolHost
from llama_stack.distribution.datatypes import (
ToolGroupWithACL,
ToolWithACL,
)
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups
from llama_stack.distribution.datatypes import ToolGroupWithACL
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
@ -19,12 +16,70 @@ from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None:
# handle the funny case like "builtin::rag/knowledge_search"
parts = toolgroup_name_with_maybe_tool_name.split("/")
if len(parts) == 2:
return parts[0]
else:
return None
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
tools = await self.get_all_with_type("tool")
toolgroups_to_tools: dict[str, list[Tool]] = {}
tool_to_toolgroup: dict[str, str] = {}
# overridden
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
# we don't index tools in the registry anymore, but only keep a cache of them by toolgroup_id
# TODO: we may want to invalidate the cache (for a given toolgroup_id) every once in a while?
toolgroup_id = parse_toolgroup_from_toolgroup_name_pair(routing_key)
if toolgroup_id:
tools = [tool for tool in tools if tool.toolgroup_id == toolgroup_id]
return ListToolsResponse(data=tools)
routing_key = toolgroup_id
if routing_key in self.tool_to_toolgroup:
routing_key = self.tool_to_toolgroup[routing_key]
return super().get_provider_impl(routing_key, provider_id)
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
if toolgroup_id:
if group_id := parse_toolgroup_from_toolgroup_name_pair(toolgroup_id):
toolgroup_id = group_id
toolgroups = [await self.get_tool_group(toolgroup_id)]
else:
toolgroups = await self.get_all_with_type("tool_group")
all_tools = []
for toolgroup in toolgroups:
if toolgroup.identifier not in self.toolgroups_to_tools:
await self._index_tools(toolgroup)
all_tools.extend(self.toolgroups_to_tools[toolgroup.identifier])
return ListToolsResponse(data=all_tools)
async def _index_tools(self, toolgroup: ToolGroup):
provider_impl = super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id)
tooldefs_response = await provider_impl.list_runtime_tools(toolgroup.identifier, toolgroup.mcp_endpoint)
# TODO: kill this Tool vs ToolDef distinction
tooldefs = tooldefs_response.data
tools = []
for t in tooldefs:
tools.append(
Tool(
identifier=t.name,
toolgroup_id=toolgroup.identifier,
description=t.description or "",
parameters=t.parameters or [],
metadata=t.metadata,
provider_id=toolgroup.provider_id,
)
)
self.toolgroups_to_tools[toolgroup.identifier] = tools
for tool in tools:
self.tool_to_toolgroup[tool.identifier] = toolgroup.identifier
async def list_tool_groups(self) -> ListToolGroupsResponse:
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
@ -36,7 +91,13 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
return tool_group
async def get_tool(self, tool_name: str) -> Tool:
return await self.get_object_by_identifier("tool", tool_name)
if tool_name in self.tool_to_toolgroup:
toolgroup_id = self.tool_to_toolgroup[tool_name]
tools = self.toolgroups_to_tools[toolgroup_id]
for tool in tools:
if tool.identifier == tool_name:
return tool
raise ValueError(f"Tool '{tool_name}' not found")
async def register_tool_group(
self,
@ -45,53 +106,26 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
mcp_endpoint: URL | None = None,
args: dict[str, Any] | None = None,
) -> None:
tools = []
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint)
tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
for tool_def in tool_defs.data:
tools.append(
ToolWithACL(
identifier=tool_def.name,
toolgroup_id=toolgroup_id,
description=tool_def.description or "",
parameters=tool_def.parameters or [],
provider_id=provider_id,
provider_resource_id=tool_def.name,
metadata=tool_def.metadata,
tool_host=tool_host,
)
)
for tool in tools:
existing_tool = await self.get_tool(tool.identifier)
# Compare existing and new object if one exists
if existing_tool:
existing_dict = existing_tool.model_dump()
new_dict = tool.model_dump()
if existing_dict != new_dict:
raise ValueError(
f"Object {tool.identifier} already exists in registry. Please use a different identifier."
)
await self.register_object(tool)
await self.dist_registry.register(
ToolGroupWithACL(
toolgroup = ToolGroupWithACL(
identifier=toolgroup_id,
provider_id=provider_id,
provider_resource_id=toolgroup_id,
mcp_endpoint=mcp_endpoint,
args=args,
)
)
await self.register_object(toolgroup)
# ideally, indexing of the tools should not be necessary because anyone using
# the tools should first list the tools and then use them. but there are assumptions
# baked in some of the code and tests right now.
if not toolgroup.mcp_endpoint:
await self._index_tools(toolgroup)
return toolgroup
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
tool_group = await self.get_tool_group(toolgroup_id)
if tool_group is None:
raise ValueError(f"Tool group {toolgroup_id} not found")
tools = await self.list_tools(toolgroup_id)
for tool in getattr(tools, "data", []):
await self.unregister_object(tool)
await self.unregister_object(tool_group)
async def shutdown(self) -> None:

View file

@ -6,20 +6,23 @@
import inspect
import re
from collections.abc import Callable
from typing import Any
from pydantic import BaseModel
from aiohttp import hdrs
from starlette.routing import Route
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
from llama_stack.distribution.resolver import api_protocol_map
from llama_stack.providers.datatypes import Api
class ApiEndpoint(BaseModel):
route: str
method: str
name: str
descriptive_name: str | None = None
EndpointFunc = Callable[..., Any]
PathParams = dict[str, str]
RouteInfo = tuple[EndpointFunc, str]
PathImpl = dict[str, RouteInfo]
RouteImpls = dict[str, PathImpl]
RouteMatch = tuple[EndpointFunc, PathParams, str]
def toolgroup_protocol_map():
@ -28,13 +31,13 @@ def toolgroup_protocol_map():
}
def get_all_api_endpoints() -> dict[Api, list[ApiEndpoint]]:
def get_all_api_routes() -> dict[Api, list[Route]]:
apis = {}
protocols = api_protocol_map()
toolgroup_protocols = toolgroup_protocol_map()
for api, protocol in protocols.items():
endpoints = []
routes = []
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
# HACK ALERT
@ -51,26 +54,28 @@ def get_all_api_endpoints() -> dict[Api, list[ApiEndpoint]]:
if not hasattr(method, "__webmethod__"):
continue
webmethod = method.__webmethod__
route = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
if webmethod.method == "GET":
method = "get"
elif webmethod.method == "DELETE":
method = "delete"
# The __webmethod__ attribute is dynamically added by the @webmethod decorator
# mypy doesn't know about this dynamic attribute, so we ignore the attr-defined error
webmethod = method.__webmethod__ # type: ignore[attr-defined]
path = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
if webmethod.method == hdrs.METH_GET:
http_method = hdrs.METH_GET
elif webmethod.method == hdrs.METH_DELETE:
http_method = hdrs.METH_DELETE
else:
method = "post"
endpoints.append(
ApiEndpoint(route=route, method=method, name=name, descriptive_name=webmethod.descriptive_name)
)
http_method = hdrs.METH_POST
routes.append(
Route(path=path, methods=[http_method], name=name, endpoint=None)
) # setting endpoint to None since don't use a Router object
apis[api] = endpoints
apis[api] = routes
return apis
def initialize_endpoint_impls(impls):
endpoints = get_all_api_endpoints()
endpoint_impls = {}
def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls:
routes = get_all_api_routes()
route_impls: RouteImpls = {}
def _convert_path_to_regex(path: str) -> str:
# Convert {param} to named capture groups
@ -83,29 +88,34 @@ def initialize_endpoint_impls(impls):
return f"^{pattern}$"
for api, api_endpoints in endpoints.items():
for api, api_routes in routes.items():
if api not in impls:
continue
for endpoint in api_endpoints:
for route in api_routes:
impl = impls[api]
func = getattr(impl, endpoint.name)
if endpoint.method not in endpoint_impls:
endpoint_impls[endpoint.method] = {}
endpoint_impls[endpoint.method][_convert_path_to_regex(endpoint.route)] = (
func = getattr(impl, route.name)
# Get the first (and typically only) method from the set, filtering out HEAD
available_methods = [m for m in route.methods if m != "HEAD"]
if not available_methods:
continue # Skip if only HEAD method is available
method = available_methods[0].lower()
if method not in route_impls:
route_impls[method] = {}
route_impls[method][_convert_path_to_regex(route.path)] = (
func,
endpoint.descriptive_name or endpoint.route,
route.path,
)
return endpoint_impls
return route_impls
def find_matching_endpoint(method, path, endpoint_impls):
def find_matching_route(method: str, path: str, route_impls: RouteImpls) -> RouteMatch:
"""Find the matching endpoint implementation for a given method and path.
Args:
method: HTTP method (GET, POST, etc.)
path: URL path to match against
endpoint_impls: A dictionary of endpoint implementations
route_impls: A dictionary of endpoint implementations
Returns:
A tuple of (endpoint_function, path_params, descriptive_name)
@ -113,7 +123,7 @@ def find_matching_endpoint(method, path, endpoint_impls):
Raises:
ValueError: If no matching endpoint is found
"""
impls = endpoint_impls.get(method.lower())
impls = route_impls.get(method.lower())
if not impls:
raise ValueError(f"No endpoint found for {path}")

View file

@ -6,6 +6,7 @@
import argparse
import asyncio
import functools
import inspect
import json
import os
@ -13,6 +14,7 @@ import ssl
import sys
import traceback
import warnings
from collections.abc import Callable
from contextlib import asynccontextmanager
from importlib.metadata import version as parse_version
from pathlib import Path
@ -20,6 +22,7 @@ from typing import Annotated, Any
import rich.pretty
import yaml
from aiohttp import hdrs
from fastapi import Body, FastAPI, HTTPException, Request
from fastapi import Path as FastapiPath
from fastapi.exceptions import RequestValidationError
@ -35,9 +38,10 @@ from llama_stack.distribution.request_headers import (
request_provider_data_context,
)
from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.server.endpoints import (
find_matching_endpoint,
initialize_endpoint_impls,
from llama_stack.distribution.server.routes import (
find_matching_route,
get_all_api_routes,
initialize_route_impls,
)
from llama_stack.distribution.stack import (
construct_stack,
@ -60,7 +64,6 @@ from llama_stack.providers.utils.telemetry.tracing import (
)
from .auth import AuthenticationMiddleware
from .endpoints import get_all_api_endpoints
from .quota import QuotaMiddleware
REPO_ROOT = Path(__file__).parent.parent.parent.parent
@ -209,8 +212,9 @@ async def log_request_pre_validation(request: Request):
logger.warning(f"Could not read or log request body for {request.method} {request.url.path}: {e}")
def create_dynamic_typed_route(func: Any, method: str, route: str):
async def endpoint(request: Request, **kwargs):
def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
@functools.wraps(func)
async def route_handler(request: Request, **kwargs):
# Get auth attributes from the request scope
user_attributes = request.scope.get("user_attributes", {})
@ -250,9 +254,9 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
for param in new_params[1:]
]
endpoint.__signature__ = sig.replace(parameters=new_params)
route_handler.__signature__ = sig.replace(parameters=new_params)
return endpoint
return route_handler
class TracingMiddleware:
@ -274,14 +278,14 @@ class TracingMiddleware:
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
return await self.app(scope, receive, send)
if not hasattr(self, "endpoint_impls"):
self.endpoint_impls = initialize_endpoint_impls(self.impls)
if not hasattr(self, "route_impls"):
self.route_impls = initialize_route_impls(self.impls)
try:
_, _, trace_path = find_matching_endpoint(scope.get("method", "GET"), path, self.endpoint_impls)
_, _, trace_path = find_matching_route(scope.get("method", hdrs.METH_GET), path, self.route_impls)
except ValueError:
# If no matching endpoint is found, pass through to FastAPI
logger.debug(f"No matching endpoint found for path: {path}, falling back to FastAPI")
logger.debug(f"No matching route found for path: {path}, falling back to FastAPI")
return await self.app(scope, receive, send)
trace_attributes = {"__location__": "server", "raw_path": path}
@ -423,7 +427,7 @@ def main(args: argparse.Namespace | None = None):
logger.info("Run configuration:")
safe_config = redact_sensitive_fields(config.model_dump())
logger.info(yaml.dump(safe_config, indent=2))
logger.info(yaml.dump(safe_config, indent=2, default_style=None))
app = FastAPI(
lifespan=lifespan,
@ -490,7 +494,7 @@ def main(args: argparse.Namespace | None = None):
else:
setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
all_endpoints = get_all_api_endpoints()
all_routes = get_all_api_routes()
if config.apis:
apis_to_serve = set(config.apis)
@ -508,24 +512,29 @@ def main(args: argparse.Namespace | None = None):
for api_str in apis_to_serve:
api = Api(api_str)
endpoints = all_endpoints[api]
routes = all_routes[api]
impl = impls[api]
for endpoint in endpoints:
if not hasattr(impl, endpoint.name):
for route in routes:
if not hasattr(impl, route.name):
# ideally this should be a typing violation already
raise ValueError(f"Could not find method {endpoint.name} on {impl}!!")
raise ValueError(f"Could not find method {route.name} on {impl}!")
impl_method = getattr(impl, endpoint.name)
logger.debug(f"{endpoint.method.upper()} {endpoint.route}")
impl_method = getattr(impl, route.name)
# Filter out HEAD method since it's automatically handled by FastAPI for GET routes
available_methods = [m for m in route.methods if m != "HEAD"]
if not available_methods:
raise ValueError(f"No methods found for {route.name} on {impl}")
method = available_methods[0]
logger.debug(f"{method} {route.path}")
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
getattr(app, method.lower())(route.path, response_model=None)(
create_dynamic_typed_route(
impl_method,
endpoint.method,
endpoint.route,
method.lower(),
route.path,
)
)

View file

@ -36,7 +36,7 @@ class DistributionRegistry(Protocol):
REGISTER_PREFIX = "distributions:registry"
KEY_VERSION = "v8"
KEY_VERSION = "v9"
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"

View file

@ -16,7 +16,7 @@ from llama_stack.apis.datatypes import Api
from llama_stack.apis.models import Model
from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.apis.shields import Shield
from llama_stack.apis.tools import Tool
from llama_stack.apis.tools import ToolGroup
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.schema_utils import json_schema_type
@ -74,10 +74,10 @@ class BenchmarksProtocolPrivate(Protocol):
async def register_benchmark(self, benchmark: Benchmark) -> None: ...
class ToolsProtocolPrivate(Protocol):
async def register_tool(self, tool: Tool) -> None: ...
class ToolGroupsProtocolPrivate(Protocol):
async def register_toolgroup(self, toolgroup: ToolGroup) -> None: ...
async def unregister_tool(self, tool_id: str) -> None: ...
async def unregister_toolgroup(self, toolgroup_id: str) -> None: ...
@json_schema_type

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import json
import time
import uuid
from collections.abc import AsyncIterator
from typing import Any, cast
@ -29,10 +30,12 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseObjectStream,
OpenAIResponseObjectStreamResponseCompleted,
OpenAIResponseObjectStreamResponseCreated,
OpenAIResponseObjectStreamResponseOutputTextDelta,
OpenAIResponseOutput,
OpenAIResponseOutputMessageContent,
OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseOutputMessageWebSearchToolCall,
)
from llama_stack.apis.inference.inference import (
@ -255,110 +258,14 @@ class OpenAIResponsesImpl:
"""
return await self.responses_store.list_response_input_items(response_id, after, before, include, limit, order)
async def create_openai_response(
async def _process_response_choices(
self,
input: str | list[OpenAIResponseInput],
model: str,
instructions: str | None = None,
previous_response_id: str | None = None,
store: bool | None = True,
stream: bool | None = False,
temperature: float | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
):
chat_response: OpenAIChatCompletion,
ctx: ChatCompletionContext,
tools: list[OpenAIResponseInputTool] | None,
) -> list[OpenAIResponseOutput]:
"""Handle tool execution and response message creation."""
output_messages: list[OpenAIResponseOutput] = []
stream = False if stream is None else stream
# Huge TODO: we need to run this in a loop, until morale improves
# Create context to run "chat completion"
input = await self._prepend_previous_response(input, previous_response_id)
messages = await _convert_response_input_to_chat_messages(input)
await self._prepend_instructions(messages, instructions)
chat_tools, mcp_tool_to_server, mcp_list_message = (
await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {}, None)
)
if mcp_list_message:
output_messages.append(mcp_list_message)
ctx = ChatCompletionContext(
model=model,
messages=messages,
tools=chat_tools,
mcp_tool_to_server=mcp_tool_to_server,
stream=stream,
temperature=temperature,
)
# Run inference
chat_response = await self.inference_api.openai_chat_completion(
model=model,
messages=messages,
tools=chat_tools,
stream=stream,
temperature=temperature,
)
# Collect output
if stream:
# TODO: refactor this into a separate method that handles streaming
chat_response_id = ""
chat_response_content = []
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
# TODO: these chunk_ fields are hacky and only take the last chunk into account
chunk_created = 0
chunk_model = ""
chunk_finish_reason = ""
async for chunk in chat_response:
chat_response_id = chunk.id
chunk_created = chunk.created
chunk_model = chunk.model
for chunk_choice in chunk.choices:
# TODO: this only works for text content
chat_response_content.append(chunk_choice.delta.content or "")
if chunk_choice.finish_reason:
chunk_finish_reason = chunk_choice.finish_reason
# Aggregate tool call arguments across chunks, using their index as the aggregation key
if chunk_choice.delta.tool_calls:
for tool_call in chunk_choice.delta.tool_calls:
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
if response_tool_call:
response_tool_call.function.arguments += tool_call.function.arguments
else:
tool_call_dict: dict[str, Any] = tool_call.model_dump()
# Ensure we don't have any empty type field in the tool call dict.
# The OpenAI client used by providers often returns a type=None here.
tool_call_dict.pop("type", None)
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
chat_response_tool_calls[tool_call.index] = response_tool_call
# Convert the dict of tool calls by index to a list of tool calls to pass back in our response
if chat_response_tool_calls:
tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())]
else:
tool_calls = None
assistant_message = OpenAIAssistantMessageParam(
content="".join(chat_response_content),
tool_calls=tool_calls,
)
chat_response = OpenAIChatCompletion(
id=chat_response_id,
choices=[
OpenAIChoice(
message=assistant_message,
finish_reason=chunk_finish_reason,
index=0,
)
],
created=chunk_created,
model=chunk_model,
)
else:
# dump and reload to map to our pydantic types
chat_response = OpenAIChatCompletion(**chat_response.model_dump())
# Execute tool calls if any
for choice in chat_response.choices:
if choice.message.tool_calls and tools:
@ -380,19 +287,13 @@ class OpenAIResponsesImpl:
else:
output_messages.append(await _convert_chat_choice_to_response_message(choice))
# Create response object
response = OpenAIResponseObject(
created_at=chat_response.created,
id=f"resp-{uuid.uuid4()}",
model=model,
object="response",
status="completed",
output=output_messages,
)
logger.debug(f"OpenAI Responses response: {response}")
return output_messages
# Store response if requested
if store:
async def _store_response(
self,
response: OpenAIResponseObject,
input: str | list[OpenAIResponseInput],
) -> None:
new_input_id = f"msg_{uuid.uuid4()}"
if isinstance(input, str):
# synthesize a message from the input string
@ -421,17 +322,233 @@ class OpenAIResponsesImpl:
input=input_items_data,
)
async def create_openai_response(
self,
input: str | list[OpenAIResponseInput],
model: str,
instructions: str | None = None,
previous_response_id: str | None = None,
store: bool | None = True,
stream: bool | None = False,
temperature: float | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
):
stream = False if stream is None else stream
output_messages: list[OpenAIResponseOutput] = []
# Input preprocessing
input = await self._prepend_previous_response(input, previous_response_id)
messages = await _convert_response_input_to_chat_messages(input)
await self._prepend_instructions(messages, instructions)
# Tool setup
chat_tools, mcp_tool_to_server, mcp_list_message = (
await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {}, None)
)
if mcp_list_message:
output_messages.append(mcp_list_message)
ctx = ChatCompletionContext(
model=model,
messages=messages,
tools=chat_tools,
mcp_tool_to_server=mcp_tool_to_server,
stream=stream,
temperature=temperature,
)
inference_result = await self.inference_api.openai_chat_completion(
model=model,
messages=messages,
tools=chat_tools,
stream=stream,
temperature=temperature,
)
if stream:
return self._create_streaming_response(
inference_result=inference_result,
ctx=ctx,
output_messages=output_messages,
input=input,
model=model,
store=store,
tools=tools,
)
else:
return await self._create_non_streaming_response(
inference_result=inference_result,
ctx=ctx,
output_messages=output_messages,
input=input,
model=model,
store=store,
tools=tools,
)
async def async_response() -> AsyncIterator[OpenAIResponseObjectStream]:
# TODO: response created should actually get emitted much earlier in the process
yield OpenAIResponseObjectStreamResponseCreated(response=response)
yield OpenAIResponseObjectStreamResponseCompleted(response=response)
async def _create_non_streaming_response(
self,
inference_result: Any,
ctx: ChatCompletionContext,
output_messages: list[OpenAIResponseOutput],
input: str | list[OpenAIResponseInput],
model: str,
store: bool | None,
tools: list[OpenAIResponseInputTool] | None,
) -> OpenAIResponseObject:
chat_response = OpenAIChatCompletion(**inference_result.model_dump())
return async_response()
# Process response choices (tool execution and message creation)
output_messages.extend(
await self._process_response_choices(
chat_response=chat_response,
ctx=ctx,
tools=tools,
)
)
response = OpenAIResponseObject(
created_at=chat_response.created,
id=f"resp-{uuid.uuid4()}",
model=model,
object="response",
status="completed",
output=output_messages,
)
logger.debug(f"OpenAI Responses response: {response}")
# Store response if requested
if store:
await self._store_response(
response=response,
input=input,
)
return response
async def _create_streaming_response(
self,
inference_result: Any,
ctx: ChatCompletionContext,
output_messages: list[OpenAIResponseOutput],
input: str | list[OpenAIResponseInput],
model: str,
store: bool | None,
tools: list[OpenAIResponseInputTool] | None,
) -> AsyncIterator[OpenAIResponseObjectStream]:
# Create initial response and emit response.created immediately
response_id = f"resp-{uuid.uuid4()}"
created_at = int(time.time())
initial_response = OpenAIResponseObject(
created_at=created_at,
id=response_id,
model=model,
object="response",
status="in_progress",
output=output_messages.copy(),
)
# Emit response.created immediately
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
# For streaming, inference_result is an async iterator of chunks
# Stream chunks and emit delta events as they arrive
chat_response_id = ""
chat_response_content = []
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
chunk_created = 0
chunk_model = ""
chunk_finish_reason = ""
sequence_number = 0
# Create a placeholder message item for delta events
message_item_id = f"msg_{uuid.uuid4()}"
async for chunk in inference_result:
chat_response_id = chunk.id
chunk_created = chunk.created
chunk_model = chunk.model
for chunk_choice in chunk.choices:
# Emit incremental text content as delta events
if chunk_choice.delta.content:
sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputTextDelta(
content_index=0,
delta=chunk_choice.delta.content,
item_id=message_item_id,
output_index=0,
sequence_number=sequence_number,
)
# Collect content for final response
chat_response_content.append(chunk_choice.delta.content or "")
if chunk_choice.finish_reason:
chunk_finish_reason = chunk_choice.finish_reason
# Aggregate tool call arguments across chunks, using their index as the aggregation key
if chunk_choice.delta.tool_calls:
for tool_call in chunk_choice.delta.tool_calls:
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
if response_tool_call:
response_tool_call.function.arguments += tool_call.function.arguments
else:
tool_call_dict: dict[str, Any] = tool_call.model_dump()
tool_call_dict.pop("type", None)
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
chat_response_tool_calls[tool_call.index] = response_tool_call
# Convert collected chunks to complete response
if chat_response_tool_calls:
tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())]
else:
tool_calls = None
assistant_message = OpenAIAssistantMessageParam(
content="".join(chat_response_content),
tool_calls=tool_calls,
)
chat_response_obj = OpenAIChatCompletion(
id=chat_response_id,
choices=[
OpenAIChoice(
message=assistant_message,
finish_reason=chunk_finish_reason,
index=0,
)
],
created=chunk_created,
model=chunk_model,
)
# Process response choices (tool execution and message creation)
output_messages.extend(
await self._process_response_choices(
chat_response=chat_response_obj,
ctx=ctx,
tools=tools,
)
)
# Create final response
final_response = OpenAIResponseObject(
created_at=created_at,
id=response_id,
model=model,
object="response",
status="completed",
output=output_messages,
)
if store:
await self._store_response(
response=final_response,
input=input,
)
# Emit response.completed
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
async def _convert_response_tools_to_chat_tools(
self, tools: list[OpenAIResponseInputTool]
) -> tuple[
@ -441,7 +558,6 @@ class OpenAIResponsesImpl:
]:
from llama_stack.apis.agents.openai_responses import (
MCPListToolsTool,
OpenAIResponseOutputMessageMCPListTools,
)
from llama_stack.apis.tools.tools import Tool

View file

@ -75,6 +75,8 @@ class PromptGuardShield:
self.temperature = temperature
self.threshold = threshold
self.device = "cpu"
if torch.cuda.is_available():
self.device = "cuda"
# load model and tokenizer

View file

@ -25,14 +25,14 @@ from llama_stack.apis.tools import (
RAGQueryConfig,
RAGQueryResult,
RAGToolRuntime,
Tool,
ToolDef,
ToolGroup,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from llama_stack.providers.utils.memory.vector_store import (
content_from_doc,
@ -49,7 +49,7 @@ def make_random_string(length: int = 8):
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime):
def __init__(
self,
config: RagToolRuntimeConfig,
@ -66,10 +66,10 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
async def shutdown(self):
pass
async def register_tool(self, tool: Tool) -> None:
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
pass
async def unregister_tool(self, tool_id: str) -> None:
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
return
async def insert(

View file

@ -19,10 +19,10 @@ def available_providers() -> list[ProviderSpec]:
api=Api.agents,
provider_type="inline::meta-reference",
pip_packages=[
"matplotlib",
"pillow",
"pandas",
"scikit-learn",
# "matplotlib",
# "pillow",
# "pandas",
# "scikit-learn",
]
+ kvstore_dependencies(),
module="llama_stack.providers.inline.agents.meta_reference",

View file

@ -13,7 +13,7 @@ def available_providers() -> list[ProviderSpec]:
InlineProviderSpec(
api=Api.eval,
provider_type="inline::meta-reference",
pip_packages=["tree_sitter", "pythainlp", "langdetect", "emoji", "nltk"],
# pip_packages=["tree_sitter", "pythainlp", "langdetect", "emoji", "nltk"],
module="llama_stack.providers.inline.eval.meta_reference",
config_class="llama_stack.providers.inline.eval.meta_reference.MetaReferenceEvalConfig",
api_dependencies=[

View file

@ -20,16 +20,16 @@ def available_providers() -> list[ProviderSpec]:
api=Api.tool_runtime,
provider_type="inline::rag-runtime",
pip_packages=[
"blobfile",
"chardet",
"pypdf",
"tqdm",
"numpy",
"scikit-learn",
"scipy",
"nltk",
"sentencepiece",
"transformers",
# "blobfile",
# "chardet",
# "pypdf",
# "tqdm",
# "numpy",
# "scikit-learn",
# "scipy",
# "nltk",
# "sentencepiece",
# "transformers",
],
module="llama_stack.providers.inline.tool_runtime.rag",
config_class="llama_stack.providers.inline.tool_runtime.rag.config.RagToolRuntimeConfig",

View file

@ -4,8 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from llama_stack.schema_utils import json_schema_type
@ -24,11 +25,27 @@ class VLLMInferenceAdapterConfig(BaseModel):
default="fake",
description="The API token",
)
tls_verify: bool = Field(
tls_verify: bool | str = Field(
default=True,
description="Whether to verify TLS certificates",
description="Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file.",
)
@field_validator("tls_verify")
@classmethod
def validate_tls_verify(cls, v):
if isinstance(v, str):
# Check if it's a boolean string
if v.lower() in ("true", "false"):
return v.lower() == "true"
# Otherwise, treat it as a cert path
cert_path = Path(v).expanduser().resolve()
if not cert_path.exists():
raise ValueError(f"TLS certificate file does not exist: {v}")
if not cert_path.is_file():
raise ValueError(f"TLS certificate path is not a file: {v}")
return v
return v
@classmethod
def sample_run_config(
cls,

View file

@ -313,7 +313,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
return AsyncOpenAI(
base_url=self.config.url,
api_key=self.config.api_token,
http_client=None if self.config.tls_verify else httpx.AsyncClient(verify=False),
http_client=httpx.AsyncClient(verify=self.config.tls_verify),
)
async def completion(

View file

@ -12,19 +12,19 @@ import httpx
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
ListToolDefsResponse,
Tool,
ToolDef,
ToolGroup,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from .config import BingSearchToolConfig
class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
class BingSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: BingSearchToolConfig):
self.config = config
self.url = "https://api.bing.microsoft.com/v7.0/search"
@ -32,10 +32,10 @@ class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestP
async def initialize(self):
pass
async def register_tool(self, tool: Tool) -> None:
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
pass
async def unregister_tool(self, tool_id: str) -> None:
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
return
def _get_api_key(self) -> str:

View file

@ -11,30 +11,30 @@ import httpx
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
ListToolDefsResponse,
Tool,
ToolDef,
ToolGroup,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.models.llama.datatypes import BuiltinTool
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from .config import BraveSearchToolConfig
class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
class BraveSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: BraveSearchToolConfig):
self.config = config
async def initialize(self):
pass
async def register_tool(self, tool: Tool) -> None:
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
pass
async def unregister_tool(self, tool_id: str) -> None:
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
return
def _get_api_key(self) -> str:

View file

@ -10,8 +10,8 @@ from pydantic import BaseModel
class MCPProviderDataValidator(BaseModel):
# mcp_endpoint => list of headers to send
mcp_headers: dict[str, list[str]] | None = None
# mcp_endpoint => dict of headers to send
mcp_headers: dict[str, dict[str, str]] | None = None
class MCPProviderConfig(BaseModel):

View file

@ -11,26 +11,33 @@ from llama_stack.apis.common.content_types import URL
from llama_stack.apis.datatypes import Api
from llama_stack.apis.tools import (
ListToolDefsResponse,
ToolGroup,
ToolInvocationResult,
ToolRuntime,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from llama_stack.providers.utils.tools.mcp import convert_header_list_to_dict, invoke_mcp_tool, list_mcp_tools
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool, list_mcp_tools
from .config import MCPProviderConfig
logger = get_logger(__name__, category="tools")
class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]):
self.config = config
async def initialize(self):
pass
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
pass
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
return
async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
) -> ListToolDefsResponse:
@ -62,5 +69,5 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, Nee
for uri, values in provider_data.mcp_headers.items():
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
continue
headers.update(convert_header_list_to_dict(values))
headers.update(values)
return headers

View file

@ -12,29 +12,29 @@ import httpx
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
ListToolDefsResponse,
Tool,
ToolDef,
ToolGroup,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from .config import TavilySearchToolConfig
class TavilySearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
class TavilySearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: TavilySearchToolConfig):
self.config = config
async def initialize(self):
pass
async def register_tool(self, tool: Tool) -> None:
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
pass
async def unregister_tool(self, tool_id: str) -> None:
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
return
def _get_api_key(self) -> str:

View file

@ -12,19 +12,19 @@ import httpx
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
ListToolDefsResponse,
Tool,
ToolDef,
ToolGroup,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from .config import WolframAlphaToolConfig
class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
class WolframAlphaToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: WolframAlphaToolConfig):
self.config = config
self.url = "https://api.wolframalpha.com/v2/query"
@ -32,10 +32,10 @@ class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
async def initialize(self):
pass
async def register_tool(self, tool: Tool) -> None:
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
pass
async def unregister_tool(self, tool_id: str) -> None:
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
return
def _get_api_key(self) -> str:

View file

@ -1402,9 +1402,8 @@ class OpenAIChatCompletionToLlamaStackMixin:
outstanding_responses: list[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]],
):
id = f"chatcmpl-{uuid.uuid4()}"
for outstanding_response in outstanding_responses:
for i, outstanding_response in enumerate(outstanding_responses):
response = await outstanding_response
i = 0
async for chunk in response:
event = chunk.event
finish_reason = _convert_stop_reason_to_openai_finish_reason(event.stop_reason)
@ -1459,7 +1458,6 @@ class OpenAIChatCompletionToLlamaStackMixin:
model=model,
object="chat.completion.chunk",
)
i = i + 1
async def _process_non_stream_response(
self, model: str, outstanding_responses: list[Awaitable[ChatCompletionResponse]]

View file

@ -51,16 +51,6 @@ async def sse_client_wrapper(endpoint: str, headers: dict[str, str]):
raise
def convert_header_list_to_dict(header_list: list[str]) -> dict[str, str]:
headers = {}
for header in header_list:
parts = header.split(":")
if len(parts) == 2:
k, v = parts
headers[k.strip()] = v.strip()
return headers
async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse:
tools = []
async with sse_client_wrapper(endpoint, headers) as session:

View file

@ -1,855 +0,0 @@
{
"bedrock": [
"aiosqlite",
"autoevals",
"blobfile",
"boto3",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlalchemy[asyncio]",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn"
],
"cerebras": [
"aiosqlite",
"autoevals",
"blobfile",
"cerebras_cloud_sdk",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"matplotlib",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlalchemy[asyncio]",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"ci-tests": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"fastapi",
"fire",
"fireworks-ai",
"httpx",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlalchemy[asyncio]",
"sqlite-vec",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"dell": [
"aiohttp",
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"huggingface_hub",
"langdetect",
"matplotlib",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlalchemy[asyncio]",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"fireworks": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"fireworks-ai",
"httpx",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlalchemy[asyncio]",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"groq": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"litellm",
"matplotlib",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlalchemy[asyncio]",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn"
],
"hf-endpoint": [
"aiohttp",
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"huggingface_hub",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlalchemy[asyncio]",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn"
],
"hf-serverless": [
"aiohttp",
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"huggingface_hub",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlalchemy[asyncio]",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"llama_api": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"fastapi",
"fire",
"httpx",
"langdetect",
"litellm",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlalchemy[asyncio]",
"sqlite-vec",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"meta-reference-gpu": [
"accelerate",
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"fairscale",
"faiss-cpu",
"fastapi",
"fbgemm-gpu-genai==1.1.2",
"fire",
"httpx",
"langdetect",
"lm-format-enforcer",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentence-transformers",
"sentencepiece",
"sqlalchemy[asyncio]",
"torch",
"torchao==0.8.0",
"torchvision",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"zmq"
],
"nvidia": [
"aiohttp",
"aiosqlite",
"blobfile",
"chardet",
"datasets",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"matplotlib",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlalchemy[asyncio]",
"tqdm",
"transformers",
"uvicorn"
],
"ollama": [
"aiohttp",
"aiosqlite",
"blobfile",
"chardet",
"chromadb-client",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"ollama",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"peft",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"requests",
"sqlalchemy[asyncio]",
"tqdm",
"tree_sitter",
"trl",
"uvicorn"
],
"open-benchmark": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"fastapi",
"fire",
"httpx",
"langdetect",
"litellm",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlalchemy[asyncio]",
"sqlite-vec",
"together",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn"
],
"passthrough": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlalchemy[asyncio]",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"remote-vllm": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlalchemy[asyncio]",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"sambanova": [
"aiosqlite",
"blobfile",
"chardet",
"chromadb-client",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"litellm",
"matplotlib",
"mcp",
"nltk",
"numpy",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlalchemy[asyncio]",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"starter": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"fastapi",
"fire",
"fireworks-ai",
"httpx",
"langdetect",
"litellm",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlalchemy[asyncio]",
"sqlite-vec",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"tgi": [
"aiohttp",
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"huggingface_hub",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlalchemy[asyncio]",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"together": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlalchemy[asyncio]",
"together",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"verification": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"fastapi",
"fire",
"httpx",
"langdetect",
"litellm",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlalchemy[asyncio]",
"sqlite-vec",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"vllm-gpu": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlalchemy[asyncio]",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"vllm",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"watsonx": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"ibm_watson_machine_learning",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlalchemy[asyncio]",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
]
}

View file

@ -25,23 +25,7 @@ distribution_spec:
- inline::rag-runtime
- remote::model-context-protocol
- remote::wolfram-alpha
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db
inference_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/inference_store.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: remote::ollama
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: remote::ollama
provider_model_id: all-minilm:latest
model_type: embedding
image_type: container
image_type: conda
additional_pip_packages:
- sqlalchemy[asyncio]
- blobfile

View file

@ -13,8 +13,8 @@ from llama_stack.distribution.datatypes import (
ShieldInput,
ToolGroupInput,
)
from llama_stack.providers.inline.post_training.huggingface import HuggingFacePostTrainingConfig
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
#from llama_stack.providers.inline.post_training.huggingface import HuggingFacePostTrainingConfig
#from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
@ -32,7 +32,6 @@ def get_distribution_template() -> DistributionTemplate:
"tool_runtime": [
"remote::brave-search",
"remote::tavily-search",
"inline::rag-runtime",
"remote::model-context-protocol",
"remote::wolfram-alpha",
],
@ -43,11 +42,11 @@ def get_distribution_template() -> DistributionTemplate:
provider_type="remote::ollama",
config=OllamaImplConfig.sample_run_config(),
)
vector_io_provider_faiss = Provider(
provider_id="faiss",
provider_type="inline::faiss",
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
#vector_io_provider_faiss = Provider(
# provider_id="faiss",
# provider_type="inline::faiss",
# config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
#)
inference_model = ModelInput(
model_id="${env.INFERENCE_MODEL}",
provider_id="ollama",
@ -70,10 +69,6 @@ def get_distribution_template() -> DistributionTemplate:
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
ToolGroupInput(
toolgroup_id="builtin::wolfram_alpha",
provider_id="wolfram-alpha",

View file

@ -24,6 +24,10 @@ providers:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db
- provider_id: chromadb
provider_type: remote::chromadb
config:
url: ${env.CHROMADB_URL:http://host.docker.internal:8000}
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard

View file

@ -2,9 +2,9 @@
import { useEffect, useState } from "react";
import { useParams } from "next/navigation";
import LlamaStackClient from "llama-stack-client";
import { ChatCompletion } from "@/lib/types";
import { ChatCompletionDetailView } from "@/components/chat-completions/chat-completion-detail";
import { client } from "@/lib/client";
export default function ChatCompletionDetailPage() {
const params = useParams();
@ -22,10 +22,6 @@ export default function ChatCompletionDetailPage() {
return;
}
const client = new LlamaStackClient({
baseURL: process.env.NEXT_PUBLIC_LLAMA_STACK_BASE_URL,
});
const fetchCompletionDetail = async () => {
setIsLoading(true);
setError(null);

View file

@ -1,45 +1,19 @@
"use client";
import React from "react";
import { usePathname, useParams } from "next/navigation";
import {
PageBreadcrumb,
BreadcrumbSegment,
} from "@/components/layout/page-breadcrumb";
import { truncateText } from "@/lib/truncate-text";
import LogsLayout from "@/components/layout/logs-layout";
export default function ChatCompletionsLayout({
children,
}: {
children: React.ReactNode;
}) {
const pathname = usePathname();
const params = useParams();
let segments: BreadcrumbSegment[] = [];
// Default for /logs/chat-completions
if (pathname === "/logs/chat-completions") {
segments = [{ label: "Chat Completions" }];
}
// For /logs/chat-completions/[id]
const idParam = params?.id;
if (idParam && typeof idParam === "string") {
segments = [
{ label: "Chat Completions", href: "/logs/chat-completions" },
{ label: `Details (${truncateText(idParam, 20)})` },
];
}
return (
<div className="container mx-auto p-4">
<>
{segments.length > 0 && (
<PageBreadcrumb segments={segments} className="mb-4" />
)}
<LogsLayout
sectionLabel="Chat Completions"
basePath="/logs/chat-completions"
>
{children}
</>
</div>
</LogsLayout>
);
}

View file

@ -1,9 +1,9 @@
"use client";
import { useEffect, useState } from "react";
import LlamaStackClient from "llama-stack-client";
import { ChatCompletion } from "@/lib/types";
import { ChatCompletionsTable } from "@/components/chat-completions/chat-completion-table";
import { ChatCompletionsTable } from "@/components/chat-completions/chat-completions-table";
import { client } from "@/lib/client";
export default function ChatCompletionsPage() {
const [completions, setCompletions] = useState<ChatCompletion[]>([]);
@ -11,9 +11,6 @@ export default function ChatCompletionsPage() {
const [error, setError] = useState<Error | null>(null);
useEffect(() => {
const client = new LlamaStackClient({
baseURL: process.env.NEXT_PUBLIC_LLAMA_STACK_BASE_URL,
});
const fetchCompletions = async () => {
setIsLoading(true);
setError(null);
@ -21,7 +18,7 @@ export default function ChatCompletionsPage() {
const response = await client.chat.completions.list();
const data = Array.isArray(response)
? response
: (response as any).data;
: (response as { data: ChatCompletion[] }).data;
if (Array.isArray(data)) {
setCompletions(data);
@ -46,7 +43,7 @@ export default function ChatCompletionsPage() {
return (
<ChatCompletionsTable
completions={completions}
data={completions}
isLoading={isLoading}
error={error}
/>

View file

@ -0,0 +1,125 @@
"use client";
import { useEffect, useState } from "react";
import { useParams } from "next/navigation";
import type { ResponseObject } from "llama-stack-client/resources/responses/responses";
import { OpenAIResponse, InputItemListResponse } from "@/lib/types";
import { ResponseDetailView } from "@/components/responses/responses-detail";
import { client } from "@/lib/client";
export default function ResponseDetailPage() {
const params = useParams();
const id = params.id as string;
const [responseDetail, setResponseDetail] = useState<OpenAIResponse | null>(
null,
);
const [inputItems, setInputItems] = useState<InputItemListResponse | null>(
null,
);
const [isLoading, setIsLoading] = useState<boolean>(true);
const [isLoadingInputItems, setIsLoadingInputItems] = useState<boolean>(true);
const [error, setError] = useState<Error | null>(null);
const [inputItemsError, setInputItemsError] = useState<Error | null>(null);
// Helper function to convert ResponseObject to OpenAIResponse
const convertResponseObject = (
responseData: ResponseObject,
): OpenAIResponse => {
return {
id: responseData.id,
created_at: responseData.created_at,
model: responseData.model,
object: responseData.object,
status: responseData.status,
output: responseData.output as OpenAIResponse["output"],
input: [], // ResponseObject doesn't include input; component uses inputItems prop instead
error: responseData.error,
parallel_tool_calls: responseData.parallel_tool_calls,
previous_response_id: responseData.previous_response_id,
temperature: responseData.temperature,
top_p: responseData.top_p,
truncation: responseData.truncation,
user: responseData.user,
};
};
useEffect(() => {
if (!id) {
setError(new Error("Response ID is missing."));
setIsLoading(false);
return;
}
const fetchResponseDetail = async () => {
setIsLoading(true);
setIsLoadingInputItems(true);
setError(null);
setInputItemsError(null);
setResponseDetail(null);
setInputItems(null);
try {
const [responseResult, inputItemsResult] = await Promise.allSettled([
client.responses.retrieve(id),
client.responses.inputItems.list(id, { order: "asc" }),
]);
// Handle response detail result
if (responseResult.status === "fulfilled") {
const convertedResponse = convertResponseObject(responseResult.value);
setResponseDetail(convertedResponse);
} else {
console.error(
`Error fetching response detail for ID ${id}:`,
responseResult.reason,
);
setError(
responseResult.reason instanceof Error
? responseResult.reason
: new Error("Failed to fetch response detail"),
);
}
// Handle input items result
if (inputItemsResult.status === "fulfilled") {
const inputItemsData =
inputItemsResult.value as unknown as InputItemListResponse;
setInputItems(inputItemsData);
} else {
console.error(
`Error fetching input items for response ID ${id}:`,
inputItemsResult.reason,
);
setInputItemsError(
inputItemsResult.reason instanceof Error
? inputItemsResult.reason
: new Error("Failed to fetch input items"),
);
}
} catch (err) {
console.error(`Unexpected error fetching data for ID ${id}:`, err);
setError(
err instanceof Error ? err : new Error("Unexpected error occurred"),
);
} finally {
setIsLoading(false);
setIsLoadingInputItems(false);
}
};
fetchResponseDetail();
}, [id]);
return (
<ResponseDetailView
response={responseDetail}
inputItems={inputItems}
isLoading={isLoading}
isLoadingInputItems={isLoadingInputItems}
error={error}
inputItemsError={inputItemsError}
id={id}
/>
);
}

View file

@ -0,0 +1,16 @@
"use client";
import React from "react";
import LogsLayout from "@/components/layout/logs-layout";
export default function ResponsesLayout({
children,
}: {
children: React.ReactNode;
}) {
return (
<LogsLayout sectionLabel="Responses" basePath="/logs/responses">
{children}
</LogsLayout>
);
}

View file

@ -1,7 +1,66 @@
export default function Responses() {
"use client";
import { useEffect, useState } from "react";
import type { ResponseListResponse } from "llama-stack-client/resources/responses/responses";
import { OpenAIResponse } from "@/lib/types";
import { ResponsesTable } from "@/components/responses/responses-table";
import { client } from "@/lib/client";
export default function ResponsesPage() {
const [responses, setResponses] = useState<OpenAIResponse[]>([]);
const [isLoading, setIsLoading] = useState<boolean>(true);
const [error, setError] = useState<Error | null>(null);
// Helper function to convert ResponseListResponse.Data to OpenAIResponse
const convertResponseListData = (
responseData: ResponseListResponse.Data,
): OpenAIResponse => {
return {
id: responseData.id,
created_at: responseData.created_at,
model: responseData.model,
object: responseData.object,
status: responseData.status,
output: responseData.output as OpenAIResponse["output"],
input: responseData.input as OpenAIResponse["input"],
error: responseData.error,
parallel_tool_calls: responseData.parallel_tool_calls,
previous_response_id: responseData.previous_response_id,
temperature: responseData.temperature,
top_p: responseData.top_p,
truncation: responseData.truncation,
user: responseData.user,
};
};
useEffect(() => {
const fetchResponses = async () => {
setIsLoading(true);
setError(null);
try {
const response = await client.responses.list();
const responseListData = response as ResponseListResponse;
const convertedResponses: OpenAIResponse[] = responseListData.data.map(
convertResponseListData,
);
setResponses(convertedResponses);
} catch (err) {
console.error("Error fetching responses:", err);
setError(
err instanceof Error ? err : new Error("Failed to fetch responses"),
);
setResponses([]);
} finally {
setIsLoading(false);
}
};
fetchResponses();
}, []);
return (
<div>
<h1>Under Construction</h1>
</div>
<ResponsesTable data={responses} isLoading={isLoading} error={error} />
);
}

View file

@ -75,7 +75,7 @@ describe("ChatCompletionDetailView", () => {
/>,
);
expect(
screen.getByText("No details found for completion ID: notfound-id."),
screen.getByText("No details found for ID: notfound-id."),
).toBeInTheDocument();
});

View file

@ -3,45 +3,14 @@
import { ChatMessage, ChatCompletion } from "@/lib/types";
import { ChatMessageItem } from "@/components/chat-completions/chat-messasge-item";
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
import { Skeleton } from "@/components/ui/skeleton";
function ChatCompletionDetailLoadingView() {
return (
<>
<Skeleton className="h-8 w-3/4 mb-6" /> {/* Title Skeleton */}
<div className="flex flex-col md:flex-row gap-6">
<div className="flex-grow md:w-2/3 space-y-6">
{[...Array(2)].map((_, i) => (
<Card key={`main-skeleton-card-${i}`}>
<CardHeader>
<CardTitle>
<Skeleton className="h-6 w-1/2" />
</CardTitle>
</CardHeader>
<CardContent className="space-y-2">
<Skeleton className="h-4 w-full" />
<Skeleton className="h-4 w-full" />
<Skeleton className="h-4 w-3/4" />
</CardContent>
</Card>
))}
</div>
<div className="md:w-1/3">
<div className="p-4 border rounded-lg shadow-sm bg-white space-y-3">
<Skeleton className="h-6 w-1/3 mb-3" />{" "}
{/* Properties Title Skeleton */}
{[...Array(5)].map((_, i) => (
<div key={`prop-skeleton-${i}`} className="space-y-1">
<Skeleton className="h-4 w-1/4" />
<Skeleton className="h-4 w-1/2" />
</div>
))}
</div>
</div>
</div>
</>
);
}
import {
DetailLoadingView,
DetailErrorView,
DetailNotFoundView,
DetailLayout,
PropertiesCard,
PropertyItem,
} from "@/components/layout/detail-layout";
interface ChatCompletionDetailViewProps {
completion: ChatCompletion | null;
@ -56,39 +25,23 @@ export function ChatCompletionDetailView({
error,
id,
}: ChatCompletionDetailViewProps) {
const title = "Chat Completion Details";
if (error) {
return (
<>
{/* We still want a title for consistency on error pages */}
<h1 className="text-2xl font-bold mb-6">Chat Completion Details</h1>
<p>
Error loading details for ID {id}: {error.message}
</p>
</>
);
return <DetailErrorView title={title} id={id} error={error} />;
}
if (isLoading) {
return <ChatCompletionDetailLoadingView />;
return <DetailLoadingView title={title} />;
}
if (!completion) {
// This state means: not loading, no error, but no completion data
return (
<>
{/* We still want a title for consistency on not-found pages */}
<h1 className="text-2xl font-bold mb-6">Chat Completion Details</h1>
<p>No details found for completion ID: {id}.</p>
</>
);
return <DetailNotFoundView title={title} id={id} />;
}
// If no error, not loading, and completion exists, render the details:
return (
// Main content cards
const mainContent = (
<>
<h1 className="text-2xl font-bold mb-6">Chat Completion Details</h1>
<div className="flex flex-col md:flex-row gap-6">
<div className="flex-grow md:w-2/3 space-y-6">
<Card>
<CardHeader>
<CardTitle>Input</CardTitle>
@ -98,13 +51,15 @@ export function ChatCompletionDetailView({
<ChatMessageItem key={`input-msg-${index}`} message={msg} />
))}
{completion.choices?.[0]?.message?.tool_calls &&
Array.isArray(completion.choices[0].message.tool_calls) &&
!completion.input_messages?.some(
(im) =>
im.role === "assistant" &&
im.tool_calls &&
Array.isArray(im.tool_calls) &&
im.tool_calls.length > 0,
) &&
completion.choices[0].message.tool_calls.map(
)
? completion.choices[0].message.tool_calls.map(
(toolCall: any, index: number) => {
const assistantToolCallMessage: ChatMessage = {
role: "assistant",
@ -118,7 +73,8 @@ export function ChatCompletionDetailView({
/>
);
},
)}
)
: null}
</CardContent>
</Card>
@ -138,61 +94,52 @@ export function ChatCompletionDetailView({
)}
</CardContent>
</Card>
</div>
</>
);
<div className="md:w-1/3">
<Card>
<CardHeader>
<CardTitle>Properties</CardTitle>
</CardHeader>
<CardContent>
<ul className="space-y-2 text-sm text-gray-600">
<li>
<strong>Created:</strong>{" "}
<span className="text-gray-900 font-medium">
{new Date(completion.created * 1000).toLocaleString()}
</span>
</li>
<li>
<strong>ID:</strong>{" "}
<span className="text-gray-900 font-medium">
{completion.id}
</span>
</li>
<li>
<strong>Model:</strong>{" "}
<span className="text-gray-900 font-medium">
{completion.model}
</span>
</li>
<li className="pt-1 mt-1 border-t border-gray-200">
<strong>Finish Reason:</strong>{" "}
<span className="text-gray-900 font-medium">
{completion.choices?.[0]?.finish_reason || "N/A"}
</span>
</li>
{completion.choices?.[0]?.message?.tool_calls &&
completion.choices[0].message.tool_calls.length > 0 && (
<li className="pt-1 mt-1 border-t border-gray-200">
<strong>Functions/Tools Called:</strong>
// Properties sidebar
const sidebar = (
<PropertiesCard>
<PropertyItem
label="Created"
value={new Date(completion.created * 1000).toLocaleString()}
/>
<PropertyItem label="ID" value={completion.id} />
<PropertyItem label="Model" value={completion.model} />
<PropertyItem
label="Finish Reason"
value={completion.choices?.[0]?.finish_reason || "N/A"}
hasBorder
/>
{(() => {
const toolCalls = completion.choices?.[0]?.message?.tool_calls;
if (toolCalls && Array.isArray(toolCalls) && toolCalls.length > 0) {
return (
<PropertyItem
label="Functions/Tools Called"
value={
<div>
<ul className="list-disc list-inside pl-4 mt-1">
{completion.choices[0].message.tool_calls.map(
(toolCall: any, index: number) => (
{toolCalls.map((toolCall: any, index: number) => (
<li key={index}>
<span className="text-gray-900 font-medium">
{toolCall.function?.name || "N/A"}
</span>
</li>
),
)}
))}
</ul>
</li>
)}
</ul>
</CardContent>
</Card>
</div>
</div>
</>
}
hasBorder
/>
);
}
return null;
})()}
</PropertiesCard>
);
return (
<DetailLayout title={title} mainContent={mainContent} sidebar={sidebar} />
);
}

View file

@ -1,8 +1,8 @@
import React from "react";
import { render, screen, fireEvent } from "@testing-library/react";
import "@testing-library/jest-dom";
import { ChatCompletionsTable } from "./chat-completion-table";
import { ChatCompletion } from "@/lib/types"; // Assuming this path is correct
import { ChatCompletionsTable } from "./chat-completions-table";
import { ChatCompletion } from "@/lib/types";
// Mock next/navigation
const mockPush = jest.fn();
@ -13,21 +13,25 @@ jest.mock("next/navigation", () => ({
}));
// Mock helper functions
// These are hoisted, so their mocks are available throughout the file
jest.mock("@/lib/truncate-text");
jest.mock("@/lib/format-tool-call");
jest.mock("@/lib/format-message-content");
// Import the mocked functions to set up default or specific implementations
import { truncateText as originalTruncateText } from "@/lib/truncate-text";
import { formatToolCallToString as originalFormatToolCallToString } from "@/lib/format-tool-call";
import {
extractTextFromContentPart as originalExtractTextFromContentPart,
extractDisplayableText as originalExtractDisplayableText,
} from "@/lib/format-message-content";
// Cast to jest.Mock for typings
const truncateText = originalTruncateText as jest.Mock;
const formatToolCallToString = originalFormatToolCallToString as jest.Mock;
const extractTextFromContentPart =
originalExtractTextFromContentPart as jest.Mock;
const extractDisplayableText = originalExtractDisplayableText as jest.Mock;
describe("ChatCompletionsTable", () => {
const defaultProps = {
completions: [] as ChatCompletion[],
data: [] as ChatCompletion[],
isLoading: false,
error: null,
};
@ -36,28 +40,26 @@ describe("ChatCompletionsTable", () => {
// Reset all mocks before each test
mockPush.mockClear();
truncateText.mockClear();
formatToolCallToString.mockClear();
extractTextFromContentPart.mockClear();
extractDisplayableText.mockClear();
// Default pass-through implementation for tests not focusing on truncation/formatting
// Default pass-through implementations
truncateText.mockImplementation((text: string | undefined) => text);
formatToolCallToString.mockImplementation((toolCall: any) =>
toolCall && typeof toolCall === "object" && toolCall.name
? `[DefaultToolCall:${toolCall.name}]`
: "[InvalidToolCall]",
extractTextFromContentPart.mockImplementation((content: unknown) =>
typeof content === "string" ? content : "extracted text",
);
extractDisplayableText.mockImplementation(
(message: unknown) =>
(message as { content?: string })?.content || "extracted output",
);
});
test("renders without crashing with default props", () => {
render(<ChatCompletionsTable {...defaultProps} />);
// Check for a unique element that should be present in the non-empty, non-loading, non-error state
// For now, as per Task 1, we will test the empty state message
expect(screen.getByText("No chat completions found.")).toBeInTheDocument();
});
test("click on a row navigates to the correct URL", () => {
const { rerender } = render(<ChatCompletionsTable {...defaultProps} />);
// Simulate a scenario where a completion exists and is clicked
const mockCompletion: ChatCompletion = {
id: "comp_123",
object: "chat.completion",
@ -73,9 +75,12 @@ describe("ChatCompletionsTable", () => {
input_messages: [{ role: "user", content: "Test input" }],
};
rerender(
<ChatCompletionsTable {...defaultProps} completions={[mockCompletion]} />,
);
// Set up mocks to return expected values
extractTextFromContentPart.mockReturnValue("Test input");
extractDisplayableText.mockReturnValue("Test output");
render(<ChatCompletionsTable {...defaultProps} data={[mockCompletion]} />);
const row = screen.getByText("Test input").closest("tr");
if (row) {
fireEvent.click(row);
@ -91,14 +96,13 @@ describe("ChatCompletionsTable", () => {
<ChatCompletionsTable {...defaultProps} isLoading={true} />,
);
// The Skeleton component uses data-slot="skeleton"
const skeletonSelector = '[data-slot="skeleton"]';
// Check for skeleton in the table caption
const tableCaption = container.querySelector("caption");
expect(tableCaption).toBeInTheDocument();
if (tableCaption) {
const captionSkeleton = tableCaption.querySelector(skeletonSelector);
const captionSkeleton = tableCaption.querySelector(
'[data-slot="skeleton"]',
);
expect(captionSkeleton).toBeInTheDocument();
}
@ -107,16 +111,10 @@ describe("ChatCompletionsTable", () => {
expect(tableBody).toBeInTheDocument();
if (tableBody) {
const bodySkeletons = tableBody.querySelectorAll(
`td ${skeletonSelector}`,
'[data-slot="skeleton"]',
);
expect(bodySkeletons.length).toBeGreaterThan(0); // Ensure at least one skeleton cell exists
expect(bodySkeletons.length).toBeGreaterThan(0);
}
// General check: ensure multiple skeleton elements are present in the table overall
const allSkeletonsInTable = container.querySelectorAll(
`table ${skeletonSelector}`,
);
expect(allSkeletonsInTable.length).toBeGreaterThan(3); // e.g., caption + at least one row of 3 cells, or just a few
});
});
@ -140,14 +138,14 @@ describe("ChatCompletionsTable", () => {
{...defaultProps}
error={{ name: "Error", message: "" }}
/>,
); // Error with empty message
);
expect(
screen.getByText("Error fetching data: An unknown error occurred"),
).toBeInTheDocument();
});
test("renders default error message when error prop is an object without message", () => {
render(<ChatCompletionsTable {...defaultProps} error={{} as Error} />); // Empty error object
render(<ChatCompletionsTable {...defaultProps} error={{} as Error} />);
expect(
screen.getByText("Error fetching data: An unknown error occurred"),
).toBeInTheDocument();
@ -155,14 +153,8 @@ describe("ChatCompletionsTable", () => {
});
describe("Empty State", () => {
test('renders "No chat completions found." and no table when completions array is empty', () => {
render(
<ChatCompletionsTable
completions={[]}
isLoading={false}
error={null}
/>,
);
test('renders "No chat completions found." and no table when data array is empty', () => {
render(<ChatCompletionsTable data={[]} isLoading={false} error={null} />);
expect(
screen.getByText("No chat completions found."),
).toBeInTheDocument();
@ -179,7 +171,7 @@ describe("ChatCompletionsTable", () => {
{
id: "comp_1",
object: "chat.completion",
created: 1710000000, // Fixed timestamp for test
created: 1710000000,
model: "llama-test-model",
choices: [
{
@ -206,9 +198,22 @@ describe("ChatCompletionsTable", () => {
},
];
// Set up mocks to return expected values
extractTextFromContentPart.mockImplementation((content: unknown) => {
if (content === "Test input") return "Test input";
if (content === "Another input") return "Another input";
return "extracted text";
});
extractDisplayableText.mockImplementation((message: unknown) => {
const msg = message as { content?: string };
if (msg?.content === "Test output") return "Test output";
if (msg?.content === "Another output") return "Another output";
return "extracted output";
});
render(
<ChatCompletionsTable
completions={mockCompletions}
data={mockCompletions}
isLoading={false}
error={null}
/>,
@ -242,7 +247,7 @@ describe("ChatCompletionsTable", () => {
});
});
describe("Text Truncation and Tool Call Formatting", () => {
describe("Text Truncation and Content Extraction", () => {
test("truncates long input and output text", () => {
// Specific mock implementation for this test
truncateText.mockImplementation(
@ -259,6 +264,10 @@ describe("ChatCompletionsTable", () => {
"This is a very long input message that should be truncated.";
const longOutput =
"This is a very long output message that should also be truncated.";
extractTextFromContentPart.mockReturnValue(longInput);
extractDisplayableText.mockReturnValue(longOutput);
const mockCompletions = [
{
id: "comp_trunc",
@ -278,7 +287,7 @@ describe("ChatCompletionsTable", () => {
render(
<ChatCompletionsTable
completions={mockCompletions}
data={mockCompletions}
isLoading={false}
error={null}
/>,
@ -289,52 +298,50 @@ describe("ChatCompletionsTable", () => {
longInput.slice(0, 10) + "...",
);
expect(truncatedTexts.length).toBe(2); // one for input, one for output
// Optionally, verify each one is in the document if getAllByText doesn't throw on not found
truncatedTexts.forEach((textElement) =>
expect(textElement).toBeInTheDocument(),
);
});
test("formats tool call output using formatToolCallToString", () => {
// Specific mock implementation for this test
formatToolCallToString.mockImplementation(
(toolCall: any) => `[TOOL:${toolCall.name}]`,
);
// Ensure no truncation interferes for this specific test for clarity of tool call format
truncateText.mockImplementation((text: string | undefined) => text);
const toolCall = { name: "search", args: { query: "llama" } };
const mockCompletions = [
{
id: "comp_tool",
test("uses content extraction functions correctly", () => {
const mockCompletion = {
id: "comp_extract",
object: "chat.completion",
created: 1710003000,
model: "llama-tool-model",
model: "llama-extract-model",
choices: [
{
index: 0,
message: {
role: "assistant",
content: "Tool output", // Content that will be prepended
tool_calls: [toolCall],
},
message: { role: "assistant", content: "Extracted output" },
finish_reason: "stop",
},
],
input_messages: [{ role: "user", content: "Tool input" }],
},
];
input_messages: [{ role: "user", content: "Extracted input" }],
};
extractTextFromContentPart.mockReturnValue("Extracted input");
extractDisplayableText.mockReturnValue("Extracted output");
render(
<ChatCompletionsTable
completions={mockCompletions}
data={[mockCompletion]}
isLoading={false}
error={null}
/>,
);
// The component concatenates message.content and the formatted tool call
expect(screen.getByText("Tool output [TOOL:search]")).toBeInTheDocument();
// Verify the extraction functions were called
expect(extractTextFromContentPart).toHaveBeenCalledWith(
"Extracted input",
);
expect(extractDisplayableText).toHaveBeenCalledWith({
role: "assistant",
content: "Extracted output",
});
// Verify the extracted content is displayed
expect(screen.getByText("Extracted input")).toBeInTheDocument();
expect(screen.getByText("Extracted output")).toBeInTheDocument();
});
});
});

View file

@ -0,0 +1,43 @@
"use client";
import { ChatCompletion } from "@/lib/types";
import { LogsTable, LogTableRow } from "@/components/logs/logs-table";
import {
extractTextFromContentPart,
extractDisplayableText,
} from "@/lib/format-message-content";
interface ChatCompletionsTableProps {
data: ChatCompletion[];
isLoading: boolean;
error: Error | null;
}
function formatChatCompletionToRow(completion: ChatCompletion): LogTableRow {
return {
id: completion.id,
input: extractTextFromContentPart(completion.input_messages?.[0]?.content),
output: extractDisplayableText(completion.choices?.[0]?.message),
model: completion.model,
createdTime: new Date(completion.created * 1000).toLocaleString(),
detailPath: `/logs/chat-completions/${completion.id}`,
};
}
export function ChatCompletionsTable({
data,
isLoading,
error,
}: ChatCompletionsTableProps) {
const formattedData = data.map(formatChatCompletionToRow);
return (
<LogsTable
data={formattedData}
isLoading={isLoading}
error={error}
caption="A list of your recent chat completions."
emptyMessage="No chat completions found."
/>
);
}

View file

@ -4,45 +4,10 @@ import { ChatMessage } from "@/lib/types";
import React from "react";
import { formatToolCallToString } from "@/lib/format-tool-call";
import { extractTextFromContentPart } from "@/lib/format-message-content";
// Sub-component or helper for the common label + content structure
const MessageBlock: React.FC<{
label: string;
labelDetail?: string;
content: React.ReactNode;
}> = ({ label, labelDetail, content }) => {
return (
<div>
<p className="py-1 font-semibold text-gray-800 mb-1">
{label}
{labelDetail && (
<span className="text-xs text-gray-500 font-normal ml-1">
{labelDetail}
</span>
)}
</p>
<div className="py-1">{content}</div>
</div>
);
};
interface ToolCallBlockProps {
children: React.ReactNode;
className?: string;
}
const ToolCallBlock = ({ children, className }: ToolCallBlockProps) => {
// Common styling for both function call arguments and tool output blocks
// Let's use slate-50 background as it's good for code-like content.
const baseClassName =
"p-3 bg-slate-50 border border-slate-200 rounded-md text-sm";
return (
<div className={`${baseClassName} ${className || ""}`}>
<pre className="whitespace-pre-wrap text-xs">{children}</pre>
</div>
);
};
import {
MessageBlock,
ToolCallBlock,
} from "@/components/ui/message-components";
interface ChatMessageItemProps {
message: ChatMessage;
@ -65,7 +30,11 @@ export function ChatMessageItem({ message }: ChatMessageItemProps) {
);
case "assistant":
if (message.tool_calls && message.tool_calls.length > 0) {
if (
message.tool_calls &&
Array.isArray(message.tool_calls) &&
message.tool_calls.length > 0
) {
return (
<>
{message.tool_calls.map((toolCall: any, index: number) => {

View file

@ -0,0 +1,141 @@
import React from "react";
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
import { Skeleton } from "@/components/ui/skeleton";
export function DetailLoadingView({ title }: { title: string }) {
return (
<>
<Skeleton className="h-8 w-3/4 mb-6" /> {/* Title Skeleton */}
<div className="flex flex-col md:flex-row gap-6">
<div className="flex-grow md:w-2/3 space-y-6">
{[...Array(2)].map((_, i) => (
<Card key={`main-skeleton-card-${i}`}>
<CardHeader>
<CardTitle>
<Skeleton className="h-6 w-1/2" />
</CardTitle>
</CardHeader>
<CardContent className="space-y-2">
<Skeleton className="h-4 w-full" />
<Skeleton className="h-4 w-full" />
<Skeleton className="h-4 w-3/4" />
</CardContent>
</Card>
))}
</div>
<div className="md:w-1/3">
<div className="p-4 border rounded-lg shadow-sm bg-white space-y-3">
<Skeleton className="h-6 w-1/3 mb-3" />{" "}
{/* Properties Title Skeleton */}
{[...Array(5)].map((_, i) => (
<div key={`prop-skeleton-${i}`} className="space-y-1">
<Skeleton className="h-4 w-1/4" />
<Skeleton className="h-4 w-1/2" />
</div>
))}
</div>
</div>
</div>
</>
);
}
export function DetailErrorView({
title,
id,
error,
}: {
title: string;
id: string;
error: Error;
}) {
return (
<>
<h1 className="text-2xl font-bold mb-6">{title}</h1>
<p>
Error loading details for ID {id}: {error.message}
</p>
</>
);
}
export function DetailNotFoundView({
title,
id,
}: {
title: string;
id: string;
}) {
return (
<>
<h1 className="text-2xl font-bold mb-6">{title}</h1>
<p>No details found for ID: {id}.</p>
</>
);
}
export interface PropertyItemProps {
label: string;
value: React.ReactNode;
className?: string;
hasBorder?: boolean;
}
export function PropertyItem({
label,
value,
className = "",
hasBorder = false,
}: PropertyItemProps) {
return (
<li
className={`${hasBorder ? "pt-1 mt-1 border-t border-gray-200" : ""} ${className}`}
>
<strong>{label}:</strong>{" "}
{typeof value === "string" || typeof value === "number" ? (
<span className="text-gray-900 font-medium">{value}</span>
) : (
value
)}
</li>
);
}
export interface PropertiesCardProps {
children: React.ReactNode;
}
export function PropertiesCard({ children }: PropertiesCardProps) {
return (
<Card>
<CardHeader>
<CardTitle>Properties</CardTitle>
</CardHeader>
<CardContent>
<ul className="space-y-2 text-sm text-gray-600">{children}</ul>
</CardContent>
</Card>
);
}
export interface DetailLayoutProps {
title: string;
mainContent: React.ReactNode;
sidebar: React.ReactNode;
}
export function DetailLayout({
title,
mainContent,
sidebar,
}: DetailLayoutProps) {
return (
<>
<h1 className="text-2xl font-bold mb-6">{title}</h1>
<div className="flex flex-col md:flex-row gap-6">
<div className="flex-grow md:w-2/3 space-y-6">{mainContent}</div>
<div className="md:w-1/3">{sidebar}</div>
</div>
</>
);
}

View file

@ -0,0 +1,49 @@
"use client";
import React from "react";
import { usePathname, useParams } from "next/navigation";
import {
PageBreadcrumb,
BreadcrumbSegment,
} from "@/components/layout/page-breadcrumb";
import { truncateText } from "@/lib/truncate-text";
interface LogsLayoutProps {
children: React.ReactNode;
sectionLabel: string;
basePath: string;
}
export default function LogsLayout({
children,
sectionLabel,
basePath,
}: LogsLayoutProps) {
const pathname = usePathname();
const params = useParams();
let segments: BreadcrumbSegment[] = [];
if (pathname === basePath) {
segments = [{ label: sectionLabel }];
}
const idParam = params?.id;
if (idParam && typeof idParam === "string") {
segments = [
{ label: sectionLabel, href: basePath },
{ label: `Details (${truncateText(idParam, 20)})` },
];
}
return (
<div className="container mx-auto p-4">
<>
{segments.length > 0 && (
<PageBreadcrumb segments={segments} className="mb-4" />
)}
{children}
</>
</div>
);
}

View file

@ -0,0 +1,350 @@
import React from "react";
import { render, screen, fireEvent } from "@testing-library/react";
import "@testing-library/jest-dom";
import { LogsTable, LogTableRow } from "./logs-table";
// Mock next/navigation
const mockPush = jest.fn();
jest.mock("next/navigation", () => ({
useRouter: () => ({
push: mockPush,
}),
}));
// Mock helper functions
jest.mock("@/lib/truncate-text");
// Import the mocked functions
import { truncateText as originalTruncateText } from "@/lib/truncate-text";
// Cast to jest.Mock for typings
const truncateText = originalTruncateText as jest.Mock;
describe("LogsTable", () => {
const defaultProps = {
data: [] as LogTableRow[],
isLoading: false,
error: null,
caption: "Test table caption",
emptyMessage: "No data found",
};
beforeEach(() => {
// Reset all mocks before each test
mockPush.mockClear();
truncateText.mockClear();
// Default pass-through implementation
truncateText.mockImplementation((text: string | undefined) => text);
});
test("renders without crashing with default props", () => {
render(<LogsTable {...defaultProps} />);
expect(screen.getByText("No data found")).toBeInTheDocument();
});
test("click on a row navigates to the correct URL", () => {
const mockData: LogTableRow[] = [
{
id: "row_123",
input: "Test input",
output: "Test output",
model: "test-model",
createdTime: "2024-01-01 12:00:00",
detailPath: "/test/path/row_123",
},
];
render(<LogsTable {...defaultProps} data={mockData} />);
const row = screen.getByText("Test input").closest("tr");
if (row) {
fireEvent.click(row);
expect(mockPush).toHaveBeenCalledWith("/test/path/row_123");
} else {
throw new Error('Row with "Test input" not found for router mock test.');
}
});
describe("Loading State", () => {
test("renders skeleton UI when isLoading is true", () => {
const { container } = render(
<LogsTable {...defaultProps} isLoading={true} />,
);
// Check for skeleton in the table caption
const tableCaption = container.querySelector("caption");
expect(tableCaption).toBeInTheDocument();
if (tableCaption) {
const captionSkeleton = tableCaption.querySelector(
'[data-slot="skeleton"]',
);
expect(captionSkeleton).toBeInTheDocument();
}
// Check for skeletons in the table body cells
const tableBody = container.querySelector("tbody");
expect(tableBody).toBeInTheDocument();
if (tableBody) {
const bodySkeletons = tableBody.querySelectorAll(
'[data-slot="skeleton"]',
);
expect(bodySkeletons.length).toBeGreaterThan(0);
}
// Check that table headers are still rendered
expect(screen.getByText("Input")).toBeInTheDocument();
expect(screen.getByText("Output")).toBeInTheDocument();
expect(screen.getByText("Model")).toBeInTheDocument();
expect(screen.getByText("Created")).toBeInTheDocument();
});
test("renders correct number of skeleton rows", () => {
const { container } = render(
<LogsTable {...defaultProps} isLoading={true} />,
);
const skeletonRows = container.querySelectorAll("tbody tr");
expect(skeletonRows.length).toBe(3); // Should render 3 skeleton rows
});
});
describe("Error State", () => {
test("renders error message when error prop is provided", () => {
const errorMessage = "Network Error";
render(
<LogsTable
{...defaultProps}
error={{ name: "Error", message: errorMessage }}
/>,
);
expect(
screen.getByText(`Error fetching data: ${errorMessage}`),
).toBeInTheDocument();
});
test("renders default error message when error.message is not available", () => {
render(
<LogsTable {...defaultProps} error={{ name: "Error", message: "" }} />,
);
expect(
screen.getByText("Error fetching data: An unknown error occurred"),
).toBeInTheDocument();
});
test("renders default error message when error prop is an object without message", () => {
render(<LogsTable {...defaultProps} error={{} as Error} />);
expect(
screen.getByText("Error fetching data: An unknown error occurred"),
).toBeInTheDocument();
});
test("does not render table when in error state", () => {
render(
<LogsTable
{...defaultProps}
error={{ name: "Error", message: "Test error" }}
/>,
);
const table = screen.queryByRole("table");
expect(table).not.toBeInTheDocument();
});
});
describe("Empty State", () => {
test("renders custom empty message when data array is empty", () => {
render(
<LogsTable
{...defaultProps}
data={[]}
emptyMessage="Custom empty message"
/>,
);
expect(screen.getByText("Custom empty message")).toBeInTheDocument();
// Ensure that the table structure is NOT rendered in the empty state
const table = screen.queryByRole("table");
expect(table).not.toBeInTheDocument();
});
});
describe("Data Rendering", () => {
test("renders table caption, headers, and data correctly", () => {
const mockData: LogTableRow[] = [
{
id: "row_1",
input: "First input",
output: "First output",
model: "model-1",
createdTime: "2024-01-01 12:00:00",
detailPath: "/path/1",
},
{
id: "row_2",
input: "Second input",
output: "Second output",
model: "model-2",
createdTime: "2024-01-02 13:00:00",
detailPath: "/path/2",
},
];
render(
<LogsTable
{...defaultProps}
data={mockData}
caption="Custom table caption"
/>,
);
// Table caption
expect(screen.getByText("Custom table caption")).toBeInTheDocument();
// Table headers
expect(screen.getByText("Input")).toBeInTheDocument();
expect(screen.getByText("Output")).toBeInTheDocument();
expect(screen.getByText("Model")).toBeInTheDocument();
expect(screen.getByText("Created")).toBeInTheDocument();
// Data rows
expect(screen.getByText("First input")).toBeInTheDocument();
expect(screen.getByText("First output")).toBeInTheDocument();
expect(screen.getByText("model-1")).toBeInTheDocument();
expect(screen.getByText("2024-01-01 12:00:00")).toBeInTheDocument();
expect(screen.getByText("Second input")).toBeInTheDocument();
expect(screen.getByText("Second output")).toBeInTheDocument();
expect(screen.getByText("model-2")).toBeInTheDocument();
expect(screen.getByText("2024-01-02 13:00:00")).toBeInTheDocument();
});
test("applies correct CSS classes to table rows", () => {
const mockData: LogTableRow[] = [
{
id: "row_1",
input: "Test input",
output: "Test output",
model: "test-model",
createdTime: "2024-01-01 12:00:00",
detailPath: "/test/path",
},
];
render(<LogsTable {...defaultProps} data={mockData} />);
const row = screen.getByText("Test input").closest("tr");
expect(row).toHaveClass("cursor-pointer");
expect(row).toHaveClass("hover:bg-muted/50");
});
test("applies correct alignment to Created column", () => {
const mockData: LogTableRow[] = [
{
id: "row_1",
input: "Test input",
output: "Test output",
model: "test-model",
createdTime: "2024-01-01 12:00:00",
detailPath: "/test/path",
},
];
render(<LogsTable {...defaultProps} data={mockData} />);
const createdCell = screen.getByText("2024-01-01 12:00:00").closest("td");
expect(createdCell).toHaveClass("text-right");
});
});
describe("Text Truncation", () => {
test("truncates input and output text using truncateText function", () => {
// Mock truncateText to return truncated versions
truncateText.mockImplementation((text: string | undefined) => {
if (typeof text === "string" && text.length > 10) {
return text.slice(0, 10) + "...";
}
return text;
});
const longInput =
"This is a very long input text that should be truncated";
const longOutput =
"This is a very long output text that should be truncated";
const mockData: LogTableRow[] = [
{
id: "row_1",
input: longInput,
output: longOutput,
model: "test-model",
createdTime: "2024-01-01 12:00:00",
detailPath: "/test/path",
},
];
render(<LogsTable {...defaultProps} data={mockData} />);
// Verify truncateText was called
expect(truncateText).toHaveBeenCalledWith(longInput);
expect(truncateText).toHaveBeenCalledWith(longOutput);
// Verify truncated text is displayed
const truncatedTexts = screen.getAllByText("This is a ...");
expect(truncatedTexts).toHaveLength(2); // one for input, one for output
truncatedTexts.forEach((textElement) =>
expect(textElement).toBeInTheDocument(),
);
});
test("does not truncate model names", () => {
const mockData: LogTableRow[] = [
{
id: "row_1",
input: "Test input",
output: "Test output",
model: "very-long-model-name-that-should-not-be-truncated",
createdTime: "2024-01-01 12:00:00",
detailPath: "/test/path",
},
];
render(<LogsTable {...defaultProps} data={mockData} />);
// Model name should not be passed to truncateText
expect(truncateText).not.toHaveBeenCalledWith(
"very-long-model-name-that-should-not-be-truncated",
);
// Full model name should be displayed
expect(
screen.getByText("very-long-model-name-that-should-not-be-truncated"),
).toBeInTheDocument();
});
});
describe("Accessibility", () => {
test("table has proper role and structure", () => {
const mockData: LogTableRow[] = [
{
id: "row_1",
input: "Test input",
output: "Test output",
model: "test-model",
createdTime: "2024-01-01 12:00:00",
detailPath: "/test/path",
},
];
render(<LogsTable {...defaultProps} data={mockData} />);
const table = screen.getByRole("table");
expect(table).toBeInTheDocument();
const columnHeaders = screen.getAllByRole("columnheader");
expect(columnHeaders).toHaveLength(4);
const rows = screen.getAllByRole("row");
expect(rows).toHaveLength(2); // 1 header row + 1 data row
});
});
});

View file

@ -1,12 +1,7 @@
"use client";
import { useRouter } from "next/navigation";
import { ChatCompletion } from "@/lib/types";
import { truncateText } from "@/lib/truncate-text";
import {
extractTextFromContentPart,
extractDisplayableText,
} from "@/lib/format-message-content";
import {
Table,
TableBody,
@ -18,17 +13,31 @@ import {
} from "@/components/ui/table";
import { Skeleton } from "@/components/ui/skeleton";
interface ChatCompletionsTableProps {
completions: ChatCompletion[];
isLoading: boolean;
error: Error | null;
// Generic table row data interface
export interface LogTableRow {
id: string;
input: string;
output: string;
model: string;
createdTime: string;
detailPath: string;
}
export function ChatCompletionsTable({
completions,
interface LogsTableProps {
data: LogTableRow[];
isLoading: boolean;
error: Error | null;
caption: string;
emptyMessage: string;
}
export function LogsTable({
data,
isLoading,
error,
}: ChatCompletionsTableProps) {
caption,
emptyMessage,
}: LogsTableProps) {
const router = useRouter();
const tableHeader = (
@ -77,41 +86,25 @@ export function ChatCompletionsTable({
);
}
if (completions.length === 0) {
return <p>No chat completions found.</p>;
if (data.length === 0) {
return <p>{emptyMessage}</p>;
}
return (
<Table>
<TableCaption>A list of your recent chat completions.</TableCaption>
<TableCaption>{caption}</TableCaption>
{tableHeader}
<TableBody>
{completions.map((completion) => (
{data.map((row) => (
<TableRow
key={completion.id}
onClick={() =>
router.push(`/logs/chat-completions/${completion.id}`)
}
key={row.id}
onClick={() => router.push(row.detailPath)}
className="cursor-pointer hover:bg-muted/50"
>
<TableCell>
{truncateText(
extractTextFromContentPart(
completion.input_messages?.[0]?.content,
),
)}
</TableCell>
<TableCell>
{(() => {
const message = completion.choices?.[0]?.message;
const outputText = extractDisplayableText(message);
return truncateText(outputText);
})()}
</TableCell>
<TableCell>{completion.model}</TableCell>
<TableCell className="text-right">
{new Date(completion.created * 1000).toLocaleString()}
</TableCell>
<TableCell>{truncateText(row.input)}</TableCell>
<TableCell>{truncateText(row.output)}</TableCell>
<TableCell>{row.model}</TableCell>
<TableCell className="text-right">{row.createdTime}</TableCell>
</TableRow>
))}
</TableBody>

View file

@ -0,0 +1,56 @@
import { useFunctionCallGrouping } from "../hooks/function-call-grouping";
import { ItemRenderer } from "../items/item-renderer";
import { GroupedFunctionCallItemComponent } from "../items/grouped-function-call-item";
import {
isFunctionCallItem,
isFunctionCallOutputItem,
AnyResponseItem,
} from "../utils/item-types";
interface GroupedItemsDisplayProps {
items: AnyResponseItem[];
keyPrefix: string;
defaultRole?: string;
}
export function GroupedItemsDisplay({
items,
keyPrefix,
defaultRole = "unknown",
}: GroupedItemsDisplayProps) {
const groupedItems = useFunctionCallGrouping(items);
return (
<>
{groupedItems.map((groupedItem) => {
// If this is a function call with an output, render the grouped component
if (
groupedItem.outputItem &&
isFunctionCallItem(groupedItem.item) &&
isFunctionCallOutputItem(groupedItem.outputItem)
) {
return (
<GroupedFunctionCallItemComponent
key={`${keyPrefix}-${groupedItem.index}`}
functionCall={groupedItem.item}
output={groupedItem.outputItem}
index={groupedItem.index}
keyPrefix={keyPrefix}
/>
);
}
// Otherwise, render the individual item
return (
<ItemRenderer
key={`${keyPrefix}-${groupedItem.index}`}
item={groupedItem.item}
index={groupedItem.index}
keyPrefix={keyPrefix}
defaultRole={defaultRole}
/>
);
})}
</>
);
}

View file

@ -0,0 +1,92 @@
import { useMemo } from "react";
import {
isFunctionCallOutputItem,
AnyResponseItem,
FunctionCallOutputItem,
} from "../utils/item-types";
export interface GroupedItem {
item: AnyResponseItem;
index: number;
outputItem?: AnyResponseItem;
outputIndex?: number;
}
/**
* Hook to group function calls with their corresponding outputs
* @param items Array of items to group
* @returns Array of grouped items with their outputs
*/
export function useFunctionCallGrouping(
items: AnyResponseItem[],
): GroupedItem[] {
return useMemo(() => {
const groupedItems: GroupedItem[] = [];
const processedIndices = new Set<number>();
// Build a map of call_id to indices for function_call_output items
const callIdToIndices = new Map<string, number[]>();
for (let i = 0; i < items.length; i++) {
const item = items[i];
if (isFunctionCallOutputItem(item)) {
if (!callIdToIndices.has(item.call_id)) {
callIdToIndices.set(item.call_id, []);
}
callIdToIndices.get(item.call_id)!.push(i);
}
}
// Process items and group function calls with their outputs
for (let i = 0; i < items.length; i++) {
if (processedIndices.has(i)) {
continue;
}
const currentItem = items[i];
if (
currentItem.type === "function_call" &&
"name" in currentItem &&
"call_id" in currentItem
) {
const functionCallId = currentItem.call_id as string;
let outputIndex = -1;
let outputItem: FunctionCallOutputItem | null = null;
const relatedIndices = callIdToIndices.get(functionCallId) || [];
for (const idx of relatedIndices) {
const potentialOutput = items[idx];
outputIndex = idx;
outputItem = potentialOutput as FunctionCallOutputItem;
break;
}
if (outputItem && outputIndex !== -1) {
// Group function call with its function_call_output
groupedItems.push({
item: currentItem,
index: i,
outputItem,
outputIndex,
});
// Mark both items as processed
processedIndices.add(i);
processedIndices.add(outputIndex);
// Matching function call and output found, skip to next item
continue;
}
}
// render normally
groupedItems.push({
item: currentItem,
index: i,
});
processedIndices.add(i);
}
return groupedItems;
}, [items]);
}

View file

@ -0,0 +1,29 @@
import {
MessageBlock,
ToolCallBlock,
} from "@/components/ui/message-components";
import { FunctionCallItem } from "../utils/item-types";
interface FunctionCallItemProps {
item: FunctionCallItem;
index: number;
keyPrefix: string;
}
export function FunctionCallItemComponent({
item,
index,
keyPrefix,
}: FunctionCallItemProps) {
const name = item.name || "unknown";
const args = item.arguments || "{}";
const formattedFunctionCall = `${name}(${args})`;
return (
<MessageBlock
key={`${keyPrefix}-${index}`}
label="Function Call"
content={<ToolCallBlock>{formattedFunctionCall}</ToolCallBlock>}
/>
);
}

View file

@ -0,0 +1,37 @@
import {
MessageBlock,
ToolCallBlock,
} from "@/components/ui/message-components";
import { BaseItem } from "../utils/item-types";
interface GenericItemProps {
item: BaseItem;
index: number;
keyPrefix: string;
}
export function GenericItemComponent({
item,
index,
keyPrefix,
}: GenericItemProps) {
// Handle other types like function calls, tool outputs, etc.
const itemData = item as Record<string, unknown>;
const content = itemData.content
? typeof itemData.content === "string"
? itemData.content
: JSON.stringify(itemData.content, null, 2)
: JSON.stringify(itemData, null, 2);
const label = keyPrefix === "input" ? "Input" : "Output";
return (
<MessageBlock
key={`${keyPrefix}-${index}`}
label={label}
labelDetail={`(${itemData.type})`}
content={<ToolCallBlock>{content}</ToolCallBlock>}
/>
);
}

View file

@ -0,0 +1,54 @@
import {
MessageBlock,
ToolCallBlock,
} from "@/components/ui/message-components";
import { FunctionCallItem, FunctionCallOutputItem } from "../utils/item-types";
interface GroupedFunctionCallItemProps {
functionCall: FunctionCallItem;
output: FunctionCallOutputItem;
index: number;
keyPrefix: string;
}
export function GroupedFunctionCallItemComponent({
functionCall,
output,
index,
keyPrefix,
}: GroupedFunctionCallItemProps) {
const name = functionCall.name || "unknown";
const args = functionCall.arguments || "{}";
// Extract the output content from function_call_output
let outputContent = "";
if (output.output) {
outputContent =
typeof output.output === "string"
? output.output
: JSON.stringify(output.output);
} else {
outputContent = JSON.stringify(output, null, 2);
}
const functionCallContent = (
<div>
<div className="mb-2">
<span className="text-sm text-gray-600">Arguments</span>
<ToolCallBlock>{`${name}(${args})`}</ToolCallBlock>
</div>
<div>
<span className="text-sm text-gray-600">Output</span>
<ToolCallBlock>{outputContent}</ToolCallBlock>
</div>
</div>
);
return (
<MessageBlock
key={`${keyPrefix}-${index}`}
label="Function Call"
content={functionCallContent}
/>
);
}

View file

@ -0,0 +1,6 @@
export { MessageItemComponent } from "./message-item";
export { FunctionCallItemComponent } from "./function-call-item";
export { WebSearchItemComponent } from "./web-search-item";
export { GenericItemComponent } from "./generic-item";
export { GroupedFunctionCallItemComponent } from "./grouped-function-call-item";
export { ItemRenderer } from "./item-renderer";

View file

@ -0,0 +1,60 @@
import {
isMessageItem,
isFunctionCallItem,
isWebSearchCallItem,
AnyResponseItem,
} from "../utils/item-types";
import { MessageItemComponent } from "./message-item";
import { FunctionCallItemComponent } from "./function-call-item";
import { WebSearchItemComponent } from "./web-search-item";
import { GenericItemComponent } from "./generic-item";
interface ItemRendererProps {
item: AnyResponseItem;
index: number;
keyPrefix: string;
defaultRole?: string;
}
export function ItemRenderer({
item,
index,
keyPrefix,
defaultRole = "unknown",
}: ItemRendererProps) {
if (isMessageItem(item)) {
return (
<MessageItemComponent
item={item}
index={index}
keyPrefix={keyPrefix}
defaultRole={defaultRole}
/>
);
}
if (isFunctionCallItem(item)) {
return (
<FunctionCallItemComponent
item={item}
index={index}
keyPrefix={keyPrefix}
/>
);
}
if (isWebSearchCallItem(item)) {
return (
<WebSearchItemComponent item={item} index={index} keyPrefix={keyPrefix} />
);
}
// Fallback to generic item for unknown types
return (
<GenericItemComponent
item={item as any}
index={index}
keyPrefix={keyPrefix}
/>
);
}

View file

@ -0,0 +1,41 @@
import { MessageBlock } from "@/components/ui/message-components";
import { MessageItem } from "../utils/item-types";
interface MessageItemProps {
item: MessageItem;
index: number;
keyPrefix: string;
defaultRole?: string;
}
export function MessageItemComponent({
item,
index,
keyPrefix,
defaultRole = "unknown",
}: MessageItemProps) {
let content = "";
if (typeof item.content === "string") {
content = item.content;
} else if (Array.isArray(item.content)) {
content = item.content
.map((c) => {
return c.type === "input_text" || c.type === "output_text"
? c.text
: JSON.stringify(c);
})
.join(" ");
}
const role = item.role || defaultRole;
const label = role.charAt(0).toUpperCase() + role.slice(1);
return (
<MessageBlock
key={`${keyPrefix}-${index}`}
label={label}
content={content}
/>
);
}

View file

@ -0,0 +1,28 @@
import {
MessageBlock,
ToolCallBlock,
} from "@/components/ui/message-components";
import { WebSearchCallItem } from "../utils/item-types";
interface WebSearchItemProps {
item: WebSearchCallItem;
index: number;
keyPrefix: string;
}
export function WebSearchItemComponent({
item,
index,
keyPrefix,
}: WebSearchItemProps) {
const formattedWebSearch = `web_search_call(status: ${item.status})`;
return (
<MessageBlock
key={`${keyPrefix}-${index}`}
label="Function Call"
labelDetail="(Web Search)"
content={<ToolCallBlock>{formattedWebSearch}</ToolCallBlock>}
/>
);
}

View file

@ -0,0 +1,777 @@
import React from "react";
import { render, screen } from "@testing-library/react";
import "@testing-library/jest-dom";
import { ResponseDetailView } from "./responses-detail";
import { OpenAIResponse, InputItemListResponse } from "@/lib/types";
describe("ResponseDetailView", () => {
const defaultProps = {
response: null,
inputItems: null,
isLoading: false,
isLoadingInputItems: false,
error: null,
inputItemsError: null,
id: "test_id",
};
describe("Loading State", () => {
test("renders loading skeleton when isLoading is true", () => {
const { container } = render(
<ResponseDetailView {...defaultProps} isLoading={true} />,
);
// Check for skeleton elements
const skeletons = container.querySelectorAll('[data-slot="skeleton"]');
expect(skeletons.length).toBeGreaterThan(0);
// The title is replaced by a skeleton when loading, so we shouldn't expect the text
});
});
describe("Error State", () => {
test("renders error message when error prop is provided", () => {
const errorMessage = "Network Error";
render(
<ResponseDetailView
{...defaultProps}
error={{ name: "Error", message: errorMessage }}
/>,
);
expect(screen.getByText("Responses Details")).toBeInTheDocument();
// The error message is split across elements, so we check for parts
expect(
screen.getByText(/Error loading details for ID/),
).toBeInTheDocument();
expect(screen.getByText(/test_id/)).toBeInTheDocument();
expect(screen.getByText(/Network Error/)).toBeInTheDocument();
});
test("renders default error message when error.message is not available", () => {
render(
<ResponseDetailView
{...defaultProps}
error={{ name: "Error", message: "" }}
/>,
);
expect(
screen.getByText(/Error loading details for ID/),
).toBeInTheDocument();
expect(screen.getByText(/test_id/)).toBeInTheDocument();
});
});
describe("Not Found State", () => {
test("renders not found message when response is null and not loading/error", () => {
render(<ResponseDetailView {...defaultProps} response={null} />);
expect(screen.getByText("Responses Details")).toBeInTheDocument();
// The message is split across elements
expect(screen.getByText(/No details found for ID:/)).toBeInTheDocument();
expect(screen.getByText(/test_id/)).toBeInTheDocument();
});
});
describe("Response Data Rendering", () => {
const mockResponse: OpenAIResponse = {
id: "resp_123",
object: "response",
created_at: 1710000000,
model: "llama-test-model",
status: "completed",
output: [
{
type: "message",
role: "assistant",
content: "Test response output",
},
],
input: [
{
type: "message",
role: "user",
content: "Test input message",
},
],
temperature: 0.7,
top_p: 0.9,
parallel_tool_calls: true,
previous_response_id: "prev_resp_456",
};
test("renders response data with input and output sections", () => {
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
// Check main sections
expect(screen.getByText("Responses Details")).toBeInTheDocument();
expect(screen.getByText("Input")).toBeInTheDocument();
expect(screen.getByText("Output")).toBeInTheDocument();
// Check input content
expect(screen.getByText("Test input message")).toBeInTheDocument();
expect(screen.getByText("User")).toBeInTheDocument();
// Check output content
expect(screen.getByText("Test response output")).toBeInTheDocument();
expect(screen.getByText("Assistant")).toBeInTheDocument();
});
test("renders properties sidebar with all response metadata", () => {
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
// Check properties - use regex to handle text split across elements
expect(screen.getByText(/Created/)).toBeInTheDocument();
expect(
screen.getByText(new Date(1710000000 * 1000).toLocaleString()),
).toBeInTheDocument();
// Check for the specific ID label (not Previous Response ID)
expect(
screen.getByText((content, element) => {
return element?.tagName === "STRONG" && content === "ID:";
}),
).toBeInTheDocument();
expect(screen.getByText("resp_123")).toBeInTheDocument();
expect(screen.getByText(/Model/)).toBeInTheDocument();
expect(screen.getByText("llama-test-model")).toBeInTheDocument();
expect(screen.getByText(/Status/)).toBeInTheDocument();
expect(screen.getByText("completed")).toBeInTheDocument();
expect(screen.getByText(/Temperature/)).toBeInTheDocument();
expect(screen.getByText("0.7")).toBeInTheDocument();
expect(screen.getByText(/Top P/)).toBeInTheDocument();
expect(screen.getByText("0.9")).toBeInTheDocument();
expect(screen.getByText(/Parallel Tool Calls/)).toBeInTheDocument();
expect(screen.getByText("Yes")).toBeInTheDocument();
expect(screen.getByText(/Previous Response ID/)).toBeInTheDocument();
expect(screen.getByText("prev_resp_456")).toBeInTheDocument();
});
test("handles optional properties correctly", () => {
const minimalResponse: OpenAIResponse = {
id: "resp_minimal",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [],
input: [],
};
render(
<ResponseDetailView {...defaultProps} response={minimalResponse} />,
);
// Should show required properties
expect(screen.getByText("resp_minimal")).toBeInTheDocument();
expect(screen.getByText("test-model")).toBeInTheDocument();
expect(screen.getByText("completed")).toBeInTheDocument();
// Should not show optional properties
expect(screen.queryByText("Temperature")).not.toBeInTheDocument();
expect(screen.queryByText("Top P")).not.toBeInTheDocument();
expect(screen.queryByText("Parallel Tool Calls")).not.toBeInTheDocument();
expect(
screen.queryByText("Previous Response ID"),
).not.toBeInTheDocument();
});
test("renders error information when response has error", () => {
const errorResponse: OpenAIResponse = {
...mockResponse,
error: {
code: "invalid_request",
message: "The request was invalid",
},
};
render(<ResponseDetailView {...defaultProps} response={errorResponse} />);
// The error is shown in the properties sidebar, not as a separate "Error" label
expect(
screen.getByText("invalid_request: The request was invalid"),
).toBeInTheDocument();
});
});
describe("Input Items Handling", () => {
const mockResponse: OpenAIResponse = {
id: "resp_123",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [{ type: "message", role: "assistant", content: "output" }],
input: [{ type: "message", role: "user", content: "fallback input" }],
};
test("shows loading state for input items", () => {
render(
<ResponseDetailView
{...defaultProps}
response={mockResponse}
isLoadingInputItems={true}
/>,
);
// Check for skeleton loading in input items section
const { container } = render(
<ResponseDetailView
{...defaultProps}
response={mockResponse}
isLoadingInputItems={true}
/>,
);
const skeletons = container.querySelectorAll('[data-slot="skeleton"]');
expect(skeletons.length).toBeGreaterThan(0);
});
test("shows error message for input items with fallback", () => {
render(
<ResponseDetailView
{...defaultProps}
response={mockResponse}
inputItemsError={{
name: "Error",
message: "Failed to load input items",
}}
/>,
);
expect(
screen.getByText(
"Error loading input items: Failed to load input items",
),
).toBeInTheDocument();
expect(
screen.getByText("Falling back to response input data."),
).toBeInTheDocument();
// Should still show fallback input data
expect(screen.getByText("fallback input")).toBeInTheDocument();
});
test("uses input items data when available", () => {
const mockInputItems: InputItemListResponse = {
object: "list",
data: [
{
type: "message",
role: "user",
content: "input from items API",
},
],
};
render(
<ResponseDetailView
{...defaultProps}
response={mockResponse}
inputItems={mockInputItems}
/>,
);
// Should show input items data, not response.input
expect(screen.getByText("input from items API")).toBeInTheDocument();
expect(screen.queryByText("fallback input")).not.toBeInTheDocument();
});
test("falls back to response.input when input items is empty", () => {
const emptyInputItems: InputItemListResponse = {
object: "list",
data: [],
};
render(
<ResponseDetailView
{...defaultProps}
response={mockResponse}
inputItems={emptyInputItems}
/>,
);
// Should show fallback input data
expect(screen.getByText("fallback input")).toBeInTheDocument();
});
test("shows no input message when no data available", () => {
const responseWithoutInput: OpenAIResponse = {
...mockResponse,
input: [],
};
render(
<ResponseDetailView
{...defaultProps}
response={responseWithoutInput}
inputItems={null}
/>,
);
expect(screen.getByText("No input data available.")).toBeInTheDocument();
});
});
describe("Input Display Components", () => {
test("renders string content input correctly", () => {
const mockResponse: OpenAIResponse = {
id: "resp_123",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [],
input: [
{
type: "message",
role: "user",
content: "Simple string input",
},
],
};
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
expect(screen.getByText("Simple string input")).toBeInTheDocument();
expect(screen.getByText("User")).toBeInTheDocument();
});
test("renders array content input correctly", () => {
const mockResponse: OpenAIResponse = {
id: "resp_123",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [],
input: [
{
type: "message",
role: "user",
content: [
{ type: "input_text", text: "First part" },
{ type: "output_text", text: "Second part" },
],
},
],
};
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
expect(screen.getByText("First part Second part")).toBeInTheDocument();
expect(screen.getByText("User")).toBeInTheDocument();
});
test("renders non-message input types correctly", () => {
const mockResponse: OpenAIResponse = {
id: "resp_123",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [],
input: [
{
type: "function_call",
content: "function call content",
},
],
};
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
expect(screen.getByText("function call content")).toBeInTheDocument();
// Use getAllByText to find the specific "Input" with the type detail
const inputElements = screen.getAllByText("Input");
expect(inputElements.length).toBeGreaterThan(0);
expect(screen.getByText("(function_call)")).toBeInTheDocument();
});
test("handles input with object content", () => {
const mockResponse: OpenAIResponse = {
id: "resp_123",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [],
input: [
{
type: "custom_type",
content: JSON.stringify({ key: "value", nested: { data: "test" } }),
},
],
};
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
// Should show JSON stringified content (without quotes around keys in the rendered output)
expect(screen.getByText(/key.*value/)).toBeInTheDocument();
// Use getAllByText to find the specific "Input" with the type detail
const inputElements = screen.getAllByText("Input");
expect(inputElements.length).toBeGreaterThan(0);
expect(screen.getByText("(custom_type)")).toBeInTheDocument();
});
test("renders function call input correctly", () => {
const mockResponse: OpenAIResponse = {
id: "resp_123",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [],
input: [
{
type: "function_call",
id: "call_456",
status: "completed",
name: "input_function",
arguments: '{"param": "value"}',
},
],
};
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
expect(
screen.getByText('input_function({"param": "value"})'),
).toBeInTheDocument();
expect(screen.getByText("Function Call")).toBeInTheDocument();
});
test("renders web search call input correctly", () => {
const mockResponse: OpenAIResponse = {
id: "resp_123",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [],
input: [
{
type: "web_search_call",
id: "search_789",
status: "completed",
},
],
};
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
expect(
screen.getByText("web_search_call(status: completed)"),
).toBeInTheDocument();
expect(screen.getByText("Function Call")).toBeInTheDocument();
expect(screen.getByText("(Web Search)")).toBeInTheDocument();
});
});
describe("Output Display Components", () => {
test("renders message output with string content", () => {
const mockResponse: OpenAIResponse = {
id: "resp_123",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [
{
type: "message",
role: "assistant",
content: "Simple string output",
},
],
input: [],
};
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
expect(screen.getByText("Simple string output")).toBeInTheDocument();
expect(screen.getByText("Assistant")).toBeInTheDocument();
});
test("renders message output with array content", () => {
const mockResponse: OpenAIResponse = {
id: "resp_123",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [
{
type: "message",
role: "assistant",
content: [
{ type: "output_text", text: "First output" },
{ type: "input_text", text: "Second output" },
],
},
],
input: [],
};
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
expect(
screen.getByText("First output Second output"),
).toBeInTheDocument();
expect(screen.getByText("Assistant")).toBeInTheDocument();
});
test("renders function call output correctly", () => {
const mockResponse: OpenAIResponse = {
id: "resp_123",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [
{
type: "function_call",
id: "call_123",
status: "completed",
name: "search_function",
arguments: '{"query": "test"}',
},
],
input: [],
};
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
expect(
screen.getByText('search_function({"query": "test"})'),
).toBeInTheDocument();
expect(screen.getByText("Function Call")).toBeInTheDocument();
});
test("renders function call output without arguments", () => {
const mockResponse: OpenAIResponse = {
id: "resp_123",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [
{
type: "function_call",
id: "call_123",
status: "completed",
name: "simple_function",
},
],
input: [],
};
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
expect(screen.getByText("simple_function({})")).toBeInTheDocument();
expect(screen.getByText(/Function Call/)).toBeInTheDocument();
});
test("renders web search call output correctly", () => {
const mockResponse: OpenAIResponse = {
id: "resp_123",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [
{
type: "web_search_call",
id: "search_123",
status: "completed",
},
],
input: [],
};
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
expect(
screen.getByText("web_search_call(status: completed)"),
).toBeInTheDocument();
expect(screen.getByText(/Function Call/)).toBeInTheDocument();
expect(screen.getByText("(Web Search)")).toBeInTheDocument();
});
test("renders unknown output types with JSON fallback", () => {
const mockResponse: OpenAIResponse = {
id: "resp_123",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [
{
type: "unknown_type",
custom_field: "custom_value",
data: { nested: "object" },
} as any,
],
input: [],
};
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
// Should show JSON stringified content
expect(
screen.getByText(/custom_field.*custom_value/),
).toBeInTheDocument();
expect(screen.getByText("(unknown_type)")).toBeInTheDocument();
});
test("shows no output message when output array is empty", () => {
const mockResponse: OpenAIResponse = {
id: "resp_123",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [],
input: [],
};
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
expect(screen.getByText("No output data available.")).toBeInTheDocument();
});
test("groups function call with its output correctly", () => {
const mockResponse: OpenAIResponse = {
id: "resp_123",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [
{
type: "function_call",
id: "call_123",
status: "completed",
name: "get_weather",
arguments: '{"city": "Tokyo"}',
},
{
type: "message",
role: "assistant",
call_id: "call_123",
content: "sunny and warm",
} as any, // Using any to bypass the type restriction for this test
],
input: [],
};
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
// Should show the function call and message as separate items (not grouped)
expect(screen.getByText("Function Call")).toBeInTheDocument();
expect(
screen.getByText('get_weather({"city": "Tokyo"})'),
).toBeInTheDocument();
expect(screen.getByText("Assistant")).toBeInTheDocument();
expect(screen.getByText("sunny and warm")).toBeInTheDocument();
// Should NOT have the grouped "Arguments" and "Output" labels
expect(screen.queryByText("Arguments")).not.toBeInTheDocument();
});
test("groups function call with function_call_output correctly", () => {
const mockResponse: OpenAIResponse = {
id: "resp_123",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [
{
type: "function_call",
call_id: "call_123",
status: "completed",
name: "get_weather",
arguments: '{"city": "Tokyo"}',
},
{
type: "function_call_output",
id: "fc_68364957013081...",
status: "completed",
call_id: "call_123",
output: "sunny and warm",
} as any, // Using any to bypass the type restriction for this test
],
input: [],
};
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
// Should show the function call grouped with its clean output
expect(screen.getByText("Function Call")).toBeInTheDocument();
expect(screen.getByText("Arguments")).toBeInTheDocument();
expect(
screen.getByText('get_weather({"city": "Tokyo"})'),
).toBeInTheDocument();
// Use getAllByText since there are multiple "Output" elements (card title and output label)
const outputElements = screen.getAllByText("Output");
expect(outputElements.length).toBeGreaterThan(0);
expect(screen.getByText("sunny and warm")).toBeInTheDocument();
});
});
describe("Edge Cases and Error Handling", () => {
test("handles missing role in message input", () => {
const mockResponse: OpenAIResponse = {
id: "resp_123",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [],
input: [
{
type: "message",
content: "Message without role",
},
],
};
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
expect(screen.getByText("Message without role")).toBeInTheDocument();
expect(screen.getByText("Unknown")).toBeInTheDocument(); // Default role
});
test("handles missing name in function call output", () => {
const mockResponse: OpenAIResponse = {
id: "resp_123",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [
{
type: "function_call",
id: "call_123",
status: "completed",
},
],
input: [],
};
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
// When name is missing, it falls back to JSON.stringify of the entire output
const functionCallElements = screen.getAllByText(/function_call/);
expect(functionCallElements.length).toBeGreaterThan(0);
expect(screen.getByText(/call_123/)).toBeInTheDocument();
});
});
});

View file

@ -0,0 +1,171 @@
"use client";
import { OpenAIResponse, InputItemListResponse } from "@/lib/types";
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
import { Skeleton } from "@/components/ui/skeleton";
import {
DetailLoadingView,
DetailErrorView,
DetailNotFoundView,
DetailLayout,
PropertiesCard,
PropertyItem,
} from "@/components/layout/detail-layout";
import { GroupedItemsDisplay } from "./grouping/grouped-items-display";
interface ResponseDetailViewProps {
response: OpenAIResponse | null;
inputItems: InputItemListResponse | null;
isLoading: boolean;
isLoadingInputItems: boolean;
error: Error | null;
inputItemsError: Error | null;
id: string;
}
export function ResponseDetailView({
response,
inputItems,
isLoading,
isLoadingInputItems,
error,
inputItemsError,
id,
}: ResponseDetailViewProps) {
const title = "Responses Details";
if (error) {
return <DetailErrorView title={title} id={id} error={error} />;
}
if (isLoading) {
return <DetailLoadingView title={title} />;
}
if (!response) {
return <DetailNotFoundView title={title} id={id} />;
}
// Main content cards
const mainContent = (
<>
<Card>
<CardHeader>
<CardTitle>Input</CardTitle>
</CardHeader>
<CardContent>
{/* Show loading state for input items */}
{isLoadingInputItems ? (
<div className="space-y-2">
<Skeleton className="h-4 w-full" />
<Skeleton className="h-4 w-3/4" />
<Skeleton className="h-4 w-1/2" />
</div>
) : inputItemsError ? (
<div className="text-red-500 text-sm">
Error loading input items: {inputItemsError.message}
<br />
<span className="text-gray-500 text-xs">
Falling back to response input data.
</span>
</div>
) : null}
{/* Display input items if available, otherwise fall back to response.input */}
{(() => {
const dataToDisplay =
inputItems?.data && inputItems.data.length > 0
? inputItems.data
: response.input;
if (dataToDisplay && dataToDisplay.length > 0) {
return (
<GroupedItemsDisplay
items={dataToDisplay}
keyPrefix="input"
defaultRole="unknown"
/>
);
} else {
return (
<p className="text-gray-500 italic text-sm">
No input data available.
</p>
);
}
})()}
</CardContent>
</Card>
<Card>
<CardHeader>
<CardTitle>Output</CardTitle>
</CardHeader>
<CardContent>
{response.output?.length > 0 ? (
<GroupedItemsDisplay
items={response.output}
keyPrefix="output"
defaultRole="assistant"
/>
) : (
<p className="text-gray-500 italic text-sm">
No output data available.
</p>
)}
</CardContent>
</Card>
</>
);
// Properties sidebar
const sidebar = (
<PropertiesCard>
<PropertyItem
label="Created"
value={new Date(response.created_at * 1000).toLocaleString()}
/>
<PropertyItem label="ID" value={response.id} />
<PropertyItem label="Model" value={response.model} />
<PropertyItem label="Status" value={response.status} hasBorder />
{response.temperature && (
<PropertyItem
label="Temperature"
value={response.temperature}
hasBorder
/>
)}
{response.top_p && <PropertyItem label="Top P" value={response.top_p} />}
{response.parallel_tool_calls && (
<PropertyItem
label="Parallel Tool Calls"
value={response.parallel_tool_calls ? "Yes" : "No"}
/>
)}
{response.previous_response_id && (
<PropertyItem
label="Previous Response ID"
value={
<span className="text-xs">{response.previous_response_id}</span>
}
hasBorder
/>
)}
{response.error && (
<PropertyItem
label="Error"
value={
<span className="text-red-900 font-medium">
{response.error.code}: {response.error.message}
</span>
}
className="pt-1 mt-1 border-t border-red-200"
/>
)}
</PropertiesCard>
);
return (
<DetailLayout title={title} mainContent={mainContent} sidebar={sidebar} />
);
}

View file

@ -0,0 +1,537 @@
import React from "react";
import { render, screen, fireEvent } from "@testing-library/react";
import "@testing-library/jest-dom";
import { ResponsesTable } from "./responses-table";
import { OpenAIResponse } from "@/lib/types";
// Mock next/navigation
const mockPush = jest.fn();
jest.mock("next/navigation", () => ({
useRouter: () => ({
push: mockPush,
}),
}));
// Mock helper functions
jest.mock("@/lib/truncate-text");
// Import the mocked functions
import { truncateText as originalTruncateText } from "@/lib/truncate-text";
// Cast to jest.Mock for typings
const truncateText = originalTruncateText as jest.Mock;
describe("ResponsesTable", () => {
const defaultProps = {
data: [] as OpenAIResponse[],
isLoading: false,
error: null,
};
beforeEach(() => {
// Reset all mocks before each test
mockPush.mockClear();
truncateText.mockClear();
// Default pass-through implementation
truncateText.mockImplementation((text: string | undefined) => text);
});
test("renders without crashing with default props", () => {
render(<ResponsesTable {...defaultProps} />);
expect(screen.getByText("No responses found.")).toBeInTheDocument();
});
test("click on a row navigates to the correct URL", () => {
const mockResponse: OpenAIResponse = {
id: "resp_123",
object: "response",
created_at: Math.floor(Date.now() / 1000),
model: "llama-test-model",
status: "completed",
output: [
{
type: "message",
role: "assistant",
content: "Test output",
},
],
input: [
{
type: "message",
role: "user",
content: "Test input",
},
],
};
render(<ResponsesTable {...defaultProps} data={[mockResponse]} />);
const row = screen.getByText("Test input").closest("tr");
if (row) {
fireEvent.click(row);
expect(mockPush).toHaveBeenCalledWith("/logs/responses/resp_123");
} else {
throw new Error('Row with "Test input" not found for router mock test.');
}
});
describe("Loading State", () => {
test("renders skeleton UI when isLoading is true", () => {
const { container } = render(
<ResponsesTable {...defaultProps} isLoading={true} />,
);
// Check for skeleton in the table caption
const tableCaption = container.querySelector("caption");
expect(tableCaption).toBeInTheDocument();
if (tableCaption) {
const captionSkeleton = tableCaption.querySelector(
'[data-slot="skeleton"]',
);
expect(captionSkeleton).toBeInTheDocument();
}
// Check for skeletons in the table body cells
const tableBody = container.querySelector("tbody");
expect(tableBody).toBeInTheDocument();
if (tableBody) {
const bodySkeletons = tableBody.querySelectorAll(
'[data-slot="skeleton"]',
);
expect(bodySkeletons.length).toBeGreaterThan(0);
}
});
});
describe("Error State", () => {
test("renders error message when error prop is provided", () => {
const errorMessage = "Network Error";
render(
<ResponsesTable
{...defaultProps}
error={{ name: "Error", message: errorMessage }}
/>,
);
expect(
screen.getByText(`Error fetching data: ${errorMessage}`),
).toBeInTheDocument();
});
test("renders default error message when error.message is not available", () => {
render(
<ResponsesTable
{...defaultProps}
error={{ name: "Error", message: "" }}
/>,
);
expect(
screen.getByText("Error fetching data: An unknown error occurred"),
).toBeInTheDocument();
});
test("renders default error message when error prop is an object without message", () => {
render(<ResponsesTable {...defaultProps} error={{} as Error} />);
expect(
screen.getByText("Error fetching data: An unknown error occurred"),
).toBeInTheDocument();
});
});
describe("Empty State", () => {
test('renders "No responses found." and no table when data array is empty', () => {
render(<ResponsesTable data={[]} isLoading={false} error={null} />);
expect(screen.getByText("No responses found.")).toBeInTheDocument();
// Ensure that the table structure is NOT rendered in the empty state
const table = screen.queryByRole("table");
expect(table).not.toBeInTheDocument();
});
});
describe("Data Rendering", () => {
test("renders table caption, headers, and response data correctly", () => {
const mockResponses = [
{
id: "resp_1",
object: "response" as const,
created_at: 1710000000,
model: "llama-test-model",
status: "completed",
output: [
{
type: "message" as const,
role: "assistant" as const,
content: "Test output",
},
],
input: [
{
type: "message",
role: "user",
content: "Test input",
},
],
},
{
id: "resp_2",
object: "response" as const,
created_at: 1710001000,
model: "llama-another-model",
status: "completed",
output: [
{
type: "message" as const,
role: "assistant" as const,
content: "Another output",
},
],
input: [
{
type: "message",
role: "user",
content: "Another input",
},
],
},
];
render(
<ResponsesTable data={mockResponses} isLoading={false} error={null} />,
);
// Table caption
expect(
screen.getByText("A list of your recent responses."),
).toBeInTheDocument();
// Table headers
expect(screen.getByText("Input")).toBeInTheDocument();
expect(screen.getByText("Output")).toBeInTheDocument();
expect(screen.getByText("Model")).toBeInTheDocument();
expect(screen.getByText("Created")).toBeInTheDocument();
// Data rows
expect(screen.getByText("Test input")).toBeInTheDocument();
expect(screen.getByText("Test output")).toBeInTheDocument();
expect(screen.getByText("llama-test-model")).toBeInTheDocument();
expect(
screen.getByText(new Date(1710000000 * 1000).toLocaleString()),
).toBeInTheDocument();
expect(screen.getByText("Another input")).toBeInTheDocument();
expect(screen.getByText("Another output")).toBeInTheDocument();
expect(screen.getByText("llama-another-model")).toBeInTheDocument();
expect(
screen.getByText(new Date(1710001000 * 1000).toLocaleString()),
).toBeInTheDocument();
});
});
describe("Input Text Extraction", () => {
test("extracts text from string content", () => {
const mockResponse: OpenAIResponse = {
id: "resp_string",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [{ type: "message", role: "assistant", content: "output" }],
input: [
{
type: "message",
role: "user",
content: "Simple string input",
},
],
};
render(
<ResponsesTable data={[mockResponse]} isLoading={false} error={null} />,
);
expect(screen.getByText("Simple string input")).toBeInTheDocument();
});
test("extracts text from array content with input_text type", () => {
const mockResponse: OpenAIResponse = {
id: "resp_array",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [{ type: "message", role: "assistant", content: "output" }],
input: [
{
type: "message",
role: "user",
content: [
{ type: "input_text", text: "Array input text" },
{ type: "input_text", text: "Should not be used" },
],
},
],
};
render(
<ResponsesTable data={[mockResponse]} isLoading={false} error={null} />,
);
expect(screen.getByText("Array input text")).toBeInTheDocument();
});
test("returns empty string when no message input found", () => {
const mockResponse: OpenAIResponse = {
id: "resp_no_input",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [{ type: "message", role: "assistant", content: "output" }],
input: [
{
type: "other_type",
content: "Not a message",
},
],
};
const { container } = render(
<ResponsesTable data={[mockResponse]} isLoading={false} error={null} />,
);
// Find the input cell (first cell in the data row) and verify it's empty
const inputCell = container.querySelector("tbody tr td:first-child");
expect(inputCell).toBeInTheDocument();
expect(inputCell).toHaveTextContent("");
});
});
describe("Output Text Extraction", () => {
test("extracts text from string message content", () => {
const mockResponse: OpenAIResponse = {
id: "resp_string_output",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [
{
type: "message",
role: "assistant",
content: "Simple string output",
},
],
input: [{ type: "message", content: "input" }],
};
render(
<ResponsesTable data={[mockResponse]} isLoading={false} error={null} />,
);
expect(screen.getByText("Simple string output")).toBeInTheDocument();
});
test("extracts text from array message content with output_text type", () => {
const mockResponse: OpenAIResponse = {
id: "resp_array_output",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [
{
type: "message",
role: "assistant",
content: [
{ type: "output_text", text: "Array output text" },
{ type: "output_text", text: "Should not be used" },
],
},
],
input: [{ type: "message", content: "input" }],
};
render(
<ResponsesTable data={[mockResponse]} isLoading={false} error={null} />,
);
expect(screen.getByText("Array output text")).toBeInTheDocument();
});
test("formats function call output", () => {
const mockResponse: OpenAIResponse = {
id: "resp_function_call",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [
{
type: "function_call",
id: "call_123",
status: "completed",
name: "search_function",
arguments: '{"query": "test"}',
},
],
input: [{ type: "message", content: "input" }],
};
render(
<ResponsesTable data={[mockResponse]} isLoading={false} error={null} />,
);
expect(
screen.getByText('search_function({"query": "test"})'),
).toBeInTheDocument();
});
test("formats function call output without arguments", () => {
const mockResponse: OpenAIResponse = {
id: "resp_function_no_args",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [
{
type: "function_call",
id: "call_123",
status: "completed",
name: "simple_function",
},
],
input: [{ type: "message", content: "input" }],
};
render(
<ResponsesTable data={[mockResponse]} isLoading={false} error={null} />,
);
expect(screen.getByText("simple_function({})")).toBeInTheDocument();
});
test("formats web search call output", () => {
const mockResponse: OpenAIResponse = {
id: "resp_web_search",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [
{
type: "web_search_call",
id: "search_123",
status: "completed",
},
],
input: [{ type: "message", content: "input" }],
};
render(
<ResponsesTable data={[mockResponse]} isLoading={false} error={null} />,
);
expect(
screen.getByText("web_search_call(status: completed)"),
).toBeInTheDocument();
});
test("falls back to JSON.stringify for unknown tool call types", () => {
const mockResponse: OpenAIResponse = {
id: "resp_unknown_tool",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [
{
type: "unknown_call",
id: "unknown_123",
status: "completed",
custom_field: "custom_value",
} as any,
],
input: [{ type: "message", content: "input" }],
};
render(
<ResponsesTable data={[mockResponse]} isLoading={false} error={null} />,
);
// Should contain the JSON stringified version
expect(screen.getByText(/unknown_call/)).toBeInTheDocument();
});
test("falls back to JSON.stringify for entire output when no message or tool call found", () => {
const mockResponse: OpenAIResponse = {
id: "resp_fallback",
object: "response",
created_at: 1710000000,
model: "test-model",
status: "completed",
output: [
{
type: "unknown_type",
data: "some data",
} as any,
],
input: [{ type: "message", content: "input" }],
};
render(
<ResponsesTable data={[mockResponse]} isLoading={false} error={null} />,
);
// Should contain the JSON stringified version of the output array
expect(screen.getByText(/unknown_type/)).toBeInTheDocument();
});
});
describe("Text Truncation", () => {
test("truncates long input and output text", () => {
// Specific mock implementation for this test
truncateText.mockImplementation(
(text: string | undefined, maxLength?: number) => {
const defaultTestMaxLength = 10;
const effectiveMaxLength = maxLength ?? defaultTestMaxLength;
return typeof text === "string" && text.length > effectiveMaxLength
? text.slice(0, effectiveMaxLength) + "..."
: text;
},
);
const longInput =
"This is a very long input message that should be truncated.";
const longOutput =
"This is a very long output message that should also be truncated.";
const mockResponse: OpenAIResponse = {
id: "resp_trunc",
object: "response",
created_at: 1710002000,
model: "llama-trunc-model",
status: "completed",
output: [
{
type: "message",
role: "assistant",
content: longOutput,
},
],
input: [
{
type: "message",
role: "user",
content: longInput,
},
],
};
render(
<ResponsesTable data={[mockResponse]} isLoading={false} error={null} />,
);
// The truncated text should be present for both input and output
const truncatedTexts = screen.getAllByText(
longInput.slice(0, 10) + "...",
);
expect(truncatedTexts.length).toBe(2); // one for input, one for output
truncatedTexts.forEach((textElement) =>
expect(textElement).toBeInTheDocument(),
);
});
});
});

View file

@ -0,0 +1,117 @@
"use client";
import {
OpenAIResponse,
ResponseInput,
ResponseInputMessageContent,
} from "@/lib/types";
import { LogsTable, LogTableRow } from "@/components/logs/logs-table";
import {
isMessageInput,
isMessageItem,
isFunctionCallItem,
isWebSearchCallItem,
MessageItem,
FunctionCallItem,
WebSearchCallItem,
} from "./utils/item-types";
interface ResponsesTableProps {
data: OpenAIResponse[];
isLoading: boolean;
error: Error | null;
}
function getInputText(response: OpenAIResponse): string {
const firstInput = response.input.find(isMessageInput);
if (firstInput) {
return extractContentFromItem(firstInput);
}
return "";
}
function getOutputText(response: OpenAIResponse): string {
const firstMessage = response.output.find((item) =>
isMessageItem(item as any),
);
if (firstMessage) {
const content = extractContentFromItem(firstMessage as MessageItem);
if (content) {
return content;
}
}
const functionCall = response.output.find((item) =>
isFunctionCallItem(item as any),
);
if (functionCall) {
return formatFunctionCall(functionCall as FunctionCallItem);
}
const webSearchCall = response.output.find((item) =>
isWebSearchCallItem(item as any),
);
if (webSearchCall) {
return formatWebSearchCall(webSearchCall as WebSearchCallItem);
}
return JSON.stringify(response.output);
}
function extractContentFromItem(item: {
content?: string | ResponseInputMessageContent[];
}): string {
if (!item.content) {
return "";
}
if (typeof item.content === "string") {
return item.content;
} else if (Array.isArray(item.content)) {
const textContent = item.content.find(
(c: ResponseInputMessageContent) =>
c.type === "input_text" || c.type === "output_text",
);
return textContent?.text || "";
}
return "";
}
function formatFunctionCall(functionCall: FunctionCallItem): string {
const args = functionCall.arguments || "{}";
const name = functionCall.name || "unknown";
return `${name}(${args})`;
}
function formatWebSearchCall(webSearchCall: WebSearchCallItem): string {
return `web_search_call(status: ${webSearchCall.status})`;
}
function formatResponseToRow(response: OpenAIResponse): LogTableRow {
return {
id: response.id,
input: getInputText(response),
output: getOutputText(response),
model: response.model,
createdTime: new Date(response.created_at * 1000).toLocaleString(),
detailPath: `/logs/responses/${response.id}`,
};
}
export function ResponsesTable({
data,
isLoading,
error,
}: ResponsesTableProps) {
const formattedData = data.map(formatResponseToRow);
return (
<LogsTable
data={formattedData}
isLoading={isLoading}
error={error}
caption="A list of your recent responses."
emptyMessage="No responses found."
/>
);
}

View file

@ -0,0 +1,61 @@
/**
* Type guards for different item types in responses
*/
import type {
ResponseInput,
ResponseOutput,
ResponseMessage,
ResponseToolCall,
} from "@/lib/types";
export interface BaseItem {
type: string;
[key: string]: unknown;
}
export type MessageItem = ResponseMessage;
export type FunctionCallItem = ResponseToolCall & { type: "function_call" };
export type WebSearchCallItem = ResponseToolCall & { type: "web_search_call" };
export type FunctionCallOutputItem = BaseItem & {
type: "function_call_output";
call_id: string;
output?: string | object;
};
export type AnyResponseItem =
| ResponseInput
| ResponseOutput
| FunctionCallOutputItem;
export function isMessageInput(
item: ResponseInput,
): item is ResponseInput & { type: "message" } {
return item.type === "message";
}
export function isMessageItem(item: AnyResponseItem): item is MessageItem {
return item.type === "message" && "content" in item;
}
export function isFunctionCallItem(
item: AnyResponseItem,
): item is FunctionCallItem {
return item.type === "function_call" && "name" in item;
}
export function isWebSearchCallItem(
item: AnyResponseItem,
): item is WebSearchCallItem {
return item.type === "web_search_call";
}
export function isFunctionCallOutputItem(
item: AnyResponseItem,
): item is FunctionCallOutputItem {
return (
item.type === "function_call_output" &&
"call_id" in item &&
typeof (item as any).call_id === "string"
);
}

View file

@ -0,0 +1,49 @@
import React from "react";
export interface MessageBlockProps {
label: string;
labelDetail?: string;
content: React.ReactNode;
className?: string;
contentClassName?: string;
}
export const MessageBlock: React.FC<MessageBlockProps> = ({
label,
labelDetail,
content,
className = "",
contentClassName = "",
}) => {
return (
<div className={`mb-4 ${className}`}>
<p className="py-1 font-semibold text-gray-800 mb-1">
{label}
{labelDetail && (
<span className="text-xs text-gray-500 font-normal ml-1">
{labelDetail}
</span>
)}
</p>
<div className={`py-1 whitespace-pre-wrap ${contentClassName}`}>
{content}
</div>
</div>
);
};
export interface ToolCallBlockProps {
children: React.ReactNode;
className?: string;
}
export const ToolCallBlock = ({ children, className }: ToolCallBlockProps) => {
const baseClassName =
"p-3 bg-slate-50 border border-slate-200 rounded-md text-sm";
return (
<div className={`${baseClassName} ${className || ""}`}>
<pre className="whitespace-pre-wrap text-xs">{children}</pre>
</div>
);
};

View file

@ -0,0 +1,12 @@
import LlamaStackClient from "llama-stack-client";
import OpenAI from "openai";
export const client =
process.env.NEXT_PUBLIC_USE_OPENAI_CLIENT === "true" // useful for testing
? new OpenAI({
apiKey: process.env.NEXT_PUBLIC_OPENAI_API_KEY,
dangerouslyAllowBrowser: true,
})
: new LlamaStackClient({
baseURL: process.env.NEXT_PUBLIC_LLAMA_STACK_BASE_URL,
});

View file

@ -43,10 +43,14 @@ export function extractDisplayableText(
return "";
}
let textPart = extractTextFromContentPart(message.content);
const textPart = extractTextFromContentPart(message.content);
let toolCallPart = "";
if (message.tool_calls && message.tool_calls.length > 0) {
if (
message.tool_calls &&
Array.isArray(message.tool_calls) &&
message.tool_calls.length > 0
) {
// For summary, usually the first tool call is sufficient
toolCallPart = formatToolCallToString(message.tool_calls[0]);
}

View file

@ -18,20 +18,20 @@ export interface ImageUrlContentBlock {
export type ChatMessageContentPart =
| TextContentBlock
| ImageUrlContentBlock
| { type: string; [key: string]: any }; // Fallback for other potential types
| { type: string; [key: string]: unknown }; // Fallback for other potential types
export interface ChatMessage {
role: string;
content: string | ChatMessageContentPart[]; // Updated content type
name?: string | null;
tool_calls?: any | null; // This could also be refined to a more specific ToolCall[] type
tool_calls?: unknown | null; // This could also be refined to a more specific ToolCall[] type
}
export interface Choice {
message: ChatMessage;
finish_reason: string;
index: number;
logprobs?: any | null;
logprobs?: unknown | null;
}
export interface ChatCompletion {
@ -42,3 +42,62 @@ export interface ChatCompletion {
model: string;
input_messages: ChatMessage[];
}
// Response types for OpenAI Responses API
export interface ResponseInputMessageContent {
text?: string;
type: "input_text" | "input_image" | "output_text";
image_url?: string;
detail?: "low" | "high" | "auto";
}
export interface ResponseMessage {
content: string | ResponseInputMessageContent[];
role: "system" | "developer" | "user" | "assistant";
type: "message";
id?: string;
status?: string;
}
export interface ResponseToolCall {
id: string;
status: string;
type: "web_search_call" | "function_call";
arguments?: string;
call_id?: string;
name?: string;
}
export type ResponseOutput = ResponseMessage | ResponseToolCall;
export interface ResponseInput {
type: string;
content?: string | ResponseInputMessageContent[];
role?: string;
[key: string]: unknown; // Flexible for various input types
}
export interface OpenAIResponse {
id: string;
created_at: number;
model: string;
object: "response";
status: string;
output: ResponseOutput[];
input: ResponseInput[];
error?: {
code: string;
message: string;
};
parallel_tool_calls?: boolean;
previous_response_id?: string;
temperature?: number;
top_p?: number;
truncation?: string;
user?: string;
}
export interface InputItemListResponse {
data: ResponseInput[];
object: "list";
}

View file

@ -19,6 +19,7 @@
"lucide-react": "^0.510.0",
"next": "15.3.2",
"next-themes": "^0.4.6",
"openai": "^4.103.0",
"react": "^19.0.0",
"react-dom": "^19.0.0",
"tailwind-merge": "^3.3.0"
@ -9092,7 +9093,7 @@
},
"node_modules/llama-stack-client": {
"version": "0.0.1-alpha.0",
"resolved": "git+ssh://git@github.com/stainless-sdks/llama-stack-node.git#efa814980d44b3b2c92944377a086915137b2134",
"resolved": "git+ssh://git@github.com/stainless-sdks/llama-stack-node.git#5d34d229fb53b6dad02da0f19f4b310b529c6b15",
"license": "Apache-2.0",
"dependencies": {
"@types/node": "^18.11.18",
@ -9804,6 +9805,51 @@
"url": "https://github.com/sponsors/sindresorhus"
}
},
"node_modules/openai": {
"version": "4.103.0",
"resolved": "https://registry.npmjs.org/openai/-/openai-4.103.0.tgz",
"integrity": "sha512-eWcz9kdurkGOFDtd5ySS5y251H2uBgq9+1a2lTBnjMMzlexJ40Am5t6Mu76SSE87VvitPa0dkIAp75F+dZVC0g==",
"license": "Apache-2.0",
"dependencies": {
"@types/node": "^18.11.18",
"@types/node-fetch": "^2.6.4",
"abort-controller": "^3.0.0",
"agentkeepalive": "^4.2.1",
"form-data-encoder": "1.7.2",
"formdata-node": "^4.3.2",
"node-fetch": "^2.6.7"
},
"bin": {
"openai": "bin/cli"
},
"peerDependencies": {
"ws": "^8.18.0",
"zod": "^3.23.8"
},
"peerDependenciesMeta": {
"ws": {
"optional": true
},
"zod": {
"optional": true
}
}
},
"node_modules/openai/node_modules/@types/node": {
"version": "18.19.103",
"resolved": "https://registry.npmjs.org/@types/node/-/node-18.19.103.tgz",
"integrity": "sha512-hHTHp+sEz6SxFsp+SA+Tqrua3AbmlAw+Y//aEwdHrdZkYVRWdvWD3y5uPZ0flYOkgskaFWqZ/YGFm3FaFQ0pRw==",
"license": "MIT",
"dependencies": {
"undici-types": "~5.26.4"
}
},
"node_modules/openai/node_modules/undici-types": {
"version": "5.26.5",
"resolved": "https://registry.npmjs.org/undici-types/-/undici-types-5.26.5.tgz",
"integrity": "sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==",
"license": "MIT"
},
"node_modules/optionator": {
"version": "0.9.4",
"resolved": "https://registry.npmjs.org/optionator/-/optionator-0.9.4.tgz",
@ -12223,7 +12269,7 @@
"version": "8.18.2",
"resolved": "https://registry.npmjs.org/ws/-/ws-8.18.2.tgz",
"integrity": "sha512-DMricUmwGZUVr++AEAe2uiVM7UoO9MAVZMDu05UQOaUII0lp+zOzLLU4Xqh/JvTqklB1T4uELaaPBKyjE1r4fQ==",
"dev": true,
"devOptional": true,
"license": "MIT",
"engines": {
"node": ">=10.0.0"
@ -12334,7 +12380,7 @@
"version": "3.24.4",
"resolved": "https://registry.npmjs.org/zod/-/zod-3.24.4.tgz",
"integrity": "sha512-OdqJE9UDRPwWsrHjLN2F8bPxvwJBK22EHLWtanu0LSYr5YqzsaaW3RMgmjwr8Rypg5k+meEJdSPXJZXE/yqOMg==",
"dev": true,
"devOptional": true,
"license": "MIT",
"funding": {
"url": "https://github.com/sponsors/colinhacks"

View file

@ -19,7 +19,7 @@
"@radix-ui/react-tooltip": "^1.2.6",
"class-variance-authority": "^0.7.1",
"clsx": "^2.1.1",
"llama-stack-client": "github:stainless-sdks/llama-stack-node#ehhuang/dev",
"llama-stack-client": "0.2.8",
"lucide-react": "^0.510.0",
"next": "15.3.2",
"next-themes": "^0.4.6",

View file

@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "llama_stack"
version = "0.2.7"
version = "0.2.8"
authors = [{ name = "Meta Llama", email = "llama-oss@meta.com" }]
description = "Llama Stack"
readme = "README.md"
@ -21,13 +21,13 @@ classifiers = [
"Topic :: Scientific/Engineering :: Information Analysis",
]
dependencies = [
"blobfile",
"aiohttp",
"fire",
"httpx",
"huggingface-hub",
"jinja2>=3.1.6",
"jsonschema",
"llama-stack-client>=0.2.7",
"llama-stack-client>=0.2.8",
"openai>=1.66",
"prompt-toolkit",
"python-dotenv",
@ -36,6 +36,7 @@ dependencies = [
"requests",
"rich",
"setuptools",
"starlette",
"termcolor",
"tiktoken",
"pillow",
@ -43,6 +44,14 @@ dependencies = [
]
[project.optional-dependencies]
ui = [
"streamlit",
"pandas",
"llama-stack-client>=0.2.8",
"streamlit-option-menu",
]
[dependency-groups]
dev = [
"pytest",
"pytest-timeout",
@ -73,10 +82,11 @@ unit = [
"opentelemetry-exporter-otlp-proto-http",
"sqlalchemy",
"sqlalchemy[asyncio]>=2.0.41",
"blobfile",
]
# These are the core dependencies required for running integration tests. They are shared across all
# providers. If a provider requires additional dependencies, please add them to your environment
# separately. If you are using "uv" to execute your tests, you can use the "--with" flag to specify extra
# separately. If you are using "uv" to execute your tests, you can use the "--group" flag to specify extra
# dependencies.
test = [
"openai",
@ -112,12 +122,6 @@ docs = [
"sphinxcontrib.openapi",
]
codegen = ["rich", "pydantic", "jinja2>=3.1.6"]
ui = [
"streamlit",
"pandas",
"llama-stack-client>=0.2.7",
"streamlit-option-menu",
]
[project.urls]
Homepage = "https://github.com/meta-llama/llama-stack"
@ -138,7 +142,6 @@ explicit = true
[tool.uv.sources]
torch = [{ index = "pytorch-cpu" }]
torchvision = [{ index = "pytorch-cpu" }]
llama-stack = { workspace = true }
[tool.ruff]
line-length = 120
@ -332,10 +335,5 @@ init_forbid_extra = true
init_typed = true
warn_required_dynamic_aliases = true
[dependency-groups]
dev = [
"llama-stack",
]
[tool.ruff.lint.pep8-naming]
classmethod-decorators = ["classmethod", "pydantic.field_validator"]

View file

@ -1,63 +1,203 @@
# This file was autogenerated by uv via the following command:
# uv export --frozen --no-hashes --no-emit-project --output-file=requirements.txt
# uv export --frozen --no-hashes --no-emit-project --no-default-groups --output-file=requirements.txt
aiohappyeyeballs==2.5.0
# via aiohttp
aiohttp==3.11.13
# via llama-stack
aiosignal==1.3.2
# via aiohttp
annotated-types==0.7.0
# via pydantic
anyio==4.8.0
# via
# httpx
# llama-stack-client
# openai
# starlette
async-timeout==5.0.1 ; python_full_version < '3.11'
# via aiohttp
attrs==25.1.0
blobfile==3.0.0
# via
# aiohttp
# jsonschema
# referencing
certifi==2025.1.31
# via
# httpcore
# httpx
# requests
charset-normalizer==3.4.1
# via requests
click==8.1.8
# via llama-stack-client
colorama==0.4.6 ; sys_platform == 'win32'
# via
# click
# tqdm
distro==1.9.0
# via
# llama-stack-client
# openai
ecdsa==0.19.1
# via python-jose
exceptiongroup==1.2.2 ; python_full_version < '3.11'
# via anyio
filelock==3.17.0
# via huggingface-hub
fire==0.7.0
# via llama-stack
frozenlist==1.5.0
# via
# aiohttp
# aiosignal
fsspec==2024.12.0
# via huggingface-hub
h11==0.16.0
# via
# httpcore
# llama-stack
httpcore==1.0.9
# via httpx
httpx==0.28.1
# via
# llama-stack
# llama-stack-client
# openai
huggingface-hub==0.29.0
# via llama-stack
idna==3.10
# via
# anyio
# httpx
# requests
# yarl
jinja2==3.1.6
# via llama-stack
jiter==0.8.2
# via openai
jsonschema==4.23.0
# via llama-stack
jsonschema-specifications==2024.10.1
llama-stack-client==0.2.7
lxml==5.3.1
# via jsonschema
llama-stack-client==0.2.8
# via llama-stack
markdown-it-py==3.0.0
# via rich
markupsafe==3.0.2
# via jinja2
mdurl==0.1.2
# via markdown-it-py
multidict==6.1.0
# via
# aiohttp
# yarl
numpy==2.2.3
# via pandas
openai==1.71.0
# via llama-stack
packaging==24.2
# via huggingface-hub
pandas==2.2.3
# via llama-stack-client
pillow==11.1.0
# via llama-stack
prompt-toolkit==3.0.50
# via
# llama-stack
# llama-stack-client
propcache==0.3.0
# via
# aiohttp
# yarl
pyaml==25.1.0
# via llama-stack-client
pyasn1==0.4.8
pycryptodomex==3.21.0
# via
# python-jose
# rsa
pydantic==2.10.6
# via
# llama-stack
# llama-stack-client
# openai
pydantic-core==2.27.2
# via pydantic
pygments==2.19.1
# via rich
python-dateutil==2.9.0.post0
# via pandas
python-dotenv==1.0.1
# via llama-stack
python-jose==3.4.0
# via llama-stack
pytz==2025.1
# via pandas
pyyaml==6.0.2
# via
# huggingface-hub
# pyaml
referencing==0.36.2
# via
# jsonschema
# jsonschema-specifications
regex==2024.11.6
# via tiktoken
requests==2.32.3
# via
# huggingface-hub
# llama-stack
# tiktoken
rich==13.9.4
# via
# llama-stack
# llama-stack-client
rpds-py==0.22.3
# via
# jsonschema
# referencing
rsa==4.9
# via python-jose
setuptools==80.8.0
# via llama-stack
six==1.17.0
# via
# ecdsa
# python-dateutil
sniffio==1.3.1
# via
# anyio
# llama-stack-client
# openai
starlette==0.45.3
# via llama-stack
termcolor==2.5.0
# via
# fire
# llama-stack
# llama-stack-client
tiktoken==0.9.0
# via llama-stack
tqdm==4.67.1
# via
# huggingface-hub
# llama-stack-client
# openai
typing-extensions==4.12.2
# via
# anyio
# huggingface-hub
# llama-stack-client
# multidict
# openai
# pydantic
# pydantic-core
# referencing
# rich
tzdata==2025.1
# via pandas
urllib3==2.3.0
# via requests
wcwidth==0.2.13
# via prompt-toolkit
yarl==1.18.3
# via aiohttp

View file

@ -7,7 +7,6 @@
import concurrent.futures
import importlib
import json
import subprocess
import sys
from collections.abc import Iterable
@ -108,21 +107,6 @@ def collect_template_dependencies(template_dir: Path) -> tuple[str | None, list[
return None, []
def generate_dependencies_file(change_tracker: ChangedPathTracker):
templates_dir = REPO_ROOT / "llama_stack" / "templates"
distribution_deps = {}
for template_dir in find_template_dirs(templates_dir):
name, deps = collect_template_dependencies(template_dir)
if name:
distribution_deps[name] = deps
deps_file = REPO_ROOT / "llama_stack" / "templates" / "dependencies.json"
change_tracker.add_paths(deps_file)
with open(deps_file, "w") as f:
f.write(json.dumps(distribution_deps, indent=2) + "\n")
def main():
templates_dir = REPO_ROOT / "llama_stack" / "templates"
change_tracker = ChangedPathTracker()
@ -143,8 +127,6 @@ def main():
list(executor.map(process_func, template_dirs))
progress.update(task, advance=len(template_dirs))
generate_dependencies_file(change_tracker)
if check_for_changes(change_tracker):
print(
"Distribution template changes detected. Please commit the changes.",

View file

@ -10,10 +10,10 @@ PYTHON_VERSION=${PYTHON_VERSION:-3.10}
command -v uv >/dev/null 2>&1 || { echo >&2 "uv is required but it's not installed. Exiting."; exit 1; }
uv python find $PYTHON_VERSION
uv python find "$PYTHON_VERSION"
FOUND_PYTHON=$?
if [ $FOUND_PYTHON -ne 0 ]; then
uv python install $PYTHON_VERSION
uv python install "$PYTHON_VERSION"
fi
uv run --python $PYTHON_VERSION --with-editable . --with-editable ".[dev]" --with-editable ".[unit]" pytest --asyncio-mode=auto -s -v tests/unit/ $@
uv run --python "$PYTHON_VERSION" --with-editable . --group unit pytest --asyncio-mode=auto -s -v tests/unit/ $@

View file

@ -6,7 +6,6 @@ dependencies = [
"aiohttp",
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",

View file

@ -41,7 +41,6 @@ def openai_client(client_with_models):
],
],
)
@pytest.mark.skip(reason="Very flaky, sometimes there is a message not a function call, standard tool calling issues")
def test_responses_store(openai_client, client_with_models, text_model_id, stream, tools):
if isinstance(client_with_models, LlamaStackAsLibraryClient):
pytest.skip("OpenAI responses are not supported when testing with library client yet.")
@ -68,13 +67,15 @@ def test_responses_store(openai_client, client_with_models, text_model_id, strea
for chunk in response:
if response_id is None:
response_id = chunk.response.id
if not tools:
if chunk.type == "response.completed":
response_id = chunk.response.id
output_type = chunk.response.output[0].type
if output_type == "message":
content = chunk.response.output[0].content[0].text
else:
response_id = response.id
if not tools:
output_type = response.output[0].type
if output_type == "message":
content = response.output[0].content[0].text
# list responses - use the underlying HTTP client for endpoints not in SDK
@ -87,9 +88,8 @@ def test_responses_store(openai_client, client_with_models, text_model_id, strea
retrieved_response = client.responses.retrieve(response_id)
assert retrieved_response.id == response_id
assert retrieved_response.model == text_model_id
if tools:
assert retrieved_response.output[0].type == "function_call"
else:
assert retrieved_response.output[0].type == output_type, retrieved_response
if output_type == "message":
assert retrieved_response.output[0].content[0].text == content

View file

@ -224,6 +224,43 @@ def test_openai_chat_completion_streaming(compat_client, client_with_models, tex
assert expected.lower() in "".join(streamed_content)
@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:streaming_01",
"inference:chat_completion:streaming_02",
],
)
def test_openai_chat_completion_streaming_with_n(compat_client, client_with_models, text_model_id, test_case):
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
provider = provider_from_model(client_with_models, text_model_id)
if provider.provider_type == "remote::ollama":
pytest.skip(f"Model {text_model_id} hosted by {provider.provider_type} doesn't support n > 1.")
tc = TestCase(test_case)
question = tc["question"]
expected = tc["expected"]
response = compat_client.chat.completions.create(
model=text_model_id,
messages=[{"role": "user", "content": question}],
stream=True,
timeout=120, # Increase timeout to 2 minutes for large conversation history,
n=2,
)
streamed_content = {}
for chunk in response:
for choice in chunk.choices:
if choice.delta.content:
streamed_content[choice.index] = (
streamed_content.get(choice.index, "") + choice.delta.content.lower().strip()
)
assert len(streamed_content) == 2
for i, content in streamed_content.items():
assert expected.lower() in content, f"Choice {i}: Expected {expected.lower()} in {content}"
@pytest.mark.parametrize(
"stream",
[
@ -253,6 +290,7 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea
for chunk in response:
if response_id is None:
response_id = chunk.id
if chunk.choices[0].delta.content:
content += chunk.choices[0].delta.content
else:
response_id = response.id
@ -263,8 +301,8 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea
retrieved_response = client.chat.completions.retrieve(response_id)
assert retrieved_response.id == response_id
assert retrieved_response.input_messages[0]["content"] == message
assert retrieved_response.choices[0].message.content == content
assert retrieved_response.input_messages[0]["content"] == message, retrieved_response
assert retrieved_response.choices[0].message.content == content, retrieved_response
@pytest.mark.parametrize(
@ -274,7 +312,6 @@ def test_inference_store(openai_client, client_with_models, text_model_id, strea
False,
],
)
@pytest.mark.skip(reason="Very flaky, tool calling really wacky on CI")
def test_inference_store_tool_calls(openai_client, client_with_models, text_model_id, stream):
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
client = openai_client
@ -312,7 +349,9 @@ def test_inference_store_tool_calls(openai_client, client_with_models, text_mode
for chunk in response:
if response_id is None:
response_id = chunk.id
content += chunk.choices[0].delta.content
if delta := chunk.choices[0].delta:
if delta.content:
content += delta.content
else:
response_id = response.id
content = response.choices[0].message.content
@ -323,5 +362,11 @@ def test_inference_store_tool_calls(openai_client, client_with_models, text_mode
retrieved_response = client.chat.completions.retrieve(response_id)
assert retrieved_response.id == response_id
assert retrieved_response.input_messages[0]["content"] == message
assert retrieved_response.choices[0].message.tool_calls[0].function.name == "get_weather"
assert retrieved_response.choices[0].message.tool_calls[0].function.arguments == '{"city":"Tokyo"}'
tool_calls = retrieved_response.choices[0].message.tool_calls
# sometimes model doesn't ouptut tool calls, but we still want to test that the tool was called
if tool_calls:
assert len(tool_calls) == 1
assert tool_calls[0].function.name == "get_weather"
assert "tokyo" in tool_calls[0].function.arguments.lower()
else:
assert retrieved_response.choices[0].message.content == content

View file

@ -25,10 +25,12 @@ def test_web_search_tool(llama_stack_client, sample_search_query):
if "TAVILY_SEARCH_API_KEY" not in os.environ:
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
tools = llama_stack_client.tool_runtime.list_tools()
assert any(tool.identifier == "web_search" for tool in tools)
response = llama_stack_client.tool_runtime.invoke_tool(
tool_name="web_search", kwargs={"query": sample_search_query}
)
# Verify the response
assert response.content is not None
assert len(response.content) > 0
@ -49,11 +51,12 @@ def test_wolfram_alpha_tool(llama_stack_client, sample_wolfram_alpha_query):
if "WOLFRAM_ALPHA_API_KEY" not in os.environ:
pytest.skip("WOLFRAM_ALPHA_API_KEY not set, skipping test")
tools = llama_stack_client.tool_runtime.list_tools()
assert any(tool.identifier == "wolfram_alpha" for tool in tools)
response = llama_stack_client.tool_runtime.invoke_tool(
tool_name="wolfram_alpha", kwargs={"query": sample_wolfram_alpha_query}
)
print(response.content)
assert response.content is not None
assert len(response.content) > 0
assert isinstance(response.content, str)

View file

@ -31,8 +31,7 @@ def test_mcp_invocation(llama_stack_client, mcp_server):
test_toolgroup_id = MCP_TOOLGROUP_ID
uri = mcp_server["server_url"]
# registering itself should fail since it requires listing tools
with pytest.raises(Exception, match="Unauthorized"):
# registering should not raise an error anymore even if you don't specify the auth token
llama_stack_client.toolgroups.register(
toolgroup_id=test_toolgroup_id,
provider_id="model-context-protocol",
@ -41,27 +40,18 @@ def test_mcp_invocation(llama_stack_client, mcp_server):
provider_data = {
"mcp_headers": {
uri: [
f"Authorization: Bearer {AUTH_TOKEN}",
],
uri: {
"Authorization": f"Bearer {AUTH_TOKEN}",
},
},
}
auth_headers = {
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
}
try:
llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id, extra_headers=auth_headers)
except Exception as e:
# An error is OK since the toolgroup may not exist
print(f"Error unregistering toolgroup: {e}")
with pytest.raises(Exception, match="Unauthorized"):
llama_stack_client.tools.list()
llama_stack_client.toolgroups.register(
toolgroup_id=test_toolgroup_id,
provider_id="model-context-protocol",
mcp_endpoint=dict(uri=uri),
extra_headers=auth_headers,
)
response = llama_stack_client.tools.list(
toolgroup_id=test_toolgroup_id,
extra_headers=auth_headers,

View file

@ -51,7 +51,5 @@ def test_register_and_unregister_toolgroup(llama_stack_client):
with pytest.raises(Exception, match=f"Tool group '{test_toolgroup_id}' not found"):
llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id)
# Verify tools are also unregistered
unregister_tools_list_response = llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id)
assert isinstance(unregister_tools_list_response, list)
assert not unregister_tools_list_response
with pytest.raises(Exception, match=f"Tool group '{test_toolgroup_id}' not found"):
llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id)

View file

@ -15,7 +15,7 @@ from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataS
from llama_stack.apis.datatypes import Api
from llama_stack.apis.models.models import Model, ModelType
from llama_stack.apis.shields.shields import Shield
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolParameter
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter
from llama_stack.apis.vector_dbs.vector_dbs import VectorDB
from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable
from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable
@ -101,11 +101,11 @@ class ToolGroupsImpl(Impl):
def __init__(self):
super().__init__(Api.tool_runtime)
async def register_tool(self, tool):
return tool
async def register_toolgroup(self, toolgroup: ToolGroup):
return toolgroup
async def unregister_tool(self, tool_name: str):
return tool_name
async def unregister_toolgroup(self, toolgroup_id: str):
return toolgroup_id
async def list_runtime_tools(self, toolgroup_id, mcp_endpoint):
return ListToolDefsResponse(

View file

@ -232,9 +232,17 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
# Check that we got the content from our mocked tool execution result
chunks = [chunk async for chunk in result]
assert len(chunks) > 0
assert chunks[0].response.output[0].type == "function_call"
assert chunks[0].response.output[0].name == "get_weather"
assert len(chunks) == 2 # Should have response.created and response.completed
# Check response.created event (should have empty output)
assert chunks[0].type == "response.created"
assert len(chunks[0].response.output) == 0
# Check response.completed event (should have the tool call)
assert chunks[1].type == "response.completed"
assert len(chunks[1].response.output) == 1
assert chunks[1].response.output[0].type == "function_call"
assert chunks[1].response.output[0].name == "get_weather"
@pytest.mark.asyncio
@ -620,3 +628,69 @@ async def test_responses_store_list_input_items_logic():
result = await responses_store.list_response_input_items("resp_123", limit=0, order=Order.asc)
assert result.object == "list"
assert len(result.data) == 0 # Should return no items
@pytest.mark.asyncio
async def test_store_response_uses_rehydrated_input_with_previous_response(
openai_responses_impl, mock_responses_store, mock_inference_api
):
"""Test that _store_response uses the full re-hydrated input (including previous responses)
rather than just the original input when previous_response_id is provided."""
# Setup - Create a previous response that should be included in the stored input
previous_response = OpenAIResponseObjectWithInput(
id="resp-previous-123",
object="response",
created_at=1234567890,
model="meta-llama/Llama-3.1-8B-Instruct",
status="completed",
input=[
OpenAIResponseMessage(
id="msg-prev-user", role="user", content=[OpenAIResponseInputMessageContentText(text="What is 2+2?")]
)
],
output=[
OpenAIResponseMessage(
id="msg-prev-assistant",
role="assistant",
content=[OpenAIResponseOutputMessageContentOutputText(text="2+2 equals 4.")],
)
],
)
mock_responses_store.get_response_object.return_value = previous_response
current_input = "Now what is 3+3?"
model = "meta-llama/Llama-3.1-8B-Instruct"
mock_chat_completion = load_chat_completion_fixture("simple_chat_completion.yaml")
mock_inference_api.openai_chat_completion.return_value = mock_chat_completion
# Execute - Create response with previous_response_id
result = await openai_responses_impl.create_openai_response(
input=current_input,
model=model,
previous_response_id="resp-previous-123",
store=True,
)
store_call_args = mock_responses_store.store_response_object.call_args
stored_input = store_call_args.kwargs["input"]
# Verify that the stored input contains the full re-hydrated conversation:
# 1. Previous user message
# 2. Previous assistant response
# 3. Current user message
assert len(stored_input) == 3
assert stored_input[0].role == "user"
assert stored_input[0].content[0].text == "What is 2+2?"
assert stored_input[1].role == "assistant"
assert stored_input[1].content[0].text == "2+2 equals 4."
assert stored_input[2].role == "user"
assert stored_input[2].content == "Now what is 3+3?"
# Verify the response itself is correct
assert result.model == model
assert result.status == "completed"

View file

@ -27,7 +27,7 @@ export TOGETHER_API_KEY=<your_together_api_key>
```
then run
```bash
uv run --with-editable ".[dev]" python tests/verifications/generate_report.py --run-tests
uv run python tests/verifications/generate_report.py --run-tests
```
## Running Tests

View file

@ -10,17 +10,17 @@ from tests.verifications.openai_api.fixtures.fixtures import _load_all_verificat
def pytest_generate_tests(metafunc):
"""Dynamically parametrize tests based on the selected provider and config."""
if "model" in metafunc.fixturenames:
model = metafunc.config.getoption("model")
if model:
metafunc.parametrize("model", [model])
return
provider = metafunc.config.getoption("provider")
if not provider:
print("Warning: --provider not specified. Skipping model parametrization.")
metafunc.parametrize("model", [])
return
model = metafunc.config.getoption("model")
if model:
metafunc.parametrize("model", [model])
return
try:
config_data = _load_all_verification_configs()
except (OSError, FileNotFoundError) as e:

View file

@ -77,11 +77,12 @@ test_response_image:
image_url: "https://upload.wikimedia.org/wikipedia/commons/f/f7/Llamas%2C_Vernagt-Stausee%2C_Italy.jpg"
output: "llama"
# the models are really poor at tool calling after seeing images :/
test_response_multi_turn_image:
test_name: test_response_multi_turn_image
test_params:
case:
- case_id: "llama_image_search"
- case_id: "llama_image_understanding"
turns:
- input:
- role: user
@ -91,7 +92,5 @@ test_response_multi_turn_image:
- type: input_image
image_url: "https://upload.wikimedia.org/wikipedia/commons/f/f7/Llamas%2C_Vernagt-Stausee%2C_Italy.jpg"
output: "llama"
- input: "Search the web using the search tool for the animal from the previous response. Your search query should be a single phrase that includes the animal's name and the words 'maverick', 'scout' and 'llm'"
tools:
- type: web_search
output: "model"
- input: "What country do you find this animal primarily in? What continent?"
output: "peru"

Some files were not shown because too many files have changed in this diff Show more