mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
Merge branch 'main' into HuggingfacePostTrainingConfig-branch
This commit is contained in:
commit
5ce2f00650
188 changed files with 15561 additions and 8453 deletions
9
llama_stack/apis/batches/__init__.py
Normal file
9
llama_stack/apis/batches/__init__.py
Normal file
|
@ -0,0 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .batches import Batches, BatchObject, ListBatchesResponse
|
||||
|
||||
__all__ = ["Batches", "BatchObject", "ListBatchesResponse"]
|
89
llama_stack/apis/batches/batches.py
Normal file
89
llama_stack/apis/batches/batches.py
Normal file
|
@ -0,0 +1,89 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Literal, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
try:
|
||||
from openai.types import Batch as BatchObject
|
||||
except ImportError as e:
|
||||
raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListBatchesResponse(BaseModel):
|
||||
"""Response containing a list of batch objects."""
|
||||
|
||||
object: Literal["list"] = "list"
|
||||
data: list[BatchObject] = Field(..., description="List of batch objects")
|
||||
first_id: str | None = Field(default=None, description="ID of the first batch in the list")
|
||||
last_id: str | None = Field(default=None, description="ID of the last batch in the list")
|
||||
has_more: bool = Field(default=False, description="Whether there are more batches available")
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Batches(Protocol):
|
||||
"""Protocol for batch processing API operations.
|
||||
|
||||
The Batches API enables efficient processing of multiple requests in a single operation,
|
||||
particularly useful for processing large datasets, batch evaluation workflows, and
|
||||
cost-effective inference at scale.
|
||||
|
||||
Note: This API is currently under active development and may undergo changes.
|
||||
"""
|
||||
|
||||
@webmethod(route="/openai/v1/batches", method="POST")
|
||||
async def create_batch(
|
||||
self,
|
||||
input_file_id: str,
|
||||
endpoint: str,
|
||||
completion_window: Literal["24h"],
|
||||
metadata: dict[str, str] | None = None,
|
||||
) -> BatchObject:
|
||||
"""Create a new batch for processing multiple API requests.
|
||||
|
||||
:param input_file_id: The ID of an uploaded file containing requests for the batch.
|
||||
:param endpoint: The endpoint to be used for all requests in the batch.
|
||||
:param completion_window: The time window within which the batch should be processed.
|
||||
:param metadata: Optional metadata for the batch.
|
||||
:returns: The created batch object.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/batches/{batch_id}", method="GET")
|
||||
async def retrieve_batch(self, batch_id: str) -> BatchObject:
|
||||
"""Retrieve information about a specific batch.
|
||||
|
||||
:param batch_id: The ID of the batch to retrieve.
|
||||
:returns: The batch object.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/batches/{batch_id}/cancel", method="POST")
|
||||
async def cancel_batch(self, batch_id: str) -> BatchObject:
|
||||
"""Cancel a batch that is in progress.
|
||||
|
||||
:param batch_id: The ID of the batch to cancel.
|
||||
:returns: The updated batch object.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/batches", method="GET")
|
||||
async def list_batches(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int = 20,
|
||||
) -> ListBatchesResponse:
|
||||
"""List all batches for the current user.
|
||||
|
||||
:param after: A cursor for pagination; returns batches after this batch ID.
|
||||
:param limit: Number of batches to return (default 20, max 100).
|
||||
:returns: A list of batch objects.
|
||||
"""
|
||||
...
|
|
@ -72,3 +72,10 @@ class ModelTypeError(TypeError):
|
|||
f"Model '{model_name}' is of type '{model_type}' rather than the expected type '{expected_model_type}'"
|
||||
)
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ConflictError(ValueError):
|
||||
"""raised when an operation cannot be performed due to a conflict with the current state"""
|
||||
|
||||
def __init__(self, message: str) -> None:
|
||||
super().__init__(message)
|
||||
|
|
|
@ -86,6 +86,7 @@ class Api(Enum, metaclass=DynamicApiMeta):
|
|||
:cvar inference: Text generation, chat completions, and embeddings
|
||||
:cvar safety: Content moderation and safety shields
|
||||
:cvar agents: Agent orchestration and execution
|
||||
:cvar batches: Batch processing for asynchronous API requests
|
||||
:cvar vector_io: Vector database operations and queries
|
||||
:cvar datasetio: Dataset input/output operations
|
||||
:cvar scoring: Model output evaluation and scoring
|
||||
|
@ -108,6 +109,7 @@ class Api(Enum, metaclass=DynamicApiMeta):
|
|||
inference = "inference"
|
||||
safety = "safety"
|
||||
agents = "agents"
|
||||
batches = "batches"
|
||||
vector_io = "vector_io"
|
||||
datasetio = "datasetio"
|
||||
scoring = "scoring"
|
||||
|
|
|
@ -22,6 +22,7 @@ class OpenAIFilePurpose(StrEnum):
|
|||
"""
|
||||
|
||||
ASSISTANTS = "assistants"
|
||||
BATCH = "batch"
|
||||
# TODO: Add other purposes as needed
|
||||
|
||||
|
||||
|
|
|
@ -1,207 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
||||
LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-}
|
||||
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
||||
PYPI_VERSION=${PYPI_VERSION:-}
|
||||
# This timeout (in seconds) is necessary when installing PyTorch via uv since it's likely to time out
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# Define color codes
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
||||
source "$SCRIPT_DIR/common.sh"
|
||||
|
||||
# Usage function
|
||||
usage() {
|
||||
echo "Usage: $0 --env-name <conda_env_name> --build-file-path <build_file_path> --normal-deps <pip_dependencies> [--external-provider-deps <external_provider_deps>] [--optional-deps <special_pip_deps>]"
|
||||
echo "Example: $0 --env-name my-conda-env --build-file-path ./my-stack-build.yaml --normal-deps 'numpy pandas scipy' --external-provider-deps 'foo' --optional-deps 'bar'"
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Parse arguments
|
||||
env_name=""
|
||||
build_file_path=""
|
||||
normal_deps=""
|
||||
external_provider_deps=""
|
||||
optional_deps=""
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
key="$1"
|
||||
case "$key" in
|
||||
--env-name)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --env-name requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
env_name="$2"
|
||||
shift 2
|
||||
;;
|
||||
--build-file-path)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --build-file-path requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
build_file_path="$2"
|
||||
shift 2
|
||||
;;
|
||||
--normal-deps)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --normal-deps requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
normal_deps="$2"
|
||||
shift 2
|
||||
;;
|
||||
--external-provider-deps)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --external-provider-deps requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
external_provider_deps="$2"
|
||||
shift 2
|
||||
;;
|
||||
--optional-deps)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --optional-deps requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
optional_deps="$2"
|
||||
shift 2
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option: $1" >&2
|
||||
usage
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Check required arguments
|
||||
if [[ -z "$env_name" || -z "$build_file_path" || -z "$normal_deps" ]]; then
|
||||
echo "Error: --env-name, --build-file-path, and --normal-deps are required." >&2
|
||||
usage
|
||||
fi
|
||||
|
||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||
echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
|
||||
fi
|
||||
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||
echo "Using llama-stack-client-dir=$LLAMA_STACK_CLIENT_DIR"
|
||||
fi
|
||||
|
||||
ensure_conda_env_python310() {
|
||||
# Use only global variables set by flag parser
|
||||
local python_version="3.12"
|
||||
|
||||
if ! is_command_available conda; then
|
||||
printf "${RED}Error: conda command not found. Is Conda installed and in your PATH?${NC}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if conda env list | grep -q "^${env_name} "; then
|
||||
printf "Conda environment '${env_name}' exists. Checking Python version...\n"
|
||||
current_version=$(conda run -n "${env_name}" python --version 2>&1 | cut -d' ' -f2 | cut -d'.' -f1,2)
|
||||
if [ "$current_version" = "$python_version" ]; then
|
||||
printf "Environment '${env_name}' already has Python ${python_version}. No action needed.\n"
|
||||
else
|
||||
printf "Updating environment '${env_name}' to Python ${python_version}...\n"
|
||||
conda install -n "${env_name}" python="${python_version}" -y
|
||||
fi
|
||||
else
|
||||
printf "Conda environment '${env_name}' does not exist. Creating with Python ${python_version}...\n"
|
||||
conda create -n "${env_name}" python="${python_version}" -y
|
||||
fi
|
||||
|
||||
eval "$(conda shell.bash hook)"
|
||||
conda deactivate && conda activate "${env_name}"
|
||||
"$CONDA_PREFIX"/bin/pip install uv
|
||||
|
||||
if [ -n "$TEST_PYPI_VERSION" ]; then
|
||||
uv pip install fastapi libcst
|
||||
uv pip install --extra-index-url https://test.pypi.org/simple/ \
|
||||
llama-stack=="$TEST_PYPI_VERSION" \
|
||||
"$normal_deps"
|
||||
if [ -n "$optional_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$optional_deps"
|
||||
for part in "${parts[@]}"; do
|
||||
echo "$part"
|
||||
uv pip install $part
|
||||
done
|
||||
fi
|
||||
if [ -n "$external_provider_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$external_provider_deps"
|
||||
for part in "${parts[@]}"; do
|
||||
echo "$part"
|
||||
uv pip install "$part"
|
||||
done
|
||||
fi
|
||||
else
|
||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||
if [ ! -d "$LLAMA_STACK_DIR" ]; then
|
||||
printf "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: $LLAMA_STACK_DIR${NC}\n" >&2
|
||||
exit 1
|
||||
fi
|
||||
printf "Installing from LLAMA_STACK_DIR: $LLAMA_STACK_DIR\n"
|
||||
uv pip install --no-cache-dir -e "$LLAMA_STACK_DIR"
|
||||
else
|
||||
PYPI_VERSION="${PYPI_VERSION:-}"
|
||||
if [ -n "$PYPI_VERSION" ]; then
|
||||
SPEC_VERSION="llama-stack==${PYPI_VERSION}"
|
||||
else
|
||||
SPEC_VERSION="llama-stack"
|
||||
fi
|
||||
uv pip install --no-cache-dir "$SPEC_VERSION"
|
||||
fi
|
||||
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||
if [ ! -d "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||
printf "${RED}Warning: LLAMA_STACK_CLIENT_DIR is set but directory does not exist: $LLAMA_STACK_CLIENT_DIR${NC}\n" >&2
|
||||
exit 1
|
||||
fi
|
||||
printf "Installing from LLAMA_STACK_CLIENT_DIR: $LLAMA_STACK_CLIENT_DIR\n"
|
||||
uv pip install --no-cache-dir -e "$LLAMA_STACK_CLIENT_DIR"
|
||||
fi
|
||||
printf "Installing pip dependencies\n"
|
||||
uv pip install $normal_deps
|
||||
if [ -n "$optional_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$optional_deps"
|
||||
for part in "${parts[@]}"; do
|
||||
echo "$part"
|
||||
uv pip install $part
|
||||
done
|
||||
fi
|
||||
if [ -n "$external_provider_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$external_provider_deps"
|
||||
for part in "${parts[@]}"; do
|
||||
echo "Getting provider spec for module: $part and installing dependencies"
|
||||
package_name=$(echo "$part" | sed 's/[<>=!].*//')
|
||||
python3 -c "
|
||||
import importlib
|
||||
import sys
|
||||
try:
|
||||
module = importlib.import_module(f'$package_name.provider')
|
||||
spec = module.get_provider_spec()
|
||||
if hasattr(spec, 'pip_packages') and spec.pip_packages:
|
||||
print('\\n'.join(spec.pip_packages))
|
||||
except Exception as e:
|
||||
print(f'Error getting provider spec for $package_name: {e}', file=sys.stderr)
|
||||
" | uv pip install -r -
|
||||
done
|
||||
fi
|
||||
fi
|
||||
mv "$build_file_path" "$CONDA_PREFIX"/llamastack-build.yaml
|
||||
echo "Build spec configuration saved at $CONDA_PREFIX/llamastack-build.yaml"
|
||||
}
|
||||
|
||||
ensure_conda_env_python310 "$env_name" "$build_file_path" "$normal_deps" "$optional_deps" "$external_provider_deps"
|
|
@ -151,23 +151,37 @@ run() {
|
|||
fi
|
||||
else
|
||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||
if [ ! -d "$LLAMA_STACK_DIR" ]; then
|
||||
# only warn if DIR does not start with "git+"
|
||||
if [ ! -d "$LLAMA_STACK_DIR" ] && [[ "$LLAMA_STACK_DIR" != git+* ]]; then
|
||||
printf "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_STACK_DIR" >&2
|
||||
exit 1
|
||||
fi
|
||||
printf "Installing from LLAMA_STACK_DIR: %s\n" "$LLAMA_STACK_DIR"
|
||||
uv pip install --no-cache-dir -e "$LLAMA_STACK_DIR"
|
||||
# editable only if LLAMA_STACK_DIR does not start with "git+"
|
||||
if [[ "$LLAMA_STACK_DIR" != git+* ]]; then
|
||||
EDITABLE="-e"
|
||||
else
|
||||
EDITABLE=""
|
||||
fi
|
||||
uv pip install --no-cache-dir $EDITABLE "$LLAMA_STACK_DIR"
|
||||
else
|
||||
uv pip install --no-cache-dir llama-stack
|
||||
fi
|
||||
|
||||
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||
if [ ! -d "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||
# only warn if DIR does not start with "git+"
|
||||
if [ ! -d "$LLAMA_STACK_CLIENT_DIR" ] && [[ "$LLAMA_STACK_CLIENT_DIR" != git+* ]]; then
|
||||
printf "${RED}Warning: LLAMA_STACK_CLIENT_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_STACK_CLIENT_DIR" >&2
|
||||
exit 1
|
||||
fi
|
||||
printf "Installing from LLAMA_STACK_CLIENT_DIR: %s\n" "$LLAMA_STACK_CLIENT_DIR"
|
||||
uv pip install --no-cache-dir -e "$LLAMA_STACK_CLIENT_DIR"
|
||||
# editable only if LLAMA_STACK_CLIENT_DIR does not start with "git+"
|
||||
if [[ "$LLAMA_STACK_CLIENT_DIR" != git+* ]]; then
|
||||
EDITABLE="-e"
|
||||
else
|
||||
EDITABLE=""
|
||||
fi
|
||||
uv pip install --no-cache-dir $EDITABLE "$LLAMA_STACK_CLIENT_DIR"
|
||||
fi
|
||||
|
||||
printf "Installing pip dependencies\n"
|
||||
|
|
|
@ -8,6 +8,7 @@ import inspect
|
|||
from typing import Any
|
||||
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.batches import Batches
|
||||
from llama_stack.apis.benchmarks import Benchmarks
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
|
@ -75,6 +76,7 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
|
|||
Api.agents: Agents,
|
||||
Api.inference: Inference,
|
||||
Api.inspect: Inspect,
|
||||
Api.batches: Batches,
|
||||
Api.vector_io: VectorIO,
|
||||
Api.vector_dbs: VectorDBs,
|
||||
Api.models: Models,
|
||||
|
|
|
@ -6,9 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
Message,
|
||||
)
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||
from llama_stack.apis.safety.safety import ModerationObject
|
||||
from llama_stack.apis.shields import Shield
|
||||
|
@ -68,6 +66,7 @@ class SafetyRouter(Safety):
|
|||
list_shields_response = await self.routing_table.list_shields()
|
||||
|
||||
matches = [s.identifier for s in list_shields_response.data if model == s.provider_resource_id]
|
||||
|
||||
if not matches:
|
||||
raise ValueError(f"No shield associated with provider_resource id {model}")
|
||||
if len(matches) > 1:
|
||||
|
|
|
@ -32,6 +32,7 @@ from fastapi.responses import JSONResponse, StreamingResponse
|
|||
from openai import BadRequestError
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.cli.utils import add_config_distro_args, get_config_from_args
|
||||
from llama_stack.core.access_control.access_control import AccessDeniedError
|
||||
|
@ -128,6 +129,10 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
|
|||
]
|
||||
},
|
||||
)
|
||||
elif isinstance(exc, ConflictError):
|
||||
return HTTPException(status_code=409, detail=str(exc))
|
||||
elif isinstance(exc, ResourceNotFoundError):
|
||||
return HTTPException(status_code=404, detail=str(exc))
|
||||
elif isinstance(exc, ValueError):
|
||||
return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=f"Invalid value: {str(exc)}")
|
||||
elif isinstance(exc, BadRequestError):
|
||||
|
|
|
@ -28,6 +28,7 @@ distribution_spec:
|
|||
- provider_type: inline::localfs
|
||||
safety:
|
||||
- provider_type: inline::llama-guard
|
||||
- provider_type: inline::code-scanner
|
||||
agents:
|
||||
- provider_type: inline::meta-reference
|
||||
telemetry:
|
||||
|
@ -48,6 +49,8 @@ distribution_spec:
|
|||
- provider_type: remote::tavily-search
|
||||
- provider_type: inline::rag-runtime
|
||||
- provider_type: remote::model-context-protocol
|
||||
batches:
|
||||
- provider_type: inline::reference
|
||||
image_type: venv
|
||||
additional_pip_packages:
|
||||
- aiosqlite
|
||||
|
|
|
@ -2,6 +2,7 @@ version: 2
|
|||
image_name: ci-tests
|
||||
apis:
|
||||
- agents
|
||||
- batches
|
||||
- datasetio
|
||||
- eval
|
||||
- files
|
||||
|
@ -134,6 +135,8 @@ providers:
|
|||
provider_type: inline::llama-guard
|
||||
config:
|
||||
excluded_categories: []
|
||||
- provider_id: code-scanner
|
||||
provider_type: inline::code-scanner
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
@ -204,6 +207,13 @@ providers:
|
|||
provider_type: inline::rag-runtime
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
batches:
|
||||
- provider_id: reference
|
||||
provider_type: inline::reference
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/batches.db
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/registry.db
|
||||
|
@ -215,6 +225,9 @@ shields:
|
|||
- shield_id: llama-guard
|
||||
provider_id: ${env.SAFETY_MODEL:+llama-guard}
|
||||
provider_shield_id: ${env.SAFETY_MODEL:=}
|
||||
- shield_id: code-scanner
|
||||
provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner}
|
||||
provider_shield_id: ${env.CODE_SCANNER_MODEL:=}
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
|
|
|
@ -28,6 +28,7 @@ distribution_spec:
|
|||
- provider_type: inline::localfs
|
||||
safety:
|
||||
- provider_type: inline::llama-guard
|
||||
- provider_type: inline::code-scanner
|
||||
agents:
|
||||
- provider_type: inline::meta-reference
|
||||
telemetry:
|
||||
|
@ -48,6 +49,8 @@ distribution_spec:
|
|||
- provider_type: remote::tavily-search
|
||||
- provider_type: inline::rag-runtime
|
||||
- provider_type: remote::model-context-protocol
|
||||
batches:
|
||||
- provider_type: inline::reference
|
||||
image_type: venv
|
||||
additional_pip_packages:
|
||||
- aiosqlite
|
||||
|
|
|
@ -2,6 +2,7 @@ version: 2
|
|||
image_name: starter
|
||||
apis:
|
||||
- agents
|
||||
- batches
|
||||
- datasetio
|
||||
- eval
|
||||
- files
|
||||
|
@ -134,6 +135,8 @@ providers:
|
|||
provider_type: inline::llama-guard
|
||||
config:
|
||||
excluded_categories: []
|
||||
- provider_id: code-scanner
|
||||
provider_type: inline::code-scanner
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
@ -204,6 +207,13 @@ providers:
|
|||
provider_type: inline::rag-runtime
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
batches:
|
||||
- provider_id: reference
|
||||
provider_type: inline::reference
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/batches.db
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/registry.db
|
||||
|
@ -215,6 +225,9 @@ shields:
|
|||
- shield_id: llama-guard
|
||||
provider_id: ${env.SAFETY_MODEL:+llama-guard}
|
||||
provider_shield_id: ${env.SAFETY_MODEL:=}
|
||||
- shield_id: code-scanner
|
||||
provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner}
|
||||
provider_shield_id: ${env.CODE_SCANNER_MODEL:=}
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
|
|
|
@ -15,19 +15,14 @@ from llama_stack.core.datatypes import (
|
|||
ToolGroupInput,
|
||||
)
|
||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.distributions.template import (
|
||||
DistributionTemplate,
|
||||
RunConfigSettings,
|
||||
)
|
||||
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings
|
||||
from llama_stack.providers.datatypes import RemoteProviderSpec
|
||||
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
|
||||
from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||
SentenceTransformersInferenceConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.milvus.config import (
|
||||
MilvusVectorIOConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec.config import (
|
||||
SQLiteVectorIOConfig,
|
||||
)
|
||||
|
@ -119,7 +114,10 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
BuildProvider(provider_type="remote::pgvector"),
|
||||
],
|
||||
"files": [BuildProvider(provider_type="inline::localfs")],
|
||||
"safety": [BuildProvider(provider_type="inline::llama-guard")],
|
||||
"safety": [
|
||||
BuildProvider(provider_type="inline::llama-guard"),
|
||||
BuildProvider(provider_type="inline::code-scanner"),
|
||||
],
|
||||
"agents": [BuildProvider(provider_type="inline::meta-reference")],
|
||||
"telemetry": [BuildProvider(provider_type="inline::meta-reference")],
|
||||
"post_training": [BuildProvider(provider_type="inline::huggingface")],
|
||||
|
@ -139,6 +137,9 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
BuildProvider(provider_type="inline::rag-runtime"),
|
||||
BuildProvider(provider_type="remote::model-context-protocol"),
|
||||
],
|
||||
"batches": [
|
||||
BuildProvider(provider_type="inline::reference"),
|
||||
],
|
||||
}
|
||||
files_provider = Provider(
|
||||
provider_id="meta-reference-files",
|
||||
|
@ -167,6 +168,11 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
provider_id="${env.SAFETY_MODEL:+llama-guard}",
|
||||
provider_shield_id="${env.SAFETY_MODEL:=}",
|
||||
),
|
||||
ShieldInput(
|
||||
shield_id="code-scanner",
|
||||
provider_id="${env.CODE_SCANNER_MODEL:+code-scanner}",
|
||||
provider_shield_id="${env.CODE_SCANNER_MODEL:=}",
|
||||
),
|
||||
]
|
||||
|
||||
return DistributionTemplate(
|
||||
|
|
|
@ -7,13 +7,11 @@
|
|||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from logging.config import dictConfig
|
||||
|
||||
from rich.console import Console
|
||||
from rich.errors import MarkupError
|
||||
from rich.logging import RichHandler
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.core.datatypes import LoggingConfig
|
||||
|
||||
|
@ -66,7 +64,6 @@ def config_to_category_levels(category: str, level: str):
|
|||
category_levels["root"] = level_value
|
||||
elif category in CATEGORIES:
|
||||
category_levels[category] = level_value
|
||||
logging.info(f"Setting '{category}' category to level '{level}'.")
|
||||
else:
|
||||
logging.warning(f"Unknown logging category: {category}. No changes made.")
|
||||
return category_levels
|
||||
|
@ -256,7 +253,6 @@ def get_logger(
|
|||
|
||||
env_config = os.environ.get("LLAMA_STACK_LOGGING", "")
|
||||
if env_config:
|
||||
cprint(f"Environment variable LLAMA_STACK_LOGGING found: {env_config}", color="yellow", file=sys.stderr)
|
||||
_category_levels.update(parse_environment_config(env_config))
|
||||
|
||||
log_file = os.environ.get("LLAMA_STACK_LOG_FILE")
|
||||
|
|
|
@ -48,8 +48,8 @@ from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
|||
|
||||
from .agent_instance import ChatAgent
|
||||
from .config import MetaReferenceAgentsImplConfig
|
||||
from .openai_responses import OpenAIResponsesImpl
|
||||
from .persistence import AgentInfo
|
||||
from .responses.openai_responses import OpenAIResponsesImpl
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -0,0 +1,271 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.agents import Order
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
ListOpenAIResponseInputItem,
|
||||
ListOpenAIResponseObject,
|
||||
OpenAIDeleteResponseObject,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponseText,
|
||||
OpenAIResponseTextFormat,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
OpenAISystemMessageParam,
|
||||
)
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||
|
||||
from .streaming import StreamingResponseOrchestrator
|
||||
from .tool_executor import ToolExecutor
|
||||
from .types import ChatCompletionContext
|
||||
from .utils import (
|
||||
convert_response_input_to_chat_messages,
|
||||
convert_response_text_to_chat_response_format,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="responses")
|
||||
|
||||
|
||||
class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
|
||||
input_items: ListOpenAIResponseInputItem
|
||||
response: OpenAIResponseObject
|
||||
|
||||
|
||||
class OpenAIResponsesImpl:
|
||||
def __init__(
|
||||
self,
|
||||
inference_api: Inference,
|
||||
tool_groups_api: ToolGroups,
|
||||
tool_runtime_api: ToolRuntime,
|
||||
responses_store: ResponsesStore,
|
||||
vector_io_api: VectorIO, # VectorIO
|
||||
):
|
||||
self.inference_api = inference_api
|
||||
self.tool_groups_api = tool_groups_api
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
self.responses_store = responses_store
|
||||
self.vector_io_api = vector_io_api
|
||||
self.tool_executor = ToolExecutor(
|
||||
tool_groups_api=tool_groups_api,
|
||||
tool_runtime_api=tool_runtime_api,
|
||||
vector_io_api=vector_io_api,
|
||||
)
|
||||
|
||||
async def _prepend_previous_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
previous_response_id: str | None = None,
|
||||
):
|
||||
if previous_response_id:
|
||||
previous_response_with_input = await self.responses_store.get_response_object(previous_response_id)
|
||||
|
||||
# previous response input items
|
||||
new_input_items = previous_response_with_input.input
|
||||
|
||||
# previous response output items
|
||||
new_input_items.extend(previous_response_with_input.output)
|
||||
|
||||
# new input items from the current request
|
||||
if isinstance(input, str):
|
||||
new_input_items.append(OpenAIResponseMessage(content=input, role="user"))
|
||||
else:
|
||||
new_input_items.extend(input)
|
||||
|
||||
input = new_input_items
|
||||
|
||||
return input
|
||||
|
||||
async def _prepend_instructions(self, messages, instructions):
|
||||
if instructions:
|
||||
messages.insert(0, OpenAISystemMessageParam(content=instructions))
|
||||
|
||||
async def get_openai_response(
|
||||
self,
|
||||
response_id: str,
|
||||
) -> OpenAIResponseObject:
|
||||
response_with_input = await self.responses_store.get_response_object(response_id)
|
||||
return OpenAIResponseObject(**{k: v for k, v in response_with_input.model_dump().items() if k != "input"})
|
||||
|
||||
async def list_openai_responses(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int | None = 50,
|
||||
model: str | None = None,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIResponseObject:
|
||||
return await self.responses_store.list_responses(after, limit, model, order)
|
||||
|
||||
async def list_openai_response_input_items(
|
||||
self,
|
||||
response_id: str,
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
include: list[str] | None = None,
|
||||
limit: int | None = 20,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIResponseInputItem:
|
||||
"""List input items for a given OpenAI response.
|
||||
|
||||
:param response_id: The ID of the response to retrieve input items for.
|
||||
:param after: An item ID to list items after, used for pagination.
|
||||
:param before: An item ID to list items before, used for pagination.
|
||||
:param include: Additional fields to include in the response.
|
||||
:param limit: A limit on the number of objects to be returned.
|
||||
:param order: The order to return the input items in.
|
||||
:returns: An ListOpenAIResponseInputItem.
|
||||
"""
|
||||
return await self.responses_store.list_response_input_items(response_id, after, before, include, limit, order)
|
||||
|
||||
async def _store_response(
|
||||
self,
|
||||
response: OpenAIResponseObject,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
) -> None:
|
||||
new_input_id = f"msg_{uuid.uuid4()}"
|
||||
if isinstance(input, str):
|
||||
# synthesize a message from the input string
|
||||
input_content = OpenAIResponseInputMessageContentText(text=input)
|
||||
input_content_item = OpenAIResponseMessage(
|
||||
role="user",
|
||||
content=[input_content],
|
||||
id=new_input_id,
|
||||
)
|
||||
input_items_data = [input_content_item]
|
||||
else:
|
||||
# we already have a list of messages
|
||||
input_items_data = []
|
||||
for input_item in input:
|
||||
if isinstance(input_item, OpenAIResponseMessage):
|
||||
# These may or may not already have an id, so dump to dict, check for id, and add if missing
|
||||
input_item_dict = input_item.model_dump()
|
||||
if "id" not in input_item_dict:
|
||||
input_item_dict["id"] = new_input_id
|
||||
input_items_data.append(OpenAIResponseMessage(**input_item_dict))
|
||||
else:
|
||||
input_items_data.append(input_item)
|
||||
|
||||
await self.responses_store.store_response_object(
|
||||
response_object=response,
|
||||
input=input_items_data,
|
||||
)
|
||||
|
||||
async def create_openai_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
model: str,
|
||||
instructions: str | None = None,
|
||||
previous_response_id: str | None = None,
|
||||
store: bool | None = True,
|
||||
stream: bool | None = False,
|
||||
temperature: float | None = None,
|
||||
text: OpenAIResponseText | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
include: list[str] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
):
|
||||
stream = bool(stream)
|
||||
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
|
||||
|
||||
stream_gen = self._create_streaming_response(
|
||||
input=input,
|
||||
model=model,
|
||||
instructions=instructions,
|
||||
previous_response_id=previous_response_id,
|
||||
store=store,
|
||||
temperature=temperature,
|
||||
text=text,
|
||||
tools=tools,
|
||||
max_infer_iters=max_infer_iters,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return stream_gen
|
||||
else:
|
||||
response = None
|
||||
async for stream_chunk in stream_gen:
|
||||
if stream_chunk.type == "response.completed":
|
||||
if response is not None:
|
||||
raise ValueError("The response stream completed multiple times! Earlier response: {response}")
|
||||
response = stream_chunk.response
|
||||
# don't leave the generator half complete!
|
||||
|
||||
if response is None:
|
||||
raise ValueError("The response stream never completed")
|
||||
return response
|
||||
|
||||
async def _create_streaming_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
model: str,
|
||||
instructions: str | None = None,
|
||||
previous_response_id: str | None = None,
|
||||
store: bool | None = True,
|
||||
temperature: float | None = None,
|
||||
text: OpenAIResponseText | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Input preprocessing
|
||||
input = await self._prepend_previous_response(input, previous_response_id)
|
||||
messages = await convert_response_input_to_chat_messages(input)
|
||||
await self._prepend_instructions(messages, instructions)
|
||||
|
||||
# Structured outputs
|
||||
response_format = await convert_response_text_to_chat_response_format(text)
|
||||
|
||||
ctx = ChatCompletionContext(
|
||||
model=model,
|
||||
messages=messages,
|
||||
response_tools=tools,
|
||||
temperature=temperature,
|
||||
response_format=response_format,
|
||||
)
|
||||
|
||||
# Create orchestrator and delegate streaming logic
|
||||
response_id = f"resp-{uuid.uuid4()}"
|
||||
created_at = int(time.time())
|
||||
|
||||
orchestrator = StreamingResponseOrchestrator(
|
||||
inference_api=self.inference_api,
|
||||
ctx=ctx,
|
||||
response_id=response_id,
|
||||
created_at=created_at,
|
||||
text=text,
|
||||
max_infer_iters=max_infer_iters,
|
||||
tool_executor=self.tool_executor,
|
||||
)
|
||||
|
||||
# Stream the response
|
||||
final_response = None
|
||||
async for stream_chunk in orchestrator.create_response():
|
||||
if stream_chunk.type == "response.completed":
|
||||
final_response = stream_chunk.response
|
||||
yield stream_chunk
|
||||
|
||||
# Store the response if requested
|
||||
if store and final_response:
|
||||
await self._store_response(
|
||||
response=final_response,
|
||||
input=input,
|
||||
)
|
||||
|
||||
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||
return await self.responses_store.delete_response_object(response_id)
|
|
@ -0,0 +1,634 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
AllowedToolsFilter,
|
||||
MCPListToolsTool,
|
||||
OpenAIResponseContentPartOutputText,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponseObjectStreamResponseCompleted,
|
||||
OpenAIResponseObjectStreamResponseContentPartAdded,
|
||||
OpenAIResponseObjectStreamResponseContentPartDone,
|
||||
OpenAIResponseObjectStreamResponseCreated,
|
||||
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta,
|
||||
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone,
|
||||
OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta,
|
||||
OpenAIResponseObjectStreamResponseMcpCallArgumentsDone,
|
||||
OpenAIResponseObjectStreamResponseMcpListToolsCompleted,
|
||||
OpenAIResponseObjectStreamResponseMcpListToolsInProgress,
|
||||
OpenAIResponseObjectStreamResponseOutputItemAdded,
|
||||
OpenAIResponseObjectStreamResponseOutputItemDone,
|
||||
OpenAIResponseObjectStreamResponseOutputTextDelta,
|
||||
OpenAIResponseOutput,
|
||||
OpenAIResponseOutputMessageFunctionToolCall,
|
||||
OpenAIResponseOutputMessageMCPListTools,
|
||||
OpenAIResponseText,
|
||||
WebSearchToolTypes,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChoice,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .types import ChatCompletionContext, ChatCompletionResult
|
||||
from .utils import convert_chat_choice_to_response_message, is_function_tool_call
|
||||
|
||||
logger = get_logger(name=__name__, category="responses")
|
||||
|
||||
|
||||
class StreamingResponseOrchestrator:
|
||||
def __init__(
|
||||
self,
|
||||
inference_api: Inference,
|
||||
ctx: ChatCompletionContext,
|
||||
response_id: str,
|
||||
created_at: int,
|
||||
text: OpenAIResponseText,
|
||||
max_infer_iters: int,
|
||||
tool_executor, # Will be the tool execution logic from the main class
|
||||
):
|
||||
self.inference_api = inference_api
|
||||
self.ctx = ctx
|
||||
self.response_id = response_id
|
||||
self.created_at = created_at
|
||||
self.text = text
|
||||
self.max_infer_iters = max_infer_iters
|
||||
self.tool_executor = tool_executor
|
||||
self.sequence_number = 0
|
||||
# Store MCP tool mapping that gets built during tool processing
|
||||
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = {}
|
||||
|
||||
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Initialize output messages
|
||||
output_messages: list[OpenAIResponseOutput] = []
|
||||
# Create initial response and emit response.created immediately
|
||||
initial_response = OpenAIResponseObject(
|
||||
created_at=self.created_at,
|
||||
id=self.response_id,
|
||||
model=self.ctx.model,
|
||||
object="response",
|
||||
status="in_progress",
|
||||
output=output_messages.copy(),
|
||||
text=self.text,
|
||||
)
|
||||
|
||||
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
|
||||
|
||||
# Process all tools (including MCP tools) and emit streaming events
|
||||
if self.ctx.response_tools:
|
||||
async for stream_event in self._process_tools(self.ctx.response_tools, output_messages):
|
||||
yield stream_event
|
||||
|
||||
n_iter = 0
|
||||
messages = self.ctx.messages.copy()
|
||||
|
||||
while True:
|
||||
completion_result = await self.inference_api.openai_chat_completion(
|
||||
model=self.ctx.model,
|
||||
messages=messages,
|
||||
tools=self.ctx.chat_tools,
|
||||
stream=True,
|
||||
temperature=self.ctx.temperature,
|
||||
response_format=self.ctx.response_format,
|
||||
)
|
||||
|
||||
# Process streaming chunks and build complete response
|
||||
completion_result_data = None
|
||||
async for stream_event_or_result in self._process_streaming_chunks(completion_result, output_messages):
|
||||
if isinstance(stream_event_or_result, ChatCompletionResult):
|
||||
completion_result_data = stream_event_or_result
|
||||
else:
|
||||
yield stream_event_or_result
|
||||
if not completion_result_data:
|
||||
raise ValueError("Streaming chunk processor failed to return completion data")
|
||||
current_response = self._build_chat_completion(completion_result_data)
|
||||
|
||||
function_tool_calls, non_function_tool_calls, next_turn_messages = self._separate_tool_calls(
|
||||
current_response, messages
|
||||
)
|
||||
|
||||
# Handle choices with no tool calls
|
||||
for choice in current_response.choices:
|
||||
if not (choice.message.tool_calls and self.ctx.response_tools):
|
||||
output_messages.append(await convert_chat_choice_to_response_message(choice))
|
||||
|
||||
# Execute tool calls and coordinate results
|
||||
async for stream_event in self._coordinate_tool_execution(
|
||||
function_tool_calls,
|
||||
non_function_tool_calls,
|
||||
completion_result_data,
|
||||
output_messages,
|
||||
next_turn_messages,
|
||||
):
|
||||
yield stream_event
|
||||
|
||||
if not function_tool_calls and not non_function_tool_calls:
|
||||
break
|
||||
|
||||
if function_tool_calls:
|
||||
logger.info("Exiting inference loop since there is a function (client-side) tool call")
|
||||
break
|
||||
|
||||
n_iter += 1
|
||||
if n_iter >= self.max_infer_iters:
|
||||
logger.info(f"Exiting inference loop since iteration count({n_iter}) exceeds {self.max_infer_iters=}")
|
||||
break
|
||||
|
||||
messages = next_turn_messages
|
||||
|
||||
# Create final response
|
||||
final_response = OpenAIResponseObject(
|
||||
created_at=self.created_at,
|
||||
id=self.response_id,
|
||||
model=self.ctx.model,
|
||||
object="response",
|
||||
status="completed",
|
||||
text=self.text,
|
||||
output=output_messages,
|
||||
)
|
||||
|
||||
# Emit response.completed
|
||||
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
|
||||
|
||||
def _separate_tool_calls(self, current_response, messages) -> tuple[list, list, list]:
|
||||
"""Separate tool calls into function and non-function categories."""
|
||||
function_tool_calls = []
|
||||
non_function_tool_calls = []
|
||||
next_turn_messages = messages.copy()
|
||||
|
||||
for choice in current_response.choices:
|
||||
next_turn_messages.append(choice.message)
|
||||
|
||||
if choice.message.tool_calls and self.ctx.response_tools:
|
||||
for tool_call in choice.message.tool_calls:
|
||||
if is_function_tool_call(tool_call, self.ctx.response_tools):
|
||||
function_tool_calls.append(tool_call)
|
||||
else:
|
||||
non_function_tool_calls.append(tool_call)
|
||||
|
||||
return function_tool_calls, non_function_tool_calls, next_turn_messages
|
||||
|
||||
async def _process_streaming_chunks(
|
||||
self, completion_result, output_messages: list[OpenAIResponseOutput]
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream | ChatCompletionResult]:
|
||||
"""Process streaming chunks and emit events, returning completion data."""
|
||||
# Initialize result tracking
|
||||
chat_response_id = ""
|
||||
chat_response_content = []
|
||||
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
|
||||
chunk_created = 0
|
||||
chunk_model = ""
|
||||
chunk_finish_reason = ""
|
||||
|
||||
# Create a placeholder message item for delta events
|
||||
message_item_id = f"msg_{uuid.uuid4()}"
|
||||
# Track tool call items for streaming events
|
||||
tool_call_item_ids: dict[int, str] = {}
|
||||
# Track content parts for streaming events
|
||||
content_part_emitted = False
|
||||
|
||||
async for chunk in completion_result:
|
||||
chat_response_id = chunk.id
|
||||
chunk_created = chunk.created
|
||||
chunk_model = chunk.model
|
||||
for chunk_choice in chunk.choices:
|
||||
# Emit incremental text content as delta events
|
||||
if chunk_choice.delta.content:
|
||||
# Emit content_part.added event for first text chunk
|
||||
if not content_part_emitted:
|
||||
content_part_emitted = True
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseContentPartAdded(
|
||||
response_id=self.response_id,
|
||||
item_id=message_item_id,
|
||||
part=OpenAIResponseContentPartOutputText(
|
||||
text="", # Will be filled incrementally via text deltas
|
||||
),
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputTextDelta(
|
||||
content_index=0,
|
||||
delta=chunk_choice.delta.content,
|
||||
item_id=message_item_id,
|
||||
output_index=0,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
# Collect content for final response
|
||||
chat_response_content.append(chunk_choice.delta.content or "")
|
||||
if chunk_choice.finish_reason:
|
||||
chunk_finish_reason = chunk_choice.finish_reason
|
||||
|
||||
# Aggregate tool call arguments across chunks
|
||||
if chunk_choice.delta.tool_calls:
|
||||
for tool_call in chunk_choice.delta.tool_calls:
|
||||
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
|
||||
# Create new tool call entry if this is the first chunk for this index
|
||||
is_new_tool_call = response_tool_call is None
|
||||
if is_new_tool_call:
|
||||
tool_call_dict: dict[str, Any] = tool_call.model_dump()
|
||||
tool_call_dict.pop("type", None)
|
||||
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
|
||||
chat_response_tool_calls[tool_call.index] = response_tool_call
|
||||
|
||||
# Create item ID for this tool call for streaming events
|
||||
tool_call_item_id = f"fc_{uuid.uuid4()}"
|
||||
tool_call_item_ids[tool_call.index] = tool_call_item_id
|
||||
|
||||
# Emit output_item.added event for the new function call
|
||||
self.sequence_number += 1
|
||||
function_call_item = OpenAIResponseOutputMessageFunctionToolCall(
|
||||
arguments="", # Will be filled incrementally via delta events
|
||||
call_id=tool_call.id or "",
|
||||
name=tool_call.function.name if tool_call.function else "",
|
||||
id=tool_call_item_id,
|
||||
status="in_progress",
|
||||
)
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
|
||||
response_id=self.response_id,
|
||||
item=function_call_item,
|
||||
output_index=len(output_messages),
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
# Stream tool call arguments as they arrive (differentiate between MCP and function calls)
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
tool_call_item_id = tool_call_item_ids[tool_call.index]
|
||||
self.sequence_number += 1
|
||||
|
||||
# Check if this is an MCP tool call
|
||||
is_mcp_tool = tool_call.function.name and tool_call.function.name in self.mcp_tool_to_server
|
||||
if is_mcp_tool:
|
||||
# Emit MCP-specific argument delta event
|
||||
yield OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta(
|
||||
delta=tool_call.function.arguments,
|
||||
item_id=tool_call_item_id,
|
||||
output_index=len(output_messages),
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
else:
|
||||
# Emit function call argument delta event
|
||||
yield OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta(
|
||||
delta=tool_call.function.arguments,
|
||||
item_id=tool_call_item_id,
|
||||
output_index=len(output_messages),
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
# Accumulate arguments for final response (only for subsequent chunks)
|
||||
if not is_new_tool_call:
|
||||
response_tool_call.function.arguments = (
|
||||
response_tool_call.function.arguments or ""
|
||||
) + tool_call.function.arguments
|
||||
|
||||
# Emit arguments.done events for completed tool calls (differentiate between MCP and function calls)
|
||||
for tool_call_index in sorted(chat_response_tool_calls.keys()):
|
||||
tool_call_item_id = tool_call_item_ids[tool_call_index]
|
||||
final_arguments = chat_response_tool_calls[tool_call_index].function.arguments or ""
|
||||
tool_call_name = chat_response_tool_calls[tool_call_index].function.name
|
||||
|
||||
# Check if this is an MCP tool call
|
||||
is_mcp_tool = tool_call_name and tool_call_name in self.mcp_tool_to_server
|
||||
self.sequence_number += 1
|
||||
done_event_cls = (
|
||||
OpenAIResponseObjectStreamResponseMcpCallArgumentsDone
|
||||
if is_mcp_tool
|
||||
else OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone
|
||||
)
|
||||
yield done_event_cls(
|
||||
arguments=final_arguments,
|
||||
item_id=tool_call_item_id,
|
||||
output_index=len(output_messages),
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
# Emit content_part.done event if text content was streamed (before content gets cleared)
|
||||
if content_part_emitted:
|
||||
final_text = "".join(chat_response_content)
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseContentPartDone(
|
||||
response_id=self.response_id,
|
||||
item_id=message_item_id,
|
||||
part=OpenAIResponseContentPartOutputText(
|
||||
text=final_text,
|
||||
),
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
# Clear content when there are tool calls (OpenAI spec behavior)
|
||||
if chat_response_tool_calls:
|
||||
chat_response_content = []
|
||||
|
||||
yield ChatCompletionResult(
|
||||
response_id=chat_response_id,
|
||||
content=chat_response_content,
|
||||
tool_calls=chat_response_tool_calls,
|
||||
created=chunk_created,
|
||||
model=chunk_model,
|
||||
finish_reason=chunk_finish_reason,
|
||||
message_item_id=message_item_id,
|
||||
tool_call_item_ids=tool_call_item_ids,
|
||||
content_part_emitted=content_part_emitted,
|
||||
)
|
||||
|
||||
def _build_chat_completion(self, result: ChatCompletionResult) -> OpenAIChatCompletion:
|
||||
"""Build OpenAIChatCompletion from ChatCompletionResult."""
|
||||
# Convert collected chunks to complete response
|
||||
if result.tool_calls:
|
||||
tool_calls = [result.tool_calls[i] for i in sorted(result.tool_calls.keys())]
|
||||
else:
|
||||
tool_calls = None
|
||||
|
||||
assistant_message = OpenAIAssistantMessageParam(
|
||||
content=result.content_text,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
return OpenAIChatCompletion(
|
||||
id=result.response_id,
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
message=assistant_message,
|
||||
finish_reason=result.finish_reason,
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
created=result.created,
|
||||
model=result.model,
|
||||
)
|
||||
|
||||
async def _coordinate_tool_execution(
|
||||
self,
|
||||
function_tool_calls: list,
|
||||
non_function_tool_calls: list,
|
||||
completion_result_data: ChatCompletionResult,
|
||||
output_messages: list[OpenAIResponseOutput],
|
||||
next_turn_messages: list,
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
"""Coordinate execution of both function and non-function tool calls."""
|
||||
# Execute non-function tool calls
|
||||
for tool_call in non_function_tool_calls:
|
||||
# Find the item_id for this tool call
|
||||
matching_item_id = None
|
||||
for index, item_id in completion_result_data.tool_call_item_ids.items():
|
||||
response_tool_call = completion_result_data.tool_calls.get(index)
|
||||
if response_tool_call and response_tool_call.id == tool_call.id:
|
||||
matching_item_id = item_id
|
||||
break
|
||||
|
||||
# Use a fallback item_id if not found
|
||||
if not matching_item_id:
|
||||
matching_item_id = f"tc_{uuid.uuid4()}"
|
||||
|
||||
# Execute tool call with streaming
|
||||
tool_call_log = None
|
||||
tool_response_message = None
|
||||
async for result in self.tool_executor.execute_tool_call(
|
||||
tool_call,
|
||||
self.ctx,
|
||||
self.sequence_number,
|
||||
len(output_messages),
|
||||
matching_item_id,
|
||||
self.mcp_tool_to_server,
|
||||
):
|
||||
if result.stream_event:
|
||||
# Forward streaming events
|
||||
self.sequence_number = result.sequence_number
|
||||
yield result.stream_event
|
||||
|
||||
if result.final_output_message is not None:
|
||||
tool_call_log = result.final_output_message
|
||||
tool_response_message = result.final_input_message
|
||||
self.sequence_number = result.sequence_number
|
||||
|
||||
if tool_call_log:
|
||||
output_messages.append(tool_call_log)
|
||||
|
||||
# Emit output_item.done event for completed non-function tool call
|
||||
if matching_item_id:
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemDone(
|
||||
response_id=self.response_id,
|
||||
item=tool_call_log,
|
||||
output_index=len(output_messages) - 1,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
if tool_response_message:
|
||||
next_turn_messages.append(tool_response_message)
|
||||
|
||||
# Execute function tool calls (client-side)
|
||||
for tool_call in function_tool_calls:
|
||||
# Find the item_id for this tool call from our tracking dictionary
|
||||
matching_item_id = None
|
||||
for index, item_id in completion_result_data.tool_call_item_ids.items():
|
||||
response_tool_call = completion_result_data.tool_calls.get(index)
|
||||
if response_tool_call and response_tool_call.id == tool_call.id:
|
||||
matching_item_id = item_id
|
||||
break
|
||||
|
||||
# Use existing item_id or create new one if not found
|
||||
final_item_id = matching_item_id or f"fc_{uuid.uuid4()}"
|
||||
|
||||
function_call_item = OpenAIResponseOutputMessageFunctionToolCall(
|
||||
arguments=tool_call.function.arguments or "",
|
||||
call_id=tool_call.id,
|
||||
name=tool_call.function.name or "",
|
||||
id=final_item_id,
|
||||
status="completed",
|
||||
)
|
||||
output_messages.append(function_call_item)
|
||||
|
||||
# Emit output_item.done event for completed function call
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemDone(
|
||||
response_id=self.response_id,
|
||||
item=function_call_item,
|
||||
output_index=len(output_messages) - 1,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
async def _process_tools(
|
||||
self, tools: list[OpenAIResponseInputTool], output_messages: list[OpenAIResponseOutput]
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
"""Process all tools and emit appropriate streaming events."""
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from llama_stack.apis.tools import Tool
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
||||
|
||||
def make_openai_tool(tool_name: str, tool: Tool) -> ChatCompletionToolParam:
|
||||
tool_def = ToolDefinition(
|
||||
tool_name=tool_name,
|
||||
description=tool.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
default=param.default,
|
||||
)
|
||||
for param in tool.parameters
|
||||
},
|
||||
)
|
||||
return convert_tooldef_to_openai_tool(tool_def)
|
||||
|
||||
# Initialize chat_tools if not already set
|
||||
if self.ctx.chat_tools is None:
|
||||
self.ctx.chat_tools = []
|
||||
|
||||
for input_tool in tools:
|
||||
if input_tool.type == "function":
|
||||
self.ctx.chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump()))
|
||||
elif input_tool.type in WebSearchToolTypes:
|
||||
tool_name = "web_search"
|
||||
# Need to access tool_groups_api from tool_executor
|
||||
tool = await self.tool_executor.tool_groups_api.get_tool(tool_name)
|
||||
if not tool:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
self.ctx.chat_tools.append(make_openai_tool(tool_name, tool))
|
||||
elif input_tool.type == "file_search":
|
||||
tool_name = "knowledge_search"
|
||||
tool = await self.tool_executor.tool_groups_api.get_tool(tool_name)
|
||||
if not tool:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
self.ctx.chat_tools.append(make_openai_tool(tool_name, tool))
|
||||
elif input_tool.type == "mcp":
|
||||
async for stream_event in self._process_mcp_tool(input_tool, output_messages):
|
||||
yield stream_event
|
||||
else:
|
||||
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
|
||||
|
||||
async def _process_mcp_tool(
|
||||
self, mcp_tool: OpenAIResponseInputToolMCP, output_messages: list[OpenAIResponseOutput]
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
"""Process an MCP tool configuration and emit appropriate streaming events."""
|
||||
from llama_stack.providers.utils.tools.mcp import list_mcp_tools
|
||||
|
||||
# Emit mcp_list_tools.in_progress
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseMcpListToolsInProgress(
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
try:
|
||||
# Parse allowed/never allowed tools
|
||||
always_allowed = None
|
||||
never_allowed = None
|
||||
if mcp_tool.allowed_tools:
|
||||
if isinstance(mcp_tool.allowed_tools, list):
|
||||
always_allowed = mcp_tool.allowed_tools
|
||||
elif isinstance(mcp_tool.allowed_tools, AllowedToolsFilter):
|
||||
always_allowed = mcp_tool.allowed_tools.always
|
||||
never_allowed = mcp_tool.allowed_tools.never
|
||||
|
||||
# Call list_mcp_tools
|
||||
tool_defs = await list_mcp_tools(
|
||||
endpoint=mcp_tool.server_url,
|
||||
headers=mcp_tool.headers or {},
|
||||
)
|
||||
|
||||
# Create the MCP list tools message
|
||||
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
|
||||
id=f"mcp_list_{uuid.uuid4()}",
|
||||
server_label=mcp_tool.server_label,
|
||||
tools=[],
|
||||
)
|
||||
|
||||
# Process tools and update context
|
||||
for t in tool_defs.data:
|
||||
if never_allowed and t.name in never_allowed:
|
||||
continue
|
||||
if not always_allowed or t.name in always_allowed:
|
||||
# Add to chat tools for inference
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
||||
|
||||
tool_def = ToolDefinition(
|
||||
tool_name=t.name,
|
||||
description=t.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
default=param.default,
|
||||
)
|
||||
for param in t.parameters
|
||||
},
|
||||
)
|
||||
openai_tool = convert_tooldef_to_openai_tool(tool_def)
|
||||
if self.ctx.chat_tools is None:
|
||||
self.ctx.chat_tools = []
|
||||
self.ctx.chat_tools.append(openai_tool)
|
||||
|
||||
# Add to MCP tool mapping
|
||||
if t.name in self.mcp_tool_to_server:
|
||||
raise ValueError(f"Duplicate tool name {t.name} found for server {mcp_tool.server_label}")
|
||||
self.mcp_tool_to_server[t.name] = mcp_tool
|
||||
|
||||
# Add to MCP list message
|
||||
mcp_list_message.tools.append(
|
||||
MCPListToolsTool(
|
||||
name=t.name,
|
||||
description=t.description,
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
p.name: {
|
||||
"type": p.parameter_type,
|
||||
"description": p.description,
|
||||
}
|
||||
for p in t.parameters
|
||||
},
|
||||
"required": [p.name for p in t.parameters if p.required],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Add the MCP list message to output
|
||||
output_messages.append(mcp_list_message)
|
||||
|
||||
# Emit output_item.added for the MCP list tools message
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
|
||||
response_id=self.response_id,
|
||||
item=mcp_list_message,
|
||||
output_index=len(output_messages) - 1,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
# Emit mcp_list_tools.completed
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseMcpListToolsCompleted(
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
# Emit output_item.done for the MCP list tools message
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemDone(
|
||||
response_id=self.response_id,
|
||||
item=mcp_list_message,
|
||||
output_index=len(output_messages) - 1,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# TODO: Emit mcp_list_tools.failed event if needed
|
||||
logger.exception(f"Failed to list MCP tools from {mcp_tool.server_url}: {e}")
|
||||
raise
|
|
@ -0,0 +1,379 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseInputToolFileSearch,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseObjectStreamResponseMcpCallCompleted,
|
||||
OpenAIResponseObjectStreamResponseMcpCallFailed,
|
||||
OpenAIResponseObjectStreamResponseMcpCallInProgress,
|
||||
OpenAIResponseObjectStreamResponseWebSearchCallCompleted,
|
||||
OpenAIResponseObjectStreamResponseWebSearchCallInProgress,
|
||||
OpenAIResponseObjectStreamResponseWebSearchCallSearching,
|
||||
OpenAIResponseOutputMessageFileSearchToolCall,
|
||||
OpenAIResponseOutputMessageFileSearchToolCallResults,
|
||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
)
|
||||
from llama_stack.apis.common.content_types import (
|
||||
ImageContentItem,
|
||||
TextContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIImageURL,
|
||||
OpenAIToolMessageParam,
|
||||
)
|
||||
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .types import ChatCompletionContext, ToolExecutionResult
|
||||
|
||||
logger = get_logger(name=__name__, category="responses")
|
||||
|
||||
|
||||
class ToolExecutor:
|
||||
def __init__(
|
||||
self,
|
||||
tool_groups_api: ToolGroups,
|
||||
tool_runtime_api: ToolRuntime,
|
||||
vector_io_api: VectorIO,
|
||||
):
|
||||
self.tool_groups_api = tool_groups_api
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
self.vector_io_api = vector_io_api
|
||||
|
||||
async def execute_tool_call(
|
||||
self,
|
||||
tool_call: OpenAIChatCompletionToolCall,
|
||||
ctx: ChatCompletionContext,
|
||||
sequence_number: int,
|
||||
output_index: int,
|
||||
item_id: str,
|
||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||
) -> AsyncIterator[ToolExecutionResult]:
|
||||
tool_call_id = tool_call.id
|
||||
function = tool_call.function
|
||||
tool_kwargs = json.loads(function.arguments) if function.arguments else {}
|
||||
|
||||
if not function or not tool_call_id or not function.name:
|
||||
yield ToolExecutionResult(sequence_number=sequence_number)
|
||||
return
|
||||
|
||||
# Emit progress events for tool execution start
|
||||
async for event_result in self._emit_progress_events(
|
||||
function.name, ctx, sequence_number, output_index, item_id, mcp_tool_to_server
|
||||
):
|
||||
sequence_number = event_result.sequence_number
|
||||
yield event_result
|
||||
|
||||
# Execute the actual tool call
|
||||
error_exc, result = await self._execute_tool(function.name, tool_kwargs, ctx, mcp_tool_to_server)
|
||||
|
||||
# Emit completion events for tool execution
|
||||
has_error = error_exc or (result and ((result.error_code and result.error_code > 0) or result.error_message))
|
||||
async for event_result in self._emit_completion_events(
|
||||
function.name, ctx, sequence_number, output_index, item_id, has_error, mcp_tool_to_server
|
||||
):
|
||||
sequence_number = event_result.sequence_number
|
||||
yield event_result
|
||||
|
||||
# Build result messages from tool execution
|
||||
output_message, input_message = await self._build_result_messages(
|
||||
function, tool_call_id, tool_kwargs, ctx, error_exc, result, has_error, mcp_tool_to_server
|
||||
)
|
||||
|
||||
# Yield the final result
|
||||
yield ToolExecutionResult(
|
||||
sequence_number=sequence_number, final_output_message=output_message, final_input_message=input_message
|
||||
)
|
||||
|
||||
async def _execute_knowledge_search_via_vector_store(
|
||||
self,
|
||||
query: str,
|
||||
response_file_search_tool: OpenAIResponseInputToolFileSearch,
|
||||
) -> ToolInvocationResult:
|
||||
"""Execute knowledge search using vector_stores.search API with filters support."""
|
||||
search_results = []
|
||||
|
||||
# Create search tasks for all vector stores
|
||||
async def search_single_store(vector_store_id):
|
||||
try:
|
||||
search_response = await self.vector_io_api.openai_search_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
query=query,
|
||||
filters=response_file_search_tool.filters,
|
||||
max_num_results=response_file_search_tool.max_num_results,
|
||||
ranking_options=response_file_search_tool.ranking_options,
|
||||
rewrite_query=False,
|
||||
)
|
||||
return search_response.data
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to search vector store {vector_store_id}: {e}")
|
||||
return []
|
||||
|
||||
# Run all searches in parallel using gather
|
||||
search_tasks = [search_single_store(vid) for vid in response_file_search_tool.vector_store_ids]
|
||||
all_results = await asyncio.gather(*search_tasks)
|
||||
|
||||
# Flatten results
|
||||
for results in all_results:
|
||||
search_results.extend(results)
|
||||
|
||||
# Convert search results to tool result format matching memory.py
|
||||
# Format the results as interleaved content similar to memory.py
|
||||
content_items = []
|
||||
content_items.append(
|
||||
TextContentItem(
|
||||
text=f"knowledge_search tool found {len(search_results)} chunks:\nBEGIN of knowledge_search tool results.\n"
|
||||
)
|
||||
)
|
||||
|
||||
for i, result_item in enumerate(search_results):
|
||||
chunk_text = result_item.content[0].text if result_item.content else ""
|
||||
metadata_text = f"document_id: {result_item.file_id}, score: {result_item.score}"
|
||||
if result_item.attributes:
|
||||
metadata_text += f", attributes: {result_item.attributes}"
|
||||
text_content = f"[{i + 1}] {metadata_text}\n{chunk_text}\n"
|
||||
content_items.append(TextContentItem(text=text_content))
|
||||
|
||||
content_items.append(TextContentItem(text="END of knowledge_search tool results.\n"))
|
||||
content_items.append(
|
||||
TextContentItem(
|
||||
text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.\n',
|
||||
)
|
||||
)
|
||||
|
||||
return ToolInvocationResult(
|
||||
content=content_items,
|
||||
metadata={
|
||||
"document_ids": [r.file_id for r in search_results],
|
||||
"chunks": [r.content[0].text if r.content else "" for r in search_results],
|
||||
"scores": [r.score for r in search_results],
|
||||
},
|
||||
)
|
||||
|
||||
async def _emit_progress_events(
|
||||
self,
|
||||
function_name: str,
|
||||
ctx: ChatCompletionContext,
|
||||
sequence_number: int,
|
||||
output_index: int,
|
||||
item_id: str,
|
||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||
) -> AsyncIterator[ToolExecutionResult]:
|
||||
"""Emit progress events for tool execution start."""
|
||||
# Emit in_progress event based on tool type (only for tools with specific streaming events)
|
||||
progress_event = None
|
||||
if mcp_tool_to_server and function_name in mcp_tool_to_server:
|
||||
sequence_number += 1
|
||||
progress_event = OpenAIResponseObjectStreamResponseMcpCallInProgress(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
elif function_name == "web_search":
|
||||
sequence_number += 1
|
||||
progress_event = OpenAIResponseObjectStreamResponseWebSearchCallInProgress(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
# Note: knowledge_search and other custom tools don't have specific streaming events in OpenAI spec
|
||||
|
||||
if progress_event:
|
||||
yield ToolExecutionResult(stream_event=progress_event, sequence_number=sequence_number)
|
||||
|
||||
# For web search, emit searching event
|
||||
if function_name == "web_search":
|
||||
sequence_number += 1
|
||||
searching_event = OpenAIResponseObjectStreamResponseWebSearchCallSearching(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
|
||||
|
||||
async def _execute_tool(
|
||||
self,
|
||||
function_name: str,
|
||||
tool_kwargs: dict,
|
||||
ctx: ChatCompletionContext,
|
||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||
) -> tuple[Exception | None, any]:
|
||||
"""Execute the tool and return error exception and result."""
|
||||
error_exc = None
|
||||
result = None
|
||||
|
||||
try:
|
||||
if mcp_tool_to_server and function_name in mcp_tool_to_server:
|
||||
from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool
|
||||
|
||||
mcp_tool = mcp_tool_to_server[function_name]
|
||||
result = await invoke_mcp_tool(
|
||||
endpoint=mcp_tool.server_url,
|
||||
headers=mcp_tool.headers or {},
|
||||
tool_name=function_name,
|
||||
kwargs=tool_kwargs,
|
||||
)
|
||||
elif function_name == "knowledge_search":
|
||||
response_file_search_tool = next(
|
||||
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)),
|
||||
None,
|
||||
)
|
||||
if response_file_search_tool:
|
||||
# Use vector_stores.search API instead of knowledge_search tool
|
||||
# to support filters and ranking_options
|
||||
query = tool_kwargs.get("query", "")
|
||||
result = await self._execute_knowledge_search_via_vector_store(
|
||||
query=query,
|
||||
response_file_search_tool=response_file_search_tool,
|
||||
)
|
||||
else:
|
||||
result = await self.tool_runtime_api.invoke_tool(
|
||||
tool_name=function_name,
|
||||
kwargs=tool_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
error_exc = e
|
||||
|
||||
return error_exc, result
|
||||
|
||||
async def _emit_completion_events(
|
||||
self,
|
||||
function_name: str,
|
||||
ctx: ChatCompletionContext,
|
||||
sequence_number: int,
|
||||
output_index: int,
|
||||
item_id: str,
|
||||
has_error: bool,
|
||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||
) -> AsyncIterator[ToolExecutionResult]:
|
||||
"""Emit completion or failure events for tool execution."""
|
||||
completion_event = None
|
||||
|
||||
if mcp_tool_to_server and function_name in mcp_tool_to_server:
|
||||
sequence_number += 1
|
||||
if has_error:
|
||||
completion_event = OpenAIResponseObjectStreamResponseMcpCallFailed(
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
else:
|
||||
completion_event = OpenAIResponseObjectStreamResponseMcpCallCompleted(
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
elif function_name == "web_search":
|
||||
sequence_number += 1
|
||||
completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
# Note: knowledge_search and other custom tools don't have specific completion events in OpenAI spec
|
||||
|
||||
if completion_event:
|
||||
yield ToolExecutionResult(stream_event=completion_event, sequence_number=sequence_number)
|
||||
|
||||
async def _build_result_messages(
|
||||
self,
|
||||
function,
|
||||
tool_call_id: str,
|
||||
tool_kwargs: dict,
|
||||
ctx: ChatCompletionContext,
|
||||
error_exc: Exception | None,
|
||||
result: any,
|
||||
has_error: bool,
|
||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||
) -> tuple[any, any]:
|
||||
"""Build output and input messages from tool execution results."""
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
# Build output message
|
||||
if mcp_tool_to_server and function.name in mcp_tool_to_server:
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseOutputMessageMCPCall,
|
||||
)
|
||||
|
||||
message = OpenAIResponseOutputMessageMCPCall(
|
||||
id=tool_call_id,
|
||||
arguments=function.arguments,
|
||||
name=function.name,
|
||||
server_label=mcp_tool_to_server[function.name].server_label,
|
||||
)
|
||||
if error_exc:
|
||||
message.error = str(error_exc)
|
||||
elif (result and result.error_code and result.error_code > 0) or (result and result.error_message):
|
||||
message.error = f"Error (code {result.error_code}): {result.error_message}"
|
||||
elif result and result.content:
|
||||
message.output = interleaved_content_as_str(result.content)
|
||||
else:
|
||||
if function.name == "web_search":
|
||||
message = OpenAIResponseOutputMessageWebSearchToolCall(
|
||||
id=tool_call_id,
|
||||
status="completed",
|
||||
)
|
||||
if has_error:
|
||||
message.status = "failed"
|
||||
elif function.name == "knowledge_search":
|
||||
message = OpenAIResponseOutputMessageFileSearchToolCall(
|
||||
id=tool_call_id,
|
||||
queries=[tool_kwargs.get("query", "")],
|
||||
status="completed",
|
||||
)
|
||||
if result and "document_ids" in result.metadata:
|
||||
message.results = []
|
||||
for i, doc_id in enumerate(result.metadata["document_ids"]):
|
||||
text = result.metadata["chunks"][i] if "chunks" in result.metadata else None
|
||||
score = result.metadata["scores"][i] if "scores" in result.metadata else None
|
||||
message.results.append(
|
||||
OpenAIResponseOutputMessageFileSearchToolCallResults(
|
||||
file_id=doc_id,
|
||||
filename=doc_id,
|
||||
text=text,
|
||||
score=score,
|
||||
attributes={},
|
||||
)
|
||||
)
|
||||
if has_error:
|
||||
message.status = "failed"
|
||||
else:
|
||||
raise ValueError(f"Unknown tool {function.name} called")
|
||||
|
||||
# Build input message
|
||||
input_message = None
|
||||
if result and result.content:
|
||||
if isinstance(result.content, str):
|
||||
content = result.content
|
||||
elif isinstance(result.content, list):
|
||||
content = []
|
||||
for item in result.content:
|
||||
if isinstance(item, TextContentItem):
|
||||
part = OpenAIChatCompletionContentPartTextParam(text=item.text)
|
||||
elif isinstance(item, ImageContentItem):
|
||||
if item.image.data:
|
||||
url = f"data:image;base64,{item.image.data}"
|
||||
else:
|
||||
url = item.image.url
|
||||
part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url))
|
||||
else:
|
||||
raise ValueError(f"Unknown result content type: {type(item)}")
|
||||
content.append(part)
|
||||
else:
|
||||
raise ValueError(f"Unknown result content type: {type(result.content)}")
|
||||
input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id)
|
||||
else:
|
||||
text = str(error_exc) if error_exc else "Tool execution failed"
|
||||
input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id)
|
||||
|
||||
return message, input_message
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponseOutput,
|
||||
)
|
||||
from llama_stack.apis.inference import OpenAIChatCompletionToolCall, OpenAIMessageParam, OpenAIResponseFormatParam
|
||||
|
||||
|
||||
class ToolExecutionResult(BaseModel):
|
||||
"""Result of streaming tool execution."""
|
||||
|
||||
stream_event: OpenAIResponseObjectStream | None = None
|
||||
sequence_number: int
|
||||
final_output_message: OpenAIResponseOutput | None = None
|
||||
final_input_message: OpenAIMessageParam | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatCompletionResult:
|
||||
"""Result of processing streaming chat completion chunks."""
|
||||
|
||||
response_id: str
|
||||
content: list[str]
|
||||
tool_calls: dict[int, OpenAIChatCompletionToolCall]
|
||||
created: int
|
||||
model: str
|
||||
finish_reason: str
|
||||
message_item_id: str # For streaming events
|
||||
tool_call_item_ids: dict[int, str] # For streaming events
|
||||
content_part_emitted: bool # Tracking state
|
||||
|
||||
@property
|
||||
def content_text(self) -> str:
|
||||
"""Get joined content as string."""
|
||||
return "".join(self.content)
|
||||
|
||||
@property
|
||||
def has_tool_calls(self) -> bool:
|
||||
"""Check if there are any tool calls."""
|
||||
return bool(self.tool_calls)
|
||||
|
||||
|
||||
class ChatCompletionContext(BaseModel):
|
||||
model: str
|
||||
messages: list[OpenAIMessageParam]
|
||||
response_tools: list[OpenAIResponseInputTool] | None = None
|
||||
chat_tools: list[ChatCompletionToolParam] | None = None
|
||||
temperature: float | None
|
||||
response_format: OpenAIResponseFormatParam
|
|
@ -0,0 +1,169 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputFunctionToolCallOutput,
|
||||
OpenAIResponseInputMessageContent,
|
||||
OpenAIResponseInputMessageContentImage,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseOutputMessageContent,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
OpenAIResponseOutputMessageFunctionToolCall,
|
||||
OpenAIResponseText,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
OpenAIChoice,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIImageURL,
|
||||
OpenAIJSONSchema,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatJSONObject,
|
||||
OpenAIResponseFormatJSONSchema,
|
||||
OpenAIResponseFormatParam,
|
||||
OpenAIResponseFormatText,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
|
||||
|
||||
async def convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenAIResponseMessage:
|
||||
"""Convert an OpenAI Chat Completion choice into an OpenAI Response output message."""
|
||||
output_content = ""
|
||||
if isinstance(choice.message.content, str):
|
||||
output_content = choice.message.content
|
||||
elif isinstance(choice.message.content, OpenAIChatCompletionContentPartTextParam):
|
||||
output_content = choice.message.content.text
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}"
|
||||
)
|
||||
|
||||
return OpenAIResponseMessage(
|
||||
id=f"msg_{uuid.uuid4()}",
|
||||
content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)],
|
||||
status="completed",
|
||||
role="assistant",
|
||||
)
|
||||
|
||||
|
||||
async def convert_response_content_to_chat_content(
|
||||
content: (str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]),
|
||||
) -> str | list[OpenAIChatCompletionContentPartParam]:
|
||||
"""
|
||||
Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts.
|
||||
|
||||
The content schemas of each API look similar, but are not exactly the same.
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
converted_parts = []
|
||||
for content_part in content:
|
||||
if isinstance(content_part, OpenAIResponseInputMessageContentText):
|
||||
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
|
||||
elif isinstance(content_part, OpenAIResponseOutputMessageContentOutputText):
|
||||
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
|
||||
elif isinstance(content_part, OpenAIResponseInputMessageContentImage):
|
||||
if content_part.image_url:
|
||||
image_url = OpenAIImageURL(url=content_part.image_url, detail=content_part.detail)
|
||||
converted_parts.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url))
|
||||
elif isinstance(content_part, str):
|
||||
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Llama Stack OpenAI Responses does not yet support content type '{type(content_part)}' in this context"
|
||||
)
|
||||
return converted_parts
|
||||
|
||||
|
||||
async def convert_response_input_to_chat_messages(
|
||||
input: str | list[OpenAIResponseInput],
|
||||
) -> list[OpenAIMessageParam]:
|
||||
"""
|
||||
Convert the input from an OpenAI Response API request into OpenAI Chat Completion messages.
|
||||
"""
|
||||
messages: list[OpenAIMessageParam] = []
|
||||
if isinstance(input, list):
|
||||
for input_item in input:
|
||||
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
|
||||
messages.append(
|
||||
OpenAIToolMessageParam(
|
||||
content=input_item.output,
|
||||
tool_call_id=input_item.call_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall):
|
||||
tool_call = OpenAIChatCompletionToolCall(
|
||||
index=0,
|
||||
id=input_item.call_id,
|
||||
function=OpenAIChatCompletionToolCallFunction(
|
||||
name=input_item.name,
|
||||
arguments=input_item.arguments,
|
||||
),
|
||||
)
|
||||
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
|
||||
else:
|
||||
content = await convert_response_content_to_chat_content(input_item.content)
|
||||
message_type = await get_message_type_by_role(input_item.role)
|
||||
if message_type is None:
|
||||
raise ValueError(
|
||||
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
|
||||
)
|
||||
messages.append(message_type(content=content))
|
||||
else:
|
||||
messages.append(OpenAIUserMessageParam(content=input))
|
||||
return messages
|
||||
|
||||
|
||||
async def convert_response_text_to_chat_response_format(
|
||||
text: OpenAIResponseText,
|
||||
) -> OpenAIResponseFormatParam:
|
||||
"""
|
||||
Convert an OpenAI Response text parameter into an OpenAI Chat Completion response format.
|
||||
"""
|
||||
if not text.format or text.format["type"] == "text":
|
||||
return OpenAIResponseFormatText(type="text")
|
||||
if text.format["type"] == "json_object":
|
||||
return OpenAIResponseFormatJSONObject()
|
||||
if text.format["type"] == "json_schema":
|
||||
return OpenAIResponseFormatJSONSchema(
|
||||
json_schema=OpenAIJSONSchema(name=text.format["name"], schema=text.format["schema"])
|
||||
)
|
||||
raise ValueError(f"Unsupported text format: {text.format}")
|
||||
|
||||
|
||||
async def get_message_type_by_role(role: str):
|
||||
role_to_type = {
|
||||
"user": OpenAIUserMessageParam,
|
||||
"system": OpenAISystemMessageParam,
|
||||
"assistant": OpenAIAssistantMessageParam,
|
||||
"developer": OpenAIDeveloperMessageParam,
|
||||
}
|
||||
return role_to_type.get(role)
|
||||
|
||||
|
||||
def is_function_tool_call(
|
||||
tool_call: OpenAIChatCompletionToolCall,
|
||||
tools: list[OpenAIResponseInputTool],
|
||||
) -> bool:
|
||||
if not tool_call.function:
|
||||
return False
|
||||
for t in tools:
|
||||
if t.type == "function" and t.name == tool_call.function.name:
|
||||
return True
|
||||
return False
|
5
llama_stack/providers/inline/batches/__init__.py
Normal file
5
llama_stack/providers/inline/batches/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
36
llama_stack/providers/inline/batches/reference/__init__.py
Normal file
36
llama_stack/providers/inline/batches/reference/__init__.py
Normal file
|
@ -0,0 +1,36 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.core.datatypes import AccessRule, Api
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
from .batches import ReferenceBatchesImpl
|
||||
from .config import ReferenceBatchesImplConfig
|
||||
|
||||
__all__ = ["ReferenceBatchesImpl", "ReferenceBatchesImplConfig"]
|
||||
|
||||
|
||||
async def get_provider_impl(config: ReferenceBatchesImplConfig, deps: dict[Api, Any], policy: list[AccessRule]):
|
||||
kvstore = await kvstore_impl(config.kvstore)
|
||||
inference_api: Inference | None = deps.get(Api.inference)
|
||||
files_api: Files | None = deps.get(Api.files)
|
||||
models_api: Models | None = deps.get(Api.models)
|
||||
|
||||
if inference_api is None:
|
||||
raise ValueError("Inference API is required but not provided in dependencies")
|
||||
if files_api is None:
|
||||
raise ValueError("Files API is required but not provided in dependencies")
|
||||
if models_api is None:
|
||||
raise ValueError("Models API is required but not provided in dependencies")
|
||||
|
||||
impl = ReferenceBatchesImpl(config, inference_api, files_api, models_api, kvstore)
|
||||
await impl.initialize()
|
||||
return impl
|
580
llama_stack/providers/inline/batches/reference/batches.py
Normal file
580
llama_stack/providers/inline/batches/reference/batches.py
Normal file
|
@ -0,0 +1,580 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import itertools
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from typing import Any, Literal
|
||||
|
||||
from openai.types.batch import BatchError, Errors
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.batches import Batches, BatchObject, ListBatchesResponse
|
||||
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
|
||||
from llama_stack.apis.files import Files, OpenAIFilePurpose
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIMessageParam,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
||||
from .config import ReferenceBatchesImplConfig
|
||||
|
||||
BATCH_PREFIX = "batch:"
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AsyncBytesIO:
|
||||
"""
|
||||
Async-compatible BytesIO wrapper to allow async file-like operations.
|
||||
|
||||
We use this when uploading files to the Files API, as it expects an
|
||||
async file-like object.
|
||||
"""
|
||||
|
||||
def __init__(self, data: bytes):
|
||||
self._buffer = BytesIO(data)
|
||||
|
||||
async def read(self, n=-1):
|
||||
return self._buffer.read(n)
|
||||
|
||||
async def seek(self, pos, whence=0):
|
||||
return self._buffer.seek(pos, whence)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self._buffer.close()
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._buffer, name)
|
||||
|
||||
|
||||
class BatchRequest(BaseModel):
|
||||
line_num: int
|
||||
custom_id: str
|
||||
method: str
|
||||
url: str
|
||||
body: dict[str, Any]
|
||||
|
||||
|
||||
def convert_to_openai_message_param(msg: dict[str, Any]) -> OpenAIMessageParam:
|
||||
"""Convert a message dictionary to OpenAIMessageParam based on role."""
|
||||
role = msg.get("role")
|
||||
|
||||
if role == "user":
|
||||
return OpenAIUserMessageParam(**msg)
|
||||
elif role == "system":
|
||||
return OpenAISystemMessageParam(**msg)
|
||||
elif role == "assistant":
|
||||
return OpenAIAssistantMessageParam(**msg)
|
||||
elif role == "tool":
|
||||
return OpenAIToolMessageParam(**msg)
|
||||
elif role == "developer":
|
||||
return OpenAIDeveloperMessageParam(**msg)
|
||||
else:
|
||||
raise ValueError(f"Unknown message role: {role}")
|
||||
|
||||
|
||||
class ReferenceBatchesImpl(Batches):
|
||||
"""Reference implementation of the Batches API.
|
||||
|
||||
This implementation processes batch files by making individual requests
|
||||
to the inference API and generates output files with results.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ReferenceBatchesImplConfig,
|
||||
inference_api: Inference,
|
||||
files_api: Files,
|
||||
models_api: Models,
|
||||
kvstore: KVStore,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.kvstore = kvstore
|
||||
self.inference_api = inference_api
|
||||
self.files_api = files_api
|
||||
self.models_api = models_api
|
||||
self._processing_tasks: dict[str, asyncio.Task] = {}
|
||||
self._batch_semaphore = asyncio.Semaphore(config.max_concurrent_batches)
|
||||
self._update_batch_lock = asyncio.Lock()
|
||||
|
||||
# this is to allow tests to disable background processing
|
||||
self.process_batches = True
|
||||
|
||||
async def initialize(self) -> None:
|
||||
# TODO: start background processing of existing tasks
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Shutdown the batches provider."""
|
||||
if self._processing_tasks:
|
||||
# don't cancel tasks - just let them stop naturally on shutdown
|
||||
# cancelling would mark batches as "cancelled" in the database
|
||||
logger.info(f"Shutdown initiated with {len(self._processing_tasks)} active batch processing tasks")
|
||||
|
||||
# TODO (SECURITY): this currently works w/ configured api keys, not with x-llamastack-provider-data or with user policy restrictions
|
||||
async def create_batch(
|
||||
self,
|
||||
input_file_id: str,
|
||||
endpoint: str,
|
||||
completion_window: Literal["24h"],
|
||||
metadata: dict[str, str] | None = None,
|
||||
) -> BatchObject:
|
||||
"""
|
||||
Create a new batch for processing multiple API requests.
|
||||
|
||||
Error handling by levels -
|
||||
0. Input param handling, results in 40x errors before processing, e.g.
|
||||
- Wrong completion_window
|
||||
- Invalid metadata types
|
||||
- Unknown endpoint
|
||||
-> no batch created
|
||||
1. Errors preventing processing, result in BatchErrors aggregated in process_batch, e.g.
|
||||
- input_file_id missing
|
||||
- invalid json in file
|
||||
- missing custom_id, method, url, body
|
||||
- invalid model
|
||||
- streaming
|
||||
-> batch created, validation sends to failed status
|
||||
2. Processing errors, result in error_file_id entries, e.g.
|
||||
- Any error returned from inference endpoint
|
||||
-> batch created, goes to completed status
|
||||
"""
|
||||
|
||||
# TODO: set expiration time for garbage collection
|
||||
|
||||
if endpoint not in ["/v1/chat/completions"]:
|
||||
raise ValueError(
|
||||
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions. Code: invalid_value. Param: endpoint",
|
||||
)
|
||||
|
||||
if completion_window != "24h":
|
||||
raise ValueError(
|
||||
f"Invalid completion_window: {completion_window}. Supported values are: 24h. Code: invalid_value. Param: completion_window",
|
||||
)
|
||||
|
||||
batch_id = f"batch_{uuid.uuid4().hex[:16]}"
|
||||
current_time = int(time.time())
|
||||
|
||||
batch = BatchObject(
|
||||
id=batch_id,
|
||||
object="batch",
|
||||
endpoint=endpoint,
|
||||
input_file_id=input_file_id,
|
||||
completion_window=completion_window,
|
||||
status="validating",
|
||||
created_at=current_time,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
await self.kvstore.set(f"batch:{batch_id}", batch.to_json())
|
||||
|
||||
if self.process_batches:
|
||||
task = asyncio.create_task(self._process_batch(batch_id))
|
||||
self._processing_tasks[batch_id] = task
|
||||
|
||||
return batch
|
||||
|
||||
async def cancel_batch(self, batch_id: str) -> BatchObject:
|
||||
"""Cancel a batch that is in progress."""
|
||||
batch = await self.retrieve_batch(batch_id)
|
||||
|
||||
if batch.status in ["cancelled", "cancelling"]:
|
||||
return batch
|
||||
|
||||
if batch.status in ["completed", "failed", "expired"]:
|
||||
raise ConflictError(f"Cannot cancel batch '{batch_id}' with status '{batch.status}'")
|
||||
|
||||
await self._update_batch(batch_id, status="cancelling", cancelling_at=int(time.time()))
|
||||
|
||||
if batch_id in self._processing_tasks:
|
||||
self._processing_tasks[batch_id].cancel()
|
||||
# note: task removal and status="cancelled" handled in finally block of _process_batch
|
||||
|
||||
return await self.retrieve_batch(batch_id)
|
||||
|
||||
async def list_batches(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int = 20,
|
||||
) -> ListBatchesResponse:
|
||||
"""
|
||||
List all batches, eventually only for the current user.
|
||||
|
||||
With no notion of user, we return all batches.
|
||||
"""
|
||||
batch_values = await self.kvstore.values_in_range("batch:", "batch:\xff")
|
||||
|
||||
batches = []
|
||||
for batch_data in batch_values:
|
||||
if batch_data:
|
||||
batches.append(BatchObject.model_validate_json(batch_data))
|
||||
|
||||
batches.sort(key=lambda b: b.created_at, reverse=True)
|
||||
|
||||
start_idx = 0
|
||||
if after:
|
||||
for i, batch in enumerate(batches):
|
||||
if batch.id == after:
|
||||
start_idx = i + 1
|
||||
break
|
||||
|
||||
page_batches = batches[start_idx : start_idx + limit]
|
||||
has_more = (start_idx + limit) < len(batches)
|
||||
|
||||
first_id = page_batches[0].id if page_batches else None
|
||||
last_id = page_batches[-1].id if page_batches else None
|
||||
|
||||
return ListBatchesResponse(
|
||||
data=page_batches,
|
||||
first_id=first_id,
|
||||
last_id=last_id,
|
||||
has_more=has_more,
|
||||
)
|
||||
|
||||
async def retrieve_batch(self, batch_id: str) -> BatchObject:
|
||||
"""Retrieve information about a specific batch."""
|
||||
batch_data = await self.kvstore.get(f"batch:{batch_id}")
|
||||
if not batch_data:
|
||||
raise ResourceNotFoundError(batch_id, "Batch", "batches.list()")
|
||||
|
||||
return BatchObject.model_validate_json(batch_data)
|
||||
|
||||
async def _update_batch(self, batch_id: str, **updates) -> None:
|
||||
"""Update batch fields in kvstore."""
|
||||
async with self._update_batch_lock:
|
||||
try:
|
||||
batch = await self.retrieve_batch(batch_id)
|
||||
|
||||
# batch processing is async. once cancelling, only allow "cancelled" status updates
|
||||
if batch.status == "cancelling" and updates.get("status") != "cancelled":
|
||||
logger.info(
|
||||
f"Skipping status update for cancelled batch {batch_id}: attempted {updates.get('status')}"
|
||||
)
|
||||
return
|
||||
|
||||
if "errors" in updates:
|
||||
updates["errors"] = updates["errors"].model_dump()
|
||||
|
||||
batch_dict = batch.model_dump()
|
||||
batch_dict.update(updates)
|
||||
|
||||
await self.kvstore.set(f"batch:{batch_id}", json.dumps(batch_dict))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update batch {batch_id}: {e}")
|
||||
|
||||
async def _validate_input(self, batch: BatchObject) -> tuple[list[BatchError], list[BatchRequest]]:
|
||||
"""
|
||||
Read & validate input, return errors and valid input.
|
||||
|
||||
Validation of
|
||||
- input_file_id existance
|
||||
- valid json
|
||||
- custom_id, method, url, body presence and valid
|
||||
- no streaming
|
||||
"""
|
||||
requests: list[BatchRequest] = []
|
||||
errors: list[BatchError] = []
|
||||
try:
|
||||
await self.files_api.openai_retrieve_file(batch.input_file_id)
|
||||
except Exception:
|
||||
errors.append(
|
||||
BatchError(
|
||||
code="invalid_request",
|
||||
line=None,
|
||||
message=f"Cannot find file {batch.input_file_id}.",
|
||||
param="input_file_id",
|
||||
)
|
||||
)
|
||||
return errors, requests
|
||||
|
||||
# TODO(SECURITY): do something about large files
|
||||
file_content_response = await self.files_api.openai_retrieve_file_content(batch.input_file_id)
|
||||
file_content = file_content_response.body.decode("utf-8")
|
||||
for line_num, line in enumerate(file_content.strip().split("\n"), 1):
|
||||
if line.strip(): # skip empty lines
|
||||
try:
|
||||
request = json.loads(line)
|
||||
|
||||
if not isinstance(request, dict):
|
||||
errors.append(
|
||||
BatchError(
|
||||
code="invalid_request",
|
||||
line=line_num,
|
||||
message="Each line must be a JSON dictionary object",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
valid = True
|
||||
|
||||
for param, expected_type, type_string in [
|
||||
("custom_id", str, "string"),
|
||||
("method", str, "string"),
|
||||
("url", str, "string"),
|
||||
("body", dict, "JSON dictionary object"),
|
||||
]:
|
||||
if param not in request:
|
||||
errors.append(
|
||||
BatchError(
|
||||
code="missing_required_parameter",
|
||||
line=line_num,
|
||||
message=f"Missing required parameter: {param}",
|
||||
param=param,
|
||||
)
|
||||
)
|
||||
valid = False
|
||||
elif not isinstance(request[param], expected_type):
|
||||
param_name = "URL" if param == "url" else param.capitalize()
|
||||
errors.append(
|
||||
BatchError(
|
||||
code="invalid_request",
|
||||
line=line_num,
|
||||
message=f"{param_name} must be a {type_string}",
|
||||
param=param,
|
||||
)
|
||||
)
|
||||
valid = False
|
||||
|
||||
if (url := request.get("url")) and isinstance(url, str) and url != batch.endpoint:
|
||||
errors.append(
|
||||
BatchError(
|
||||
code="invalid_url",
|
||||
line=line_num,
|
||||
message="URL provided for this request does not match the batch endpoint",
|
||||
param="url",
|
||||
)
|
||||
)
|
||||
valid = False
|
||||
|
||||
if (body := request.get("body")) and isinstance(body, dict):
|
||||
if body.get("stream", False):
|
||||
errors.append(
|
||||
BatchError(
|
||||
code="streaming_unsupported",
|
||||
line=line_num,
|
||||
message="Streaming is not supported in batch processing",
|
||||
param="body.stream",
|
||||
)
|
||||
)
|
||||
valid = False
|
||||
|
||||
for param, expected_type, type_string in [
|
||||
("model", str, "a string"),
|
||||
# messages is specific to /v1/chat/completions
|
||||
# we could skip validating messages here and let inference fail. however,
|
||||
# that would be a very expensive way to find out messages is wrong.
|
||||
("messages", list, "an array"), # TODO: allow messages to be a string?
|
||||
]:
|
||||
if param not in body:
|
||||
errors.append(
|
||||
BatchError(
|
||||
code="invalid_request",
|
||||
line=line_num,
|
||||
message=f"{param.capitalize()} parameter is required",
|
||||
param=f"body.{param}",
|
||||
)
|
||||
)
|
||||
valid = False
|
||||
elif not isinstance(body[param], expected_type):
|
||||
errors.append(
|
||||
BatchError(
|
||||
code="invalid_request",
|
||||
line=line_num,
|
||||
message=f"{param.capitalize()} must be {type_string}",
|
||||
param=f"body.{param}",
|
||||
)
|
||||
)
|
||||
valid = False
|
||||
|
||||
if "model" in body and isinstance(body["model"], str):
|
||||
try:
|
||||
await self.models_api.get_model(body["model"])
|
||||
except Exception:
|
||||
errors.append(
|
||||
BatchError(
|
||||
code="model_not_found",
|
||||
line=line_num,
|
||||
message=f"Model '{body['model']}' does not exist or is not supported",
|
||||
param="body.model",
|
||||
)
|
||||
)
|
||||
valid = False
|
||||
|
||||
if valid:
|
||||
assert isinstance(url, str), "URL must be a string" # for mypy
|
||||
assert isinstance(body, dict), "Body must be a dictionary" # for mypy
|
||||
requests.append(
|
||||
BatchRequest(
|
||||
line_num=line_num,
|
||||
url=url,
|
||||
method=request["method"],
|
||||
custom_id=request["custom_id"],
|
||||
body=body,
|
||||
),
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
errors.append(
|
||||
BatchError(
|
||||
code="invalid_json_line",
|
||||
line=line_num,
|
||||
message="This line is not parseable as valid JSON.",
|
||||
)
|
||||
)
|
||||
|
||||
return errors, requests
|
||||
|
||||
async def _process_batch(self, batch_id: str) -> None:
|
||||
"""Background task to process a batch of requests."""
|
||||
try:
|
||||
logger.info(f"Starting batch processing for {batch_id}")
|
||||
async with self._batch_semaphore: # semaphore to limit concurrency
|
||||
logger.info(f"Acquired semaphore for batch {batch_id}")
|
||||
await self._process_batch_impl(batch_id)
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Batch processing cancelled for {batch_id}")
|
||||
await self._update_batch(batch_id, status="cancelled", cancelled_at=int(time.time()))
|
||||
except Exception as e:
|
||||
logger.error(f"Batch processing failed for {batch_id}: {e}")
|
||||
await self._update_batch(
|
||||
batch_id,
|
||||
status="failed",
|
||||
failed_at=int(time.time()),
|
||||
errors=Errors(data=[BatchError(code="internal_error", message=str(e))]),
|
||||
)
|
||||
finally:
|
||||
self._processing_tasks.pop(batch_id, None)
|
||||
|
||||
async def _process_batch_impl(self, batch_id: str) -> None:
|
||||
"""Implementation of batch processing logic."""
|
||||
errors: list[BatchError] = []
|
||||
batch = await self.retrieve_batch(batch_id)
|
||||
|
||||
errors, requests = await self._validate_input(batch)
|
||||
if errors:
|
||||
await self._update_batch(batch_id, status="failed", failed_at=int(time.time()), errors=Errors(data=errors))
|
||||
logger.info(f"Batch validation failed for {batch_id} with {len(errors)} errors")
|
||||
return
|
||||
|
||||
logger.info(f"Processing {len(requests)} requests for batch {batch_id}")
|
||||
|
||||
total_requests = len(requests)
|
||||
await self._update_batch(
|
||||
batch_id,
|
||||
status="in_progress",
|
||||
request_counts={"total": total_requests, "completed": 0, "failed": 0},
|
||||
)
|
||||
|
||||
error_results = []
|
||||
success_results = []
|
||||
completed_count = 0
|
||||
failed_count = 0
|
||||
|
||||
for chunk in itertools.batched(requests, self.config.max_concurrent_requests_per_batch):
|
||||
# we use a TaskGroup to ensure all process-single-request tasks are canceled when process-batch is cancelled
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
chunk_tasks = [tg.create_task(self._process_single_request(batch_id, request)) for request in chunk]
|
||||
|
||||
chunk_results = await asyncio.gather(*chunk_tasks, return_exceptions=True)
|
||||
|
||||
for result in chunk_results:
|
||||
if isinstance(result, dict) and result.get("error") is not None: # error response from inference
|
||||
failed_count += 1
|
||||
error_results.append(result)
|
||||
elif isinstance(result, dict) and result.get("response") is not None: # successful inference
|
||||
completed_count += 1
|
||||
success_results.append(result)
|
||||
else: # unexpected result
|
||||
failed_count += 1
|
||||
errors.append(BatchError(code="internal_error", message=f"Unexpected result: {result}"))
|
||||
|
||||
await self._update_batch(
|
||||
batch_id,
|
||||
request_counts={"total": total_requests, "completed": completed_count, "failed": failed_count},
|
||||
)
|
||||
|
||||
if errors:
|
||||
await self._update_batch(
|
||||
batch_id, status="failed", failed_at=int(time.time()), errors=Errors(data=errors)
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
output_file_id = await self._create_output_file(batch_id, success_results, "success")
|
||||
await self._update_batch(batch_id, output_file_id=output_file_id)
|
||||
|
||||
error_file_id = await self._create_output_file(batch_id, error_results, "error")
|
||||
await self._update_batch(batch_id, error_file_id=error_file_id)
|
||||
|
||||
await self._update_batch(batch_id, status="completed", completed_at=int(time.time()))
|
||||
|
||||
logger.info(
|
||||
f"Batch processing completed for {batch_id}: {completed_count} completed, {failed_count} failed"
|
||||
)
|
||||
except Exception as e:
|
||||
# note: errors is empty at this point, so we don't lose anything by ignoring it
|
||||
await self._update_batch(
|
||||
batch_id,
|
||||
status="failed",
|
||||
failed_at=int(time.time()),
|
||||
errors=Errors(data=[BatchError(code="output_failed", message=str(e))]),
|
||||
)
|
||||
|
||||
async def _process_single_request(self, batch_id: str, request: BatchRequest) -> dict:
|
||||
"""Process a single request from the batch."""
|
||||
request_id = f"batch_req_{batch_id}_{request.line_num}"
|
||||
|
||||
try:
|
||||
# TODO(SECURITY): review body for security issues
|
||||
request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]]
|
||||
chat_response = await self.inference_api.openai_chat_completion(**request.body)
|
||||
|
||||
# this is for mypy, we don't allow streaming so we'll get the right type
|
||||
assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method"
|
||||
return {
|
||||
"id": request_id,
|
||||
"custom_id": request.custom_id,
|
||||
"response": {
|
||||
"status_code": 200,
|
||||
"request_id": request_id, # TODO: should this be different?
|
||||
"body": chat_response.model_dump_json(),
|
||||
},
|
||||
}
|
||||
except Exception as e:
|
||||
logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}")
|
||||
return {
|
||||
"id": request_id,
|
||||
"custom_id": request.custom_id,
|
||||
"error": {"type": "request_failed", "message": str(e)},
|
||||
}
|
||||
|
||||
async def _create_output_file(self, batch_id: str, results: list[dict], file_type: str) -> str:
|
||||
"""
|
||||
Create an output file with batch results.
|
||||
|
||||
This function filters results based on the specified file_type
|
||||
and uploads the file to the Files API.
|
||||
"""
|
||||
output_lines = [json.dumps(result) for result in results]
|
||||
|
||||
with AsyncBytesIO("\n".join(output_lines).encode("utf-8")) as file_buffer:
|
||||
file_buffer.filename = f"{batch_id}_{file_type}.jsonl"
|
||||
uploaded_file = await self.files_api.openai_upload_file(file=file_buffer, purpose=OpenAIFilePurpose.BATCH)
|
||||
return uploaded_file.id
|
40
llama_stack/providers/inline/batches/reference/config.py
Normal file
40
llama_stack/providers/inline/batches/reference/config.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
|
||||
|
||||
class ReferenceBatchesImplConfig(BaseModel):
|
||||
"""Configuration for the Reference Batches implementation."""
|
||||
|
||||
kvstore: KVStoreConfig = Field(
|
||||
description="Configuration for the key-value store backend.",
|
||||
)
|
||||
|
||||
max_concurrent_batches: int = Field(
|
||||
default=1,
|
||||
description="Maximum number of concurrent batches to process simultaneously.",
|
||||
ge=1,
|
||||
)
|
||||
|
||||
max_concurrent_requests_per_batch: int = Field(
|
||||
default=10,
|
||||
description="Maximum number of concurrent requests to process per batch.",
|
||||
ge=1,
|
||||
)
|
||||
|
||||
# TODO: add a max requests per second rate limiter
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> dict:
|
||||
return {
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="batches.db",
|
||||
),
|
||||
}
|
|
@ -5,7 +5,11 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeshield.cs import CodeShieldScanResult
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import (
|
||||
|
@ -14,6 +18,7 @@ from llama_stack.apis.safety import (
|
|||
SafetyViolation,
|
||||
ViolationLevel,
|
||||
)
|
||||
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
|
@ -24,8 +29,8 @@ from .config import CodeScannerConfig
|
|||
log = logging.getLogger(__name__)
|
||||
|
||||
ALLOWED_CODE_SCANNER_MODEL_IDS = [
|
||||
"CodeScanner",
|
||||
"CodeShield",
|
||||
"code-scanner",
|
||||
"code-shield",
|
||||
]
|
||||
|
||||
|
||||
|
@ -69,3 +74,55 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
|
|||
metadata={"violation_type": ",".join([issue.pattern_id for issue in result.issues_found])},
|
||||
)
|
||||
return RunShieldResponse(violation=violation)
|
||||
|
||||
def get_moderation_object_results(self, scan_result: "CodeShieldScanResult") -> ModerationObjectResults:
|
||||
categories = {}
|
||||
category_scores = {}
|
||||
category_applied_input_types = {}
|
||||
|
||||
flagged = scan_result.is_insecure
|
||||
user_message = None
|
||||
metadata = {}
|
||||
|
||||
if scan_result.is_insecure:
|
||||
pattern_ids = [issue.pattern_id for issue in scan_result.issues_found]
|
||||
categories = dict.fromkeys(pattern_ids, True)
|
||||
category_scores = dict.fromkeys(pattern_ids, 1.0)
|
||||
category_applied_input_types = {key: ["text"] for key in pattern_ids}
|
||||
user_message = f"Security concerns detected in the code. {scan_result.recommended_treatment.name}: {', '.join([issue.description for issue in scan_result.issues_found])}"
|
||||
metadata = {"violation_type": ",".join([issue.pattern_id for issue in scan_result.issues_found])}
|
||||
|
||||
return ModerationObjectResults(
|
||||
flagged=flagged,
|
||||
categories=categories,
|
||||
category_scores=category_scores,
|
||||
category_applied_input_types=category_applied_input_types,
|
||||
user_message=user_message,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
|
||||
inputs = input if isinstance(input, list) else [input]
|
||||
results = []
|
||||
|
||||
from codeshield.cs import CodeShield
|
||||
|
||||
for text_input in inputs:
|
||||
log.info(f"Running CodeScannerShield moderation on input: {text_input[:100]}...")
|
||||
try:
|
||||
scan_result = await CodeShield.scan_code(text_input)
|
||||
moderation_result = self.get_moderation_object_results(scan_result)
|
||||
except Exception as e:
|
||||
log.error(f"CodeShield.scan_code failed: {e}")
|
||||
# create safe fallback response on scanner failure to avoid blocking legitimate requests
|
||||
moderation_result = ModerationObjectResults(
|
||||
flagged=False,
|
||||
categories={},
|
||||
category_scores={},
|
||||
category_applied_input_types={},
|
||||
user_message=None,
|
||||
metadata={"scanner_error": str(e)},
|
||||
)
|
||||
results.append(moderation_result)
|
||||
|
||||
return ModerationObject(id=str(uuid.uuid4()), model=model, results=results)
|
||||
|
|
|
@ -11,11 +11,7 @@ from string import Template
|
|||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
Message,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.inference import Inference, Message, UserMessage
|
||||
from llama_stack.apis.safety import (
|
||||
RunShieldResponse,
|
||||
Safety,
|
||||
|
@ -72,7 +68,6 @@ SAFETY_CATEGORIES_TO_CODE_MAP = {
|
|||
}
|
||||
SAFETY_CODE_TO_CATEGORIES_MAP = {v: k for k, v in SAFETY_CATEGORIES_TO_CODE_MAP.items()}
|
||||
|
||||
|
||||
DEFAULT_LG_V3_SAFETY_CATEGORIES = [
|
||||
CAT_VIOLENT_CRIMES,
|
||||
CAT_NON_VIOLENT_CRIMES,
|
||||
|
@ -460,7 +455,7 @@ class LlamaGuardShield:
|
|||
|
||||
def is_content_safe(self, response: str, unsafe_code: str | None = None) -> bool:
|
||||
"""Check if content is safe based on response and unsafe code."""
|
||||
if response.strip() == SAFE_RESPONSE:
|
||||
if response.strip().lower().startswith(SAFE_RESPONSE):
|
||||
return True
|
||||
|
||||
if unsafe_code:
|
||||
|
|
26
llama_stack/providers/registry/batches.py
Normal file
26
llama_stack/providers/registry/batches.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
|
||||
|
||||
|
||||
def available_providers() -> list[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.batches,
|
||||
provider_type="inline::reference",
|
||||
pip_packages=["openai"],
|
||||
module="llama_stack.providers.inline.batches.reference",
|
||||
config_class="llama_stack.providers.inline.batches.reference.config.ReferenceBatchesImplConfig",
|
||||
api_dependencies=[
|
||||
Api.inference,
|
||||
Api.files,
|
||||
Api.models,
|
||||
],
|
||||
description="Reference implementation of batches API with KVStore persistence.",
|
||||
),
|
||||
]
|
|
@ -413,15 +413,6 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
|
||||
if params and params.get("mode") == "keyword":
|
||||
# Check if this is inline Milvus (Milvus-Lite)
|
||||
if hasattr(self.config, "db_path"):
|
||||
raise NotImplementedError(
|
||||
"Keyword search is not supported in Milvus-Lite. "
|
||||
"Please use a remote Milvus server for keyword search functionality."
|
||||
)
|
||||
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
|
|
|
@ -31,15 +31,21 @@ from openai.types.chat import (
|
|||
from openai.types.chat import (
|
||||
ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam,
|
||||
)
|
||||
|
||||
try:
|
||||
from openai.types.chat import (
|
||||
ChatCompletionMessageFunctionToolCall as OpenAIChatCompletionMessageFunctionToolCall,
|
||||
)
|
||||
except ImportError:
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall as OpenAIChatCompletionMessageFunctionToolCall,
|
||||
)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionMessageParam as OpenAIChatCompletionMessage,
|
||||
)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionMessageToolCall,
|
||||
)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall,
|
||||
)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage,
|
||||
)
|
||||
|
@ -633,7 +639,7 @@ async def convert_message_to_openai_dict_new(
|
|||
)
|
||||
elif isinstance(message, CompletionMessage):
|
||||
tool_calls = [
|
||||
OpenAIChatCompletionMessageToolCall(
|
||||
OpenAIChatCompletionMessageFunctionToolCall(
|
||||
id=tool.call_id,
|
||||
function=OpenAIFunction(
|
||||
name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value),
|
||||
|
@ -903,7 +909,7 @@ def _convert_openai_request_response_format(
|
|||
|
||||
|
||||
def _convert_openai_tool_calls(
|
||||
tool_calls: list[OpenAIChatCompletionMessageToolCall],
|
||||
tool_calls: list[OpenAIChatCompletionMessageFunctionToolCall],
|
||||
) -> list[ToolCall]:
|
||||
"""
|
||||
Convert an OpenAI ChatCompletionMessageToolCall list into a list of ToolCall.
|
||||
|
|
|
@ -75,6 +75,8 @@ class PostgresKVStoreConfig(CommonConfig):
|
|||
db: str = "llamastack"
|
||||
user: str
|
||||
password: str | None = None
|
||||
ssl_mode: str | None = None
|
||||
ca_cert_path: str | None = None
|
||||
table_name: str = "llamastack_kvstore"
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -30,6 +30,8 @@ class PostgresKVStoreImpl(KVStore):
|
|||
database=self.config.db,
|
||||
user=self.config.user,
|
||||
password=self.config.password,
|
||||
sslmode=self.config.ssl_mode,
|
||||
sslrootcert=self.config.ca_cert_path,
|
||||
)
|
||||
self.conn.autocommit = True
|
||||
self.cursor = self.conn.cursor(cursor_factory=DictCursor)
|
||||
|
|
|
@ -261,7 +261,7 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
|||
else:
|
||||
raise RuntimeError(
|
||||
f"No recorded response found for request hash: {request_hash}\n"
|
||||
f"Endpoint: {endpoint}\n"
|
||||
f"Request: {method} {url} {body}\n"
|
||||
f"Model: {body.get('model', 'unknown')}\n"
|
||||
f"To record this response, run with LLAMA_STACK_INFERENCE_MODE=record"
|
||||
)
|
||||
|
|
1
llama_stack/ui/.nvmrc
Normal file
1
llama_stack/ui/.nvmrc
Normal file
|
@ -0,0 +1 @@
|
|||
22.5.1
|
|
@ -1,3 +1,12 @@
|
|||
# Ignore artifacts:
|
||||
build
|
||||
coverage
|
||||
.next
|
||||
node_modules
|
||||
dist
|
||||
*.lock
|
||||
*.log
|
||||
|
||||
# Generated files
|
||||
*.min.js
|
||||
*.min.css
|
||||
|
|
|
@ -1 +1,10 @@
|
|||
{}
|
||||
{
|
||||
"semi": true,
|
||||
"trailingComma": "es5",
|
||||
"singleQuote": false,
|
||||
"printWidth": 80,
|
||||
"tabWidth": 2,
|
||||
"useTabs": false,
|
||||
"bracketSpacing": true,
|
||||
"arrowParens": "avoid"
|
||||
}
|
||||
|
|
|
@ -47,7 +47,7 @@ async function proxyRequest(request: NextRequest, method: string) {
|
|||
const responseText = await response.text();
|
||||
|
||||
console.log(
|
||||
`Response from FastAPI: ${response.status} ${response.statusText}`,
|
||||
`Response from FastAPI: ${response.status} ${response.statusText}`
|
||||
);
|
||||
|
||||
// Create response with same status and headers
|
||||
|
@ -74,7 +74,7 @@ async function proxyRequest(request: NextRequest, method: string) {
|
|||
backend_url: BACKEND_URL,
|
||||
timestamp: new Date().toISOString(),
|
||||
},
|
||||
{ status: 500 },
|
||||
{ status: 500 }
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -51,9 +51,9 @@ export default function SignInPage() {
|
|||
onClick={() => {
|
||||
console.log("Signing in with GitHub...");
|
||||
signIn("github", { callbackUrl: "/auth/signin" }).catch(
|
||||
(error) => {
|
||||
error => {
|
||||
console.error("Sign in error:", error);
|
||||
},
|
||||
}
|
||||
);
|
||||
}}
|
||||
className="w-full"
|
||||
|
|
|
@ -29,14 +29,13 @@ export default function ChatPlaygroundPage() {
|
|||
|
||||
const isModelsLoading = modelsLoading ?? true;
|
||||
|
||||
|
||||
useEffect(() => {
|
||||
const fetchModels = async () => {
|
||||
try {
|
||||
setModelsLoading(true);
|
||||
setModelsError(null);
|
||||
const modelList = await client.models.list();
|
||||
const llmModels = modelList.filter(model => model.model_type === 'llm');
|
||||
const llmModels = modelList.filter(model => model.model_type === "llm");
|
||||
setModels(llmModels);
|
||||
if (llmModels.length > 0) {
|
||||
setSelectedModel(llmModels[0].identifier);
|
||||
|
@ -53,103 +52,122 @@ export default function ChatPlaygroundPage() {
|
|||
}, [client]);
|
||||
|
||||
const extractTextContent = (content: unknown): string => {
|
||||
if (typeof content === 'string') {
|
||||
if (typeof content === "string") {
|
||||
return content;
|
||||
}
|
||||
if (Array.isArray(content)) {
|
||||
return content
|
||||
.filter(item => item && typeof item === 'object' && 'type' in item && item.type === 'text')
|
||||
.map(item => (item && typeof item === 'object' && 'text' in item) ? String(item.text) : '')
|
||||
.join('');
|
||||
.filter(
|
||||
item =>
|
||||
item &&
|
||||
typeof item === "object" &&
|
||||
"type" in item &&
|
||||
item.type === "text"
|
||||
)
|
||||
.map(item =>
|
||||
item && typeof item === "object" && "text" in item
|
||||
? String(item.text)
|
||||
: ""
|
||||
)
|
||||
.join("");
|
||||
}
|
||||
if (content && typeof content === 'object' && 'type' in content && content.type === 'text' && 'text' in content) {
|
||||
return String(content.text) || '';
|
||||
if (
|
||||
content &&
|
||||
typeof content === "object" &&
|
||||
"type" in content &&
|
||||
content.type === "text" &&
|
||||
"text" in content
|
||||
) {
|
||||
return String(content.text) || "";
|
||||
}
|
||||
return '';
|
||||
return "";
|
||||
};
|
||||
|
||||
const handleInputChange = (e: React.ChangeEvent<HTMLTextAreaElement>) => {
|
||||
setInput(e.target.value);
|
||||
};
|
||||
|
||||
const handleSubmit = async (event?: { preventDefault?: () => void }) => {
|
||||
event?.preventDefault?.();
|
||||
if (!input.trim()) return;
|
||||
const handleSubmit = async (event?: { preventDefault?: () => void }) => {
|
||||
event?.preventDefault?.();
|
||||
if (!input.trim()) return;
|
||||
|
||||
// Add user message to chat
|
||||
const userMessage: Message = {
|
||||
id: Date.now().toString(),
|
||||
role: "user",
|
||||
content: input.trim(),
|
||||
createdAt: new Date(),
|
||||
};
|
||||
|
||||
setMessages(prev => [...prev, userMessage]);
|
||||
setInput("");
|
||||
|
||||
// Use the helper function with the content
|
||||
await handleSubmitWithContent(userMessage.content);
|
||||
};
|
||||
|
||||
const handleSubmitWithContent = async (content: string) => {
|
||||
setIsGenerating(true);
|
||||
setError(null);
|
||||
|
||||
try {
|
||||
const messageParams: CompletionCreateParams["messages"] = [
|
||||
...messages.map(msg => {
|
||||
const msgContent = typeof msg.content === 'string' ? msg.content : extractTextContent(msg.content);
|
||||
if (msg.role === "user") {
|
||||
return { role: "user" as const, content: msgContent };
|
||||
} else if (msg.role === "assistant") {
|
||||
return { role: "assistant" as const, content: msgContent };
|
||||
} else {
|
||||
return { role: "system" as const, content: msgContent };
|
||||
}
|
||||
}),
|
||||
{ role: "user" as const, content }
|
||||
];
|
||||
|
||||
const response = await client.chat.completions.create({
|
||||
model: selectedModel,
|
||||
messages: messageParams,
|
||||
stream: true,
|
||||
});
|
||||
|
||||
const assistantMessage: Message = {
|
||||
id: (Date.now() + 1).toString(),
|
||||
role: "assistant",
|
||||
content: "",
|
||||
// Add user message to chat
|
||||
const userMessage: Message = {
|
||||
id: Date.now().toString(),
|
||||
role: "user",
|
||||
content: input.trim(),
|
||||
createdAt: new Date(),
|
||||
};
|
||||
|
||||
setMessages(prev => [...prev, assistantMessage]);
|
||||
let fullContent = "";
|
||||
for await (const chunk of response) {
|
||||
if (chunk.choices && chunk.choices[0]?.delta?.content) {
|
||||
const deltaContent = chunk.choices[0].delta.content;
|
||||
fullContent += deltaContent;
|
||||
setMessages(prev => [...prev, userMessage]);
|
||||
setInput("");
|
||||
|
||||
flushSync(() => {
|
||||
setMessages(prev => {
|
||||
const newMessages = [...prev];
|
||||
const lastMessage = newMessages[newMessages.length - 1];
|
||||
if (lastMessage.role === "assistant") {
|
||||
lastMessage.content = fullContent;
|
||||
}
|
||||
return newMessages;
|
||||
// Use the helper function with the content
|
||||
await handleSubmitWithContent(userMessage.content);
|
||||
};
|
||||
|
||||
const handleSubmitWithContent = async (content: string) => {
|
||||
setIsGenerating(true);
|
||||
setError(null);
|
||||
|
||||
try {
|
||||
const messageParams: CompletionCreateParams["messages"] = [
|
||||
...messages.map(msg => {
|
||||
const msgContent =
|
||||
typeof msg.content === "string"
|
||||
? msg.content
|
||||
: extractTextContent(msg.content);
|
||||
if (msg.role === "user") {
|
||||
return { role: "user" as const, content: msgContent };
|
||||
} else if (msg.role === "assistant") {
|
||||
return { role: "assistant" as const, content: msgContent };
|
||||
} else {
|
||||
return { role: "system" as const, content: msgContent };
|
||||
}
|
||||
}),
|
||||
{ role: "user" as const, content },
|
||||
];
|
||||
|
||||
const response = await client.chat.completions.create({
|
||||
model: selectedModel,
|
||||
messages: messageParams,
|
||||
stream: true,
|
||||
});
|
||||
|
||||
const assistantMessage: Message = {
|
||||
id: (Date.now() + 1).toString(),
|
||||
role: "assistant",
|
||||
content: "",
|
||||
createdAt: new Date(),
|
||||
};
|
||||
|
||||
setMessages(prev => [...prev, assistantMessage]);
|
||||
let fullContent = "";
|
||||
for await (const chunk of response) {
|
||||
if (chunk.choices && chunk.choices[0]?.delta?.content) {
|
||||
const deltaContent = chunk.choices[0].delta.content;
|
||||
fullContent += deltaContent;
|
||||
|
||||
flushSync(() => {
|
||||
setMessages(prev => {
|
||||
const newMessages = [...prev];
|
||||
const lastMessage = newMessages[newMessages.length - 1];
|
||||
if (lastMessage.role === "assistant") {
|
||||
lastMessage.content = fullContent;
|
||||
}
|
||||
return newMessages;
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
console.error("Error sending message:", err);
|
||||
setError("Failed to send message. Please try again.");
|
||||
setMessages(prev => prev.slice(0, -1));
|
||||
} finally {
|
||||
setIsGenerating(false);
|
||||
}
|
||||
} catch (err) {
|
||||
console.error("Error sending message:", err);
|
||||
setError("Failed to send message. Please try again.");
|
||||
setMessages(prev => prev.slice(0, -1));
|
||||
} finally {
|
||||
setIsGenerating(false);
|
||||
}
|
||||
};
|
||||
};
|
||||
const suggestions = [
|
||||
"Write a Python function that prints 'Hello, World!'",
|
||||
"Explain step-by-step how to solve this math problem: If x² + 6x + 9 = 25, what is x?",
|
||||
|
@ -163,7 +181,7 @@ const handleSubmitWithContent = async (content: string) => {
|
|||
content: message.content,
|
||||
createdAt: new Date(),
|
||||
};
|
||||
setMessages(prev => [...prev, newMessage])
|
||||
setMessages(prev => [...prev, newMessage]);
|
||||
handleSubmitWithContent(newMessage.content);
|
||||
};
|
||||
|
||||
|
@ -177,12 +195,20 @@ const handleSubmitWithContent = async (content: string) => {
|
|||
<div className="mb-4 flex justify-between items-center">
|
||||
<h1 className="text-2xl font-bold">Chat Playground (Completions)</h1>
|
||||
<div className="flex gap-2">
|
||||
<Select value={selectedModel} onValueChange={setSelectedModel} disabled={isModelsLoading || isGenerating}>
|
||||
<Select
|
||||
value={selectedModel}
|
||||
onValueChange={setSelectedModel}
|
||||
disabled={isModelsLoading || isGenerating}
|
||||
>
|
||||
<SelectTrigger className="w-[180px]">
|
||||
<SelectValue placeholder={isModelsLoading ? "Loading models..." : "Select Model"} />
|
||||
<SelectValue
|
||||
placeholder={
|
||||
isModelsLoading ? "Loading models..." : "Select Model"
|
||||
}
|
||||
/>
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
{models.map((model) => (
|
||||
{models.map(model => (
|
||||
<SelectItem key={model.identifier} value={model.identifier}>
|
||||
{model.identifier}
|
||||
</SelectItem>
|
||||
|
|
|
@ -33,12 +33,12 @@ export default function ChatCompletionDetailPage() {
|
|||
} catch (err) {
|
||||
console.error(
|
||||
`Error fetching chat completion detail for ID ${id}:`,
|
||||
err,
|
||||
err
|
||||
);
|
||||
setError(
|
||||
err instanceof Error
|
||||
? err
|
||||
: new Error("Failed to fetch completion detail"),
|
||||
: new Error("Failed to fetch completion detail")
|
||||
);
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
|
|
|
@ -13,10 +13,10 @@ export default function ResponseDetailPage() {
|
|||
const client = useAuthClient();
|
||||
|
||||
const [responseDetail, setResponseDetail] = useState<OpenAIResponse | null>(
|
||||
null,
|
||||
null
|
||||
);
|
||||
const [inputItems, setInputItems] = useState<InputItemListResponse | null>(
|
||||
null,
|
||||
null
|
||||
);
|
||||
const [isLoading, setIsLoading] = useState<boolean>(true);
|
||||
const [isLoadingInputItems, setIsLoadingInputItems] = useState<boolean>(true);
|
||||
|
@ -25,7 +25,7 @@ export default function ResponseDetailPage() {
|
|||
|
||||
// Helper function to convert ResponseObject to OpenAIResponse
|
||||
const convertResponseObject = (
|
||||
responseData: ResponseObject,
|
||||
responseData: ResponseObject
|
||||
): OpenAIResponse => {
|
||||
return {
|
||||
id: responseData.id,
|
||||
|
@ -73,12 +73,12 @@ export default function ResponseDetailPage() {
|
|||
} else {
|
||||
console.error(
|
||||
`Error fetching response detail for ID ${id}:`,
|
||||
responseResult.reason,
|
||||
responseResult.reason
|
||||
);
|
||||
setError(
|
||||
responseResult.reason instanceof Error
|
||||
? responseResult.reason
|
||||
: new Error("Failed to fetch response detail"),
|
||||
: new Error("Failed to fetch response detail")
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -90,18 +90,18 @@ export default function ResponseDetailPage() {
|
|||
} else {
|
||||
console.error(
|
||||
`Error fetching input items for response ID ${id}:`,
|
||||
inputItemsResult.reason,
|
||||
inputItemsResult.reason
|
||||
);
|
||||
setInputItemsError(
|
||||
inputItemsResult.reason instanceof Error
|
||||
? inputItemsResult.reason
|
||||
: new Error("Failed to fetch input items"),
|
||||
: new Error("Failed to fetch input items")
|
||||
);
|
||||
}
|
||||
} catch (err) {
|
||||
console.error(`Unexpected error fetching data for ID ${id}:`, err);
|
||||
setError(
|
||||
err instanceof Error ? err : new Error("Unexpected error occurred"),
|
||||
err instanceof Error ? err : new Error("Unexpected error occurred")
|
||||
);
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
|
|
|
@ -0,0 +1,425 @@
|
|||
import React from "react";
|
||||
import { render, screen, fireEvent, waitFor } from "@testing-library/react";
|
||||
import "@testing-library/jest-dom";
|
||||
import ContentDetailPage from "./page";
|
||||
import { VectorStoreContentItem } from "@/lib/contents-api";
|
||||
import type { VectorStore } from "llama-stack-client/resources/vector-stores/vector-stores";
|
||||
import type { VectorStoreFile } from "llama-stack-client/resources/vector-stores/files";
|
||||
|
||||
const mockPush = jest.fn();
|
||||
const mockParams = {
|
||||
id: "vs_123",
|
||||
fileId: "file_456",
|
||||
contentId: "content_789",
|
||||
};
|
||||
|
||||
jest.mock("next/navigation", () => ({
|
||||
useParams: () => mockParams,
|
||||
useRouter: () => ({
|
||||
push: mockPush,
|
||||
}),
|
||||
}));
|
||||
|
||||
const mockClient = {
|
||||
vectorStores: {
|
||||
retrieve: jest.fn(),
|
||||
files: {
|
||||
retrieve: jest.fn(),
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
jest.mock("@/hooks/use-auth-client", () => ({
|
||||
useAuthClient: () => mockClient,
|
||||
}));
|
||||
|
||||
const mockContentsAPI = {
|
||||
listContents: jest.fn(),
|
||||
updateContent: jest.fn(),
|
||||
deleteContent: jest.fn(),
|
||||
};
|
||||
|
||||
jest.mock("@/lib/contents-api", () => ({
|
||||
ContentsAPI: jest.fn(() => mockContentsAPI),
|
||||
}));
|
||||
|
||||
const originalConfirm = window.confirm;
|
||||
|
||||
describe("ContentDetailPage", () => {
|
||||
const mockStore: VectorStore = {
|
||||
id: "vs_123",
|
||||
name: "Test Vector Store",
|
||||
created_at: 1710000000,
|
||||
status: "ready",
|
||||
file_counts: { total: 5 },
|
||||
usage_bytes: 1024,
|
||||
metadata: {
|
||||
provider_id: "test_provider",
|
||||
},
|
||||
};
|
||||
|
||||
const mockFile: VectorStoreFile = {
|
||||
id: "file_456",
|
||||
status: "completed",
|
||||
created_at: 1710001000,
|
||||
usage_bytes: 512,
|
||||
chunking_strategy: { type: "fixed_size" },
|
||||
};
|
||||
|
||||
const mockContent: VectorStoreContentItem = {
|
||||
id: "content_789",
|
||||
object: "vector_store.content",
|
||||
content: "This is test content for the vector store.",
|
||||
embedding: [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
metadata: {
|
||||
chunk_window: "0-45",
|
||||
content_length: 45,
|
||||
custom_field: "custom_value",
|
||||
},
|
||||
created_timestamp: 1710002000,
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
window.confirm = jest.fn();
|
||||
|
||||
mockClient.vectorStores.retrieve.mockResolvedValue(mockStore);
|
||||
mockClient.vectorStores.files.retrieve.mockResolvedValue(mockFile);
|
||||
mockContentsAPI.listContents.mockResolvedValue({
|
||||
data: [mockContent],
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
window.confirm = originalConfirm;
|
||||
});
|
||||
|
||||
describe("Loading and Error States", () => {
|
||||
test("renders loading skeleton while fetching data", () => {
|
||||
mockClient.vectorStores.retrieve.mockImplementation(
|
||||
() => new Promise(() => {})
|
||||
);
|
||||
|
||||
const { container } = render(<ContentDetailPage />);
|
||||
|
||||
const skeletons = container.querySelectorAll('[data-slot="skeleton"]');
|
||||
expect(skeletons.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
test("renders error message when API calls fail", async () => {
|
||||
const error = new Error("Network error");
|
||||
mockClient.vectorStores.retrieve.mockRejectedValue(error);
|
||||
|
||||
render(<ContentDetailPage />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText(/Error loading details for ID content_789/)
|
||||
).toBeInTheDocument();
|
||||
expect(screen.getByText(/Network error/)).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("renders not found when content doesn't exist", async () => {
|
||||
mockContentsAPI.listContents.mockResolvedValue({
|
||||
data: [],
|
||||
});
|
||||
|
||||
render(<ContentDetailPage />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText(/Content content_789 not found/)
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("Content Display", () => {
|
||||
test("renders content details correctly", async () => {
|
||||
render(<ContentDetailPage />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("Content: content_789")).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText("This is test content for the vector store.")
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
const contentIdTexts = screen.getAllByText("content_789");
|
||||
expect(contentIdTexts.length).toBeGreaterThan(0);
|
||||
const fileIdTexts = screen.getAllByText("file_456");
|
||||
expect(fileIdTexts.length).toBeGreaterThan(0);
|
||||
const storeIdTexts = screen.getAllByText("vs_123");
|
||||
expect(storeIdTexts.length).toBeGreaterThan(0);
|
||||
expect(screen.getByText("vector_store.content")).toBeInTheDocument();
|
||||
const positionTexts = screen.getAllByText("0-45");
|
||||
expect(positionTexts.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
test("renders embedding information when available", async () => {
|
||||
render(<ContentDetailPage />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText(/0.100000, 0.200000, 0.300000/)
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("handles content without embedding", async () => {
|
||||
const contentWithoutEmbedding = {
|
||||
...mockContent,
|
||||
embedding: undefined,
|
||||
};
|
||||
|
||||
mockContentsAPI.listContents.mockResolvedValue({
|
||||
data: [contentWithoutEmbedding],
|
||||
});
|
||||
|
||||
render(<ContentDetailPage />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("No embedding available for this content.")
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("renders metadata correctly", async () => {
|
||||
render(<ContentDetailPage />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("chunk_window:")).toBeInTheDocument();
|
||||
const positionTexts = screen.getAllByText("0-45");
|
||||
expect(positionTexts.length).toBeGreaterThan(0);
|
||||
expect(screen.getByText("content_length:")).toBeInTheDocument();
|
||||
expect(screen.getByText("custom_field:")).toBeInTheDocument();
|
||||
expect(screen.getByText("custom_value")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("Edit Functionality", () => {
|
||||
test("enables edit mode when edit button is clicked", async () => {
|
||||
render(<ContentDetailPage />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("This is test content for the vector store.")
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
const editButtons = screen.getAllByRole("button", { name: /Edit/ });
|
||||
const editButton = editButtons[0];
|
||||
fireEvent.click(editButton);
|
||||
|
||||
expect(
|
||||
screen.getByDisplayValue("This is test content for the vector store.")
|
||||
).toBeInTheDocument();
|
||||
expect(screen.getByRole("button", { name: /Save/ })).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByRole("button", { name: /Cancel/ })
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("cancels edit mode and resets content", async () => {
|
||||
render(<ContentDetailPage />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("This is test content for the vector store.")
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
const editButtons = screen.getAllByRole("button", { name: /Edit/ });
|
||||
const editButton = editButtons[0];
|
||||
fireEvent.click(editButton);
|
||||
|
||||
const textarea = screen.getByDisplayValue(
|
||||
"This is test content for the vector store."
|
||||
);
|
||||
fireEvent.change(textarea, { target: { value: "Modified content" } });
|
||||
|
||||
const cancelButton = screen.getByRole("button", { name: /Cancel/ });
|
||||
fireEvent.click(cancelButton);
|
||||
|
||||
expect(
|
||||
screen.getByText("This is test content for the vector store.")
|
||||
).toBeInTheDocument();
|
||||
expect(
|
||||
screen.queryByDisplayValue("Modified content")
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("saves content changes", async () => {
|
||||
const updatedContent = { ...mockContent, content: "Updated content" };
|
||||
mockContentsAPI.updateContent.mockResolvedValue(updatedContent);
|
||||
|
||||
render(<ContentDetailPage />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("This is test content for the vector store.")
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
const editButtons = screen.getAllByRole("button", { name: /Edit/ });
|
||||
const editButton = editButtons[0];
|
||||
fireEvent.click(editButton);
|
||||
|
||||
const textarea = screen.getByDisplayValue(
|
||||
"This is test content for the vector store."
|
||||
);
|
||||
fireEvent.change(textarea, { target: { value: "Updated content" } });
|
||||
|
||||
const saveButton = screen.getByRole("button", { name: /Save/ });
|
||||
fireEvent.click(saveButton);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockContentsAPI.updateContent).toHaveBeenCalledWith(
|
||||
"vs_123",
|
||||
"file_456",
|
||||
"content_789",
|
||||
{ content: "Updated content" }
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("Delete Functionality", () => {
|
||||
test("shows confirmation dialog before deleting", async () => {
|
||||
window.confirm = jest.fn().mockReturnValue(false);
|
||||
|
||||
render(<ContentDetailPage />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("This is test content for the vector store.")
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
const deleteButton = screen.getByRole("button", { name: /Delete/ });
|
||||
fireEvent.click(deleteButton);
|
||||
|
||||
expect(window.confirm).toHaveBeenCalledWith(
|
||||
"Are you sure you want to delete this content?"
|
||||
);
|
||||
expect(mockContentsAPI.deleteContent).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test("deletes content when confirmed", async () => {
|
||||
window.confirm = jest.fn().mockReturnValue(true);
|
||||
|
||||
render(<ContentDetailPage />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("This is test content for the vector store.")
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
const deleteButton = screen.getByRole("button", { name: /Delete/ });
|
||||
fireEvent.click(deleteButton);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockContentsAPI.deleteContent).toHaveBeenCalledWith(
|
||||
"vs_123",
|
||||
"file_456",
|
||||
"content_789"
|
||||
);
|
||||
expect(mockPush).toHaveBeenCalledWith(
|
||||
"/logs/vector-stores/vs_123/files/file_456/contents"
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("Embedding Edit Functionality", () => {
|
||||
test("enables embedding edit mode", async () => {
|
||||
render(<ContentDetailPage />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("This is test content for the vector store.")
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
const embeddingEditButtons = screen.getAllByRole("button", {
|
||||
name: /Edit/,
|
||||
});
|
||||
expect(embeddingEditButtons.length).toBeGreaterThanOrEqual(1);
|
||||
});
|
||||
|
||||
test.skip("cancels embedding edit mode", async () => {
|
||||
render(<ContentDetailPage />);
|
||||
|
||||
await waitFor(() => {
|
||||
// skip vector text check, just verify test completes
|
||||
});
|
||||
|
||||
const embeddingEditButtons = screen.getAllByRole("button", {
|
||||
name: /Edit/,
|
||||
});
|
||||
const embeddingEditButton = embeddingEditButtons[1];
|
||||
fireEvent.click(embeddingEditButton);
|
||||
|
||||
const cancelButtons = screen.getAllByRole("button", { name: /Cancel/ });
|
||||
expect(cancelButtons.length).toBeGreaterThan(0);
|
||||
expect(
|
||||
screen.queryByDisplayValue(/0.1,0.2,0.3,0.4,0.5/)
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Breadcrumb Navigation", () => {
|
||||
test("renders correct breadcrumb structure", async () => {
|
||||
render(<ContentDetailPage />);
|
||||
|
||||
await waitFor(() => {
|
||||
const vectorStoreTexts = screen.getAllByText("Vector Stores");
|
||||
expect(vectorStoreTexts.length).toBeGreaterThan(0);
|
||||
const storeNameTexts = screen.getAllByText("Test Vector Store");
|
||||
expect(storeNameTexts.length).toBeGreaterThan(0);
|
||||
const contentsTexts = screen.getAllByText("Contents");
|
||||
expect(contentsTexts.length).toBeGreaterThan(0);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("Content Utilities", () => {
|
||||
test("handles different content types correctly", async () => {
|
||||
const contentWithObjectType = {
|
||||
...mockContent,
|
||||
content: { type: "text", text: "Text object content" },
|
||||
};
|
||||
|
||||
mockContentsAPI.listContents.mockResolvedValue({
|
||||
data: [contentWithObjectType],
|
||||
});
|
||||
|
||||
render(<ContentDetailPage />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("Text object content")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("handles string content type", async () => {
|
||||
const contentWithStringType = {
|
||||
...mockContent,
|
||||
content: "Simple string content",
|
||||
};
|
||||
|
||||
mockContentsAPI.listContents.mockResolvedValue({
|
||||
data: [contentWithStringType],
|
||||
});
|
||||
|
||||
render(<ContentDetailPage />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("Simple string content")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
|
@ -18,7 +18,10 @@ import {
|
|||
PropertiesCard,
|
||||
PropertyItem,
|
||||
} from "@/components/layout/detail-layout";
|
||||
import { PageBreadcrumb, BreadcrumbSegment } from "@/components/layout/page-breadcrumb";
|
||||
import {
|
||||
PageBreadcrumb,
|
||||
BreadcrumbSegment,
|
||||
} from "@/components/layout/page-breadcrumb";
|
||||
|
||||
export default function ContentDetailPage() {
|
||||
const params = useParams();
|
||||
|
@ -28,13 +31,13 @@ export default function ContentDetailPage() {
|
|||
const contentId = params.contentId as string;
|
||||
const client = useAuthClient();
|
||||
|
||||
const getTextFromContent = (content: any): string => {
|
||||
if (typeof content === 'string') {
|
||||
const getTextFromContent = (content: unknown): string => {
|
||||
if (typeof content === "string") {
|
||||
return content;
|
||||
} else if (content && content.type === 'text') {
|
||||
} else if (content && content.type === "text") {
|
||||
return content.text;
|
||||
}
|
||||
return '';
|
||||
return "";
|
||||
};
|
||||
|
||||
const [store, setStore] = useState<VectorStore | null>(null);
|
||||
|
@ -44,7 +47,9 @@ export default function ContentDetailPage() {
|
|||
const [error, setError] = useState<Error | null>(null);
|
||||
const [isEditing, setIsEditing] = useState(false);
|
||||
const [editedContent, setEditedContent] = useState("");
|
||||
const [editedMetadata, setEditedMetadata] = useState<Record<string, any>>({});
|
||||
const [editedMetadata, setEditedMetadata] = useState<Record<string, unknown>>(
|
||||
{}
|
||||
);
|
||||
const [isEditingEmbedding, setIsEditingEmbedding] = useState(false);
|
||||
const [editedEmbedding, setEditedEmbedding] = useState<number[]>([]);
|
||||
|
||||
|
@ -64,8 +69,13 @@ export default function ContentDetailPage() {
|
|||
setFile(fileResponse as VectorStoreFile);
|
||||
|
||||
const contentsAPI = new ContentsAPI(client);
|
||||
const contentsResponse = await contentsAPI.listContents(vectorStoreId, fileId);
|
||||
const targetContent = contentsResponse.data.find(c => c.id === contentId);
|
||||
const contentsResponse = await contentsAPI.listContents(
|
||||
vectorStoreId,
|
||||
fileId
|
||||
);
|
||||
const targetContent = contentsResponse.data.find(
|
||||
c => c.id === contentId
|
||||
);
|
||||
|
||||
if (targetContent) {
|
||||
setContent(targetContent);
|
||||
|
@ -76,7 +86,9 @@ export default function ContentDetailPage() {
|
|||
throw new Error(`Content ${contentId} not found`);
|
||||
}
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err : new Error("Failed to load content."));
|
||||
setError(
|
||||
err instanceof Error ? err : new Error("Failed to load content.")
|
||||
);
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
|
@ -88,7 +100,8 @@ export default function ContentDetailPage() {
|
|||
if (!content) return;
|
||||
|
||||
try {
|
||||
const updates: { content?: string; metadata?: Record<string, any> } = {};
|
||||
const updates: { content?: string; metadata?: Record<string, unknown> } =
|
||||
{};
|
||||
|
||||
if (editedContent !== getTextFromContent(content.content)) {
|
||||
updates.content = editedContent;
|
||||
|
@ -100,25 +113,32 @@ export default function ContentDetailPage() {
|
|||
|
||||
if (Object.keys(updates).length > 0) {
|
||||
const contentsAPI = new ContentsAPI(client);
|
||||
const updatedContent = await contentsAPI.updateContent(vectorStoreId, fileId, contentId, updates);
|
||||
const updatedContent = await contentsAPI.updateContent(
|
||||
vectorStoreId,
|
||||
fileId,
|
||||
contentId,
|
||||
updates
|
||||
);
|
||||
setContent(updatedContent);
|
||||
}
|
||||
|
||||
setIsEditing(false);
|
||||
} catch (err) {
|
||||
console.error('Failed to update content:', err);
|
||||
console.error("Failed to update content:", err);
|
||||
}
|
||||
};
|
||||
|
||||
const handleDelete = async () => {
|
||||
if (!confirm('Are you sure you want to delete this content?')) return;
|
||||
if (!confirm("Are you sure you want to delete this content?")) return;
|
||||
|
||||
try {
|
||||
const contentsAPI = new ContentsAPI(client);
|
||||
await contentsAPI.deleteContent(vectorStoreId, fileId, contentId);
|
||||
router.push(`/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents`);
|
||||
router.push(
|
||||
`/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents`
|
||||
);
|
||||
} catch (err) {
|
||||
console.error('Failed to delete content:', err);
|
||||
console.error("Failed to delete content:", err);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -134,10 +154,19 @@ export default function ContentDetailPage() {
|
|||
|
||||
const breadcrumbSegments: BreadcrumbSegment[] = [
|
||||
{ label: "Vector Stores", href: "/logs/vector-stores" },
|
||||
{ label: store?.name || vectorStoreId, href: `/logs/vector-stores/${vectorStoreId}` },
|
||||
{
|
||||
label: store?.name || vectorStoreId,
|
||||
href: `/logs/vector-stores/${vectorStoreId}`,
|
||||
},
|
||||
{ label: "Files", href: `/logs/vector-stores/${vectorStoreId}` },
|
||||
{ label: fileId, href: `/logs/vector-stores/${vectorStoreId}/files/${fileId}` },
|
||||
{ label: "Contents", href: `/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents` },
|
||||
{
|
||||
label: fileId,
|
||||
href: `/logs/vector-stores/${vectorStoreId}/files/${fileId}`,
|
||||
},
|
||||
{
|
||||
label: "Contents",
|
||||
href: `/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents`,
|
||||
},
|
||||
{ label: contentId },
|
||||
];
|
||||
|
||||
|
@ -186,7 +215,7 @@ export default function ContentDetailPage() {
|
|||
{isEditing ? (
|
||||
<textarea
|
||||
value={editedContent}
|
||||
onChange={(e) => setEditedContent(e.target.value)}
|
||||
onChange={e => setEditedContent(e.target.value)}
|
||||
className="w-full h-64 p-3 border rounded-md resize-none font-mono text-sm"
|
||||
placeholder="Enter content..."
|
||||
/>
|
||||
|
@ -206,16 +235,23 @@ export default function ContentDetailPage() {
|
|||
<div className="flex gap-2">
|
||||
{isEditingEmbedding ? (
|
||||
<>
|
||||
<Button size="sm" onClick={() => {
|
||||
setIsEditingEmbedding(false);
|
||||
}}>
|
||||
<Button
|
||||
size="sm"
|
||||
onClick={() => {
|
||||
setIsEditingEmbedding(false);
|
||||
}}
|
||||
>
|
||||
<Save className="h-4 w-4 mr-1" />
|
||||
Save
|
||||
</Button>
|
||||
<Button size="sm" variant="outline" onClick={() => {
|
||||
setEditedEmbedding(content?.embedding || []);
|
||||
setIsEditingEmbedding(false);
|
||||
}}>
|
||||
<Button
|
||||
size="sm"
|
||||
variant="outline"
|
||||
onClick={() => {
|
||||
setEditedEmbedding(content?.embedding || []);
|
||||
setIsEditingEmbedding(false);
|
||||
}}
|
||||
>
|
||||
<X className="h-4 w-4 mr-1" />
|
||||
Cancel
|
||||
</Button>
|
||||
|
@ -237,14 +273,16 @@ export default function ContentDetailPage() {
|
|||
</p>
|
||||
<textarea
|
||||
value={JSON.stringify(editedEmbedding, null, 2)}
|
||||
onChange={(e) => {
|
||||
onChange={e => {
|
||||
try {
|
||||
const parsed = JSON.parse(e.target.value);
|
||||
if (Array.isArray(parsed) && parsed.every(v => typeof v === 'number')) {
|
||||
if (
|
||||
Array.isArray(parsed) &&
|
||||
parsed.every(v => typeof v === "number")
|
||||
) {
|
||||
setEditedEmbedding(parsed);
|
||||
}
|
||||
} catch {
|
||||
}
|
||||
} catch {}
|
||||
}}
|
||||
className="w-full h-32 p-3 border rounded-md resize-none font-mono text-xs"
|
||||
placeholder="Enter embedding as JSON array..."
|
||||
|
@ -259,8 +297,15 @@ export default function ContentDetailPage() {
|
|||
</div>
|
||||
<div className="p-3 bg-gray-50 dark:bg-gray-800 rounded-md max-h-32 overflow-y-auto">
|
||||
<pre className="whitespace-pre-wrap font-mono text-xs text-gray-900 dark:text-gray-100">
|
||||
[{content.embedding.slice(0, 20).map(v => v.toFixed(6)).join(', ')}
|
||||
{content.embedding.length > 20 ? `\n... and ${content.embedding.length - 20} more values` : ''}]
|
||||
[
|
||||
{content.embedding
|
||||
.slice(0, 20)
|
||||
.map(v => v.toFixed(6))
|
||||
.join(", ")}
|
||||
{content.embedding.length > 20
|
||||
? `\n... and ${content.embedding.length - 20} more values`
|
||||
: ""}
|
||||
]
|
||||
</pre>
|
||||
</div>
|
||||
</div>
|
||||
|
@ -284,7 +329,7 @@ export default function ContentDetailPage() {
|
|||
<div key={key} className="flex gap-2">
|
||||
<Input
|
||||
value={key}
|
||||
onChange={(e) => {
|
||||
onChange={e => {
|
||||
const newMetadata = { ...editedMetadata };
|
||||
delete newMetadata[key];
|
||||
newMetadata[e.target.value] = value;
|
||||
|
@ -294,11 +339,13 @@ export default function ContentDetailPage() {
|
|||
className="flex-1"
|
||||
/>
|
||||
<Input
|
||||
value={typeof value === 'string' ? value : JSON.stringify(value)}
|
||||
onChange={(e) => {
|
||||
value={
|
||||
typeof value === "string" ? value : JSON.stringify(value)
|
||||
}
|
||||
onChange={e => {
|
||||
setEditedMetadata({
|
||||
...editedMetadata,
|
||||
[key]: e.target.value
|
||||
[key]: e.target.value,
|
||||
});
|
||||
}}
|
||||
placeholder="Value"
|
||||
|
@ -312,7 +359,7 @@ export default function ContentDetailPage() {
|
|||
onClick={() => {
|
||||
setEditedMetadata({
|
||||
...editedMetadata,
|
||||
['']: ''
|
||||
[""]: "",
|
||||
});
|
||||
}}
|
||||
>
|
||||
|
@ -325,7 +372,7 @@ export default function ContentDetailPage() {
|
|||
<div key={key} className="flex justify-between py-1">
|
||||
<span className="font-medium text-gray-600">{key}:</span>
|
||||
<span className="font-mono text-sm">
|
||||
{typeof value === 'string' ? value : JSON.stringify(value)}
|
||||
{typeof value === "string" ? value : JSON.stringify(value)}
|
||||
</span>
|
||||
</div>
|
||||
))}
|
||||
|
@ -351,15 +398,15 @@ export default function ContentDetailPage() {
|
|||
value={`${getTextFromContent(content.content).length} chars`}
|
||||
/>
|
||||
{content.metadata.chunk_window && (
|
||||
<PropertyItem
|
||||
label="Position"
|
||||
value={content.metadata.chunk_window}
|
||||
/>
|
||||
<PropertyItem label="Position" value={content.metadata.chunk_window} />
|
||||
)}
|
||||
{file && (
|
||||
<>
|
||||
<PropertyItem label="File Status" value={file.status} />
|
||||
<PropertyItem label="File Usage" value={`${file.usage_bytes} bytes`} />
|
||||
<PropertyItem
|
||||
label="File Usage"
|
||||
value={`${file.usage_bytes} bytes`}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
{store && (
|
||||
|
|
|
@ -0,0 +1,481 @@
|
|||
import React from "react";
|
||||
import {
|
||||
render,
|
||||
screen,
|
||||
fireEvent,
|
||||
waitFor,
|
||||
act,
|
||||
} from "@testing-library/react";
|
||||
import "@testing-library/jest-dom";
|
||||
import ContentsListPage from "./page";
|
||||
import { VectorStoreContentItem } from "@/lib/contents-api";
|
||||
import type { VectorStore } from "llama-stack-client/resources/vector-stores/vector-stores";
|
||||
import type { VectorStoreFile } from "llama-stack-client/resources/vector-stores/files";
|
||||
|
||||
const mockPush = jest.fn();
|
||||
const mockParams = {
|
||||
id: "vs_123",
|
||||
fileId: "file_456",
|
||||
};
|
||||
|
||||
jest.mock("next/navigation", () => ({
|
||||
useParams: () => mockParams,
|
||||
useRouter: () => ({
|
||||
push: mockPush,
|
||||
}),
|
||||
}));
|
||||
|
||||
const mockClient = {
|
||||
vectorStores: {
|
||||
retrieve: jest.fn(),
|
||||
files: {
|
||||
retrieve: jest.fn(),
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
jest.mock("@/hooks/use-auth-client", () => ({
|
||||
useAuthClient: () => mockClient,
|
||||
}));
|
||||
|
||||
const mockContentsAPI = {
|
||||
listContents: jest.fn(),
|
||||
deleteContent: jest.fn(),
|
||||
};
|
||||
|
||||
jest.mock("@/lib/contents-api", () => ({
|
||||
ContentsAPI: jest.fn(() => mockContentsAPI),
|
||||
}));
|
||||
|
||||
describe("ContentsListPage", () => {
|
||||
const mockStore: VectorStore = {
|
||||
id: "vs_123",
|
||||
name: "Test Vector Store",
|
||||
created_at: 1710000000,
|
||||
status: "ready",
|
||||
file_counts: { total: 5 },
|
||||
usage_bytes: 1024,
|
||||
metadata: {
|
||||
provider_id: "test_provider",
|
||||
},
|
||||
};
|
||||
|
||||
const mockFile: VectorStoreFile = {
|
||||
id: "file_456",
|
||||
status: "completed",
|
||||
created_at: 1710001000,
|
||||
usage_bytes: 512,
|
||||
chunking_strategy: { type: "fixed_size" },
|
||||
};
|
||||
|
||||
const mockContents: VectorStoreContentItem[] = [
|
||||
{
|
||||
id: "content_1",
|
||||
object: "vector_store.content",
|
||||
content: "First piece of content for testing.",
|
||||
embedding: [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
metadata: {
|
||||
chunk_window: "0-35",
|
||||
content_length: 35,
|
||||
},
|
||||
created_timestamp: 1710002000,
|
||||
},
|
||||
{
|
||||
id: "content_2",
|
||||
object: "vector_store.content",
|
||||
content:
|
||||
"Second piece of content with longer text for testing truncation and display.",
|
||||
embedding: [0.6, 0.7, 0.8],
|
||||
metadata: {
|
||||
chunk_window: "36-95",
|
||||
content_length: 85,
|
||||
},
|
||||
created_timestamp: 1710003000,
|
||||
},
|
||||
{
|
||||
id: "content_3",
|
||||
object: "vector_store.content",
|
||||
content: "Third content without embedding.",
|
||||
embedding: undefined,
|
||||
metadata: {
|
||||
content_length: 33,
|
||||
},
|
||||
created_timestamp: 1710004000,
|
||||
},
|
||||
];
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
|
||||
mockClient.vectorStores.retrieve.mockResolvedValue(mockStore);
|
||||
mockClient.vectorStores.files.retrieve.mockResolvedValue(mockFile);
|
||||
mockContentsAPI.listContents.mockResolvedValue({
|
||||
data: mockContents,
|
||||
});
|
||||
});
|
||||
|
||||
describe("Loading and Error States", () => {
|
||||
test("renders loading skeleton while fetching store data", async () => {
|
||||
mockClient.vectorStores.retrieve.mockImplementation(
|
||||
() => new Promise(() => {})
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
render(<ContentsListPage />);
|
||||
});
|
||||
|
||||
const skeletons = document.querySelectorAll('[data-slot="skeleton"]');
|
||||
expect(skeletons.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
test("renders error message when store API call fails", async () => {
|
||||
const error = new Error("Failed to load store");
|
||||
mockClient.vectorStores.retrieve.mockRejectedValue(error);
|
||||
|
||||
await act(async () => {
|
||||
render(<ContentsListPage />);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText(/Error loading details for ID vs_123/)
|
||||
).toBeInTheDocument();
|
||||
expect(screen.getByText(/Failed to load store/)).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("renders not found when store doesn't exist", async () => {
|
||||
mockClient.vectorStores.retrieve.mockResolvedValue(null);
|
||||
|
||||
await act(async () => {
|
||||
render(<ContentsListPage />);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText(/No details found for ID: vs_123/)
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("renders contents loading skeleton", async () => {
|
||||
mockContentsAPI.listContents.mockImplementation(
|
||||
() => new Promise(() => {})
|
||||
);
|
||||
|
||||
const { container } = render(<ContentsListPage />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("Contents in File: file_456")
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
const skeletons = container.querySelectorAll('[data-slot="skeleton"]');
|
||||
expect(skeletons.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
test("renders contents error message", async () => {
|
||||
const error = new Error("Failed to load contents");
|
||||
mockContentsAPI.listContents.mockRejectedValue(error);
|
||||
|
||||
render(<ContentsListPage />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("Error loading contents: Failed to load contents")
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("Contents Table Display", () => {
|
||||
test("renders contents table with correct headers", async () => {
|
||||
render(<ContentsListPage />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("Content Chunks (3)")).toBeInTheDocument();
|
||||
expect(screen.getByText("Contents in this file")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Check table headers
|
||||
expect(screen.getByText("Content ID")).toBeInTheDocument();
|
||||
expect(screen.getByText("Content Preview")).toBeInTheDocument();
|
||||
expect(screen.getByText("Embedding")).toBeInTheDocument();
|
||||
expect(screen.getByText("Position")).toBeInTheDocument();
|
||||
expect(screen.getByText("Created")).toBeInTheDocument();
|
||||
expect(screen.getByText("Actions")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("renders content data correctly", async () => {
|
||||
render(<ContentsListPage />);
|
||||
|
||||
await waitFor(() => {
|
||||
// Check first content row
|
||||
expect(screen.getByText("content_1...")).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText("First piece of content for testing.")
|
||||
).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText("[0.100, 0.200, 0.300...] (5D)")
|
||||
).toBeInTheDocument();
|
||||
expect(screen.getByText("0-35")).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText(new Date(1710002000 * 1000).toLocaleString())
|
||||
).toBeInTheDocument();
|
||||
|
||||
expect(screen.getByText("content_2...")).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText(/Second piece of content with longer text/)
|
||||
).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText("[0.600, 0.700, 0.800...] (3D)")
|
||||
).toBeInTheDocument();
|
||||
expect(screen.getByText("36-95")).toBeInTheDocument();
|
||||
|
||||
expect(screen.getByText("content_3...")).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText("Third content without embedding.")
|
||||
).toBeInTheDocument();
|
||||
expect(screen.getByText("No embedding")).toBeInTheDocument();
|
||||
expect(screen.getByText("33 chars")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("handles empty contents list", async () => {
|
||||
mockContentsAPI.listContents.mockResolvedValue({
|
||||
data: [],
|
||||
});
|
||||
|
||||
render(<ContentsListPage />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("Content Chunks (0)")).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText("No contents found for this file.")
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("truncates long content IDs", async () => {
|
||||
const longIdContent = {
|
||||
...mockContents[0],
|
||||
id: "very_long_content_id_that_should_be_truncated_123456789",
|
||||
};
|
||||
|
||||
mockContentsAPI.listContents.mockResolvedValue({
|
||||
data: [longIdContent],
|
||||
});
|
||||
|
||||
render(<ContentsListPage />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("very_long_...")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("Content Navigation", () => {
|
||||
test("navigates to content detail when content ID is clicked", async () => {
|
||||
render(<ContentsListPage />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("content_1...")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
const contentLink = screen.getByRole("button", { name: "content_1..." });
|
||||
fireEvent.click(contentLink);
|
||||
|
||||
expect(mockPush).toHaveBeenCalledWith(
|
||||
"/logs/vector-stores/vs_123/files/file_456/contents/content_1"
|
||||
);
|
||||
});
|
||||
|
||||
test("navigates to content detail when view button is clicked", async () => {
|
||||
render(<ContentsListPage />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("Content Chunks (3)")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
const viewButtons = screen.getAllByTitle("View content details");
|
||||
fireEvent.click(viewButtons[0]);
|
||||
|
||||
expect(mockPush).toHaveBeenCalledWith(
|
||||
"/logs/vector-stores/vs_123/files/file_456/contents/content_1"
|
||||
);
|
||||
});
|
||||
|
||||
test("navigates to content detail when edit button is clicked", async () => {
|
||||
render(<ContentsListPage />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("Content Chunks (3)")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
const editButtons = screen.getAllByTitle("Edit content");
|
||||
fireEvent.click(editButtons[0]);
|
||||
|
||||
expect(mockPush).toHaveBeenCalledWith(
|
||||
"/logs/vector-stores/vs_123/files/file_456/contents/content_1"
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("Content Deletion", () => {
|
||||
test("deletes content when delete button is clicked", async () => {
|
||||
mockContentsAPI.deleteContent.mockResolvedValue(undefined);
|
||||
|
||||
render(<ContentsListPage />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("Content Chunks (3)")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
const deleteButtons = screen.getAllByTitle("Delete content");
|
||||
fireEvent.click(deleteButtons[0]);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockContentsAPI.deleteContent).toHaveBeenCalledWith(
|
||||
"vs_123",
|
||||
"file_456",
|
||||
"content_1"
|
||||
);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("Content Chunks (2)")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
expect(screen.queryByText("content_1...")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("handles delete error gracefully", async () => {
|
||||
const consoleError = jest
|
||||
.spyOn(console, "error")
|
||||
.mockImplementation(() => {});
|
||||
mockContentsAPI.deleteContent.mockRejectedValue(
|
||||
new Error("Delete failed")
|
||||
);
|
||||
|
||||
render(<ContentsListPage />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("Content Chunks (3)")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
const deleteButtons = screen.getAllByTitle("Delete content");
|
||||
fireEvent.click(deleteButtons[0]);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(consoleError).toHaveBeenCalledWith(
|
||||
"Failed to delete content:",
|
||||
expect.any(Error)
|
||||
);
|
||||
});
|
||||
|
||||
expect(screen.getByText("Content Chunks (3)")).toBeInTheDocument();
|
||||
expect(screen.getByText("content_1...")).toBeInTheDocument();
|
||||
|
||||
consoleError.mockRestore();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Breadcrumb Navigation", () => {
|
||||
test("renders correct breadcrumb structure", async () => {
|
||||
render(<ContentsListPage />);
|
||||
|
||||
await waitFor(() => {
|
||||
const vectorStoreTexts = screen.getAllByText("Vector Stores");
|
||||
expect(vectorStoreTexts.length).toBeGreaterThan(0);
|
||||
const storeNameTexts = screen.getAllByText("Test Vector Store");
|
||||
expect(storeNameTexts.length).toBeGreaterThan(0);
|
||||
const filesTexts = screen.getAllByText("Files");
|
||||
expect(filesTexts.length).toBeGreaterThan(0);
|
||||
const fileIdTexts = screen.getAllByText("file_456");
|
||||
expect(fileIdTexts.length).toBeGreaterThan(0);
|
||||
const contentsTexts = screen.getAllByText("Contents");
|
||||
expect(contentsTexts.length).toBeGreaterThan(0);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("Sidebar Properties", () => {
|
||||
test("renders file and store properties", async () => {
|
||||
render(<ContentsListPage />);
|
||||
|
||||
await waitFor(() => {
|
||||
const fileIdTexts = screen.getAllByText("file_456");
|
||||
expect(fileIdTexts.length).toBeGreaterThan(0);
|
||||
const storeIdTexts = screen.getAllByText("vs_123");
|
||||
expect(storeIdTexts.length).toBeGreaterThan(0);
|
||||
const storeNameTexts = screen.getAllByText("Test Vector Store");
|
||||
expect(storeNameTexts.length).toBeGreaterThan(0);
|
||||
|
||||
expect(screen.getByText("completed")).toBeInTheDocument();
|
||||
expect(screen.getByText("512")).toBeInTheDocument();
|
||||
expect(screen.getByText("fixed_size")).toBeInTheDocument();
|
||||
expect(screen.getByText("test_provider")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("Content Text Utilities", () => {
|
||||
test("handles different content formats correctly", async () => {
|
||||
const contentWithObject = {
|
||||
...mockContents[0],
|
||||
content: { type: "text", text: "Object format content" },
|
||||
};
|
||||
|
||||
mockContentsAPI.listContents.mockResolvedValue({
|
||||
data: [contentWithObject],
|
||||
});
|
||||
|
||||
render(<ContentsListPage />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("Object format content")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("handles string content format", async () => {
|
||||
const contentWithString = {
|
||||
...mockContents[0],
|
||||
content: "String format content",
|
||||
};
|
||||
|
||||
mockContentsAPI.listContents.mockResolvedValue({
|
||||
data: [contentWithString],
|
||||
});
|
||||
|
||||
render(<ContentsListPage />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("String format content")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("handles unknown content format", async () => {
|
||||
const contentWithUnknown = {
|
||||
...mockContents[0],
|
||||
content: { unknown: "format" },
|
||||
};
|
||||
|
||||
mockContentsAPI.listContents.mockResolvedValue({
|
||||
data: [contentWithUnknown],
|
||||
});
|
||||
|
||||
render(<ContentsListPage />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("Content Chunks (1)")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
const contentCells = screen.getAllByRole("cell");
|
||||
const contentPreviewCell = contentCells.find(cell =>
|
||||
cell.querySelector("p[title]")
|
||||
);
|
||||
expect(contentPreviewCell?.querySelector("p")?.textContent).toBe("");
|
||||
});
|
||||
});
|
||||
});
|
|
@ -18,7 +18,10 @@ import {
|
|||
PropertiesCard,
|
||||
PropertyItem,
|
||||
} from "@/components/layout/detail-layout";
|
||||
import { PageBreadcrumb, BreadcrumbSegment } from "@/components/layout/page-breadcrumb";
|
||||
import {
|
||||
PageBreadcrumb,
|
||||
BreadcrumbSegment,
|
||||
} from "@/components/layout/page-breadcrumb";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
|
@ -36,13 +39,13 @@ export default function ContentsListPage() {
|
|||
const fileId = params.fileId as string;
|
||||
const client = useAuthClient();
|
||||
|
||||
const getTextFromContent = (content: any): string => {
|
||||
if (typeof content === 'string') {
|
||||
const getTextFromContent = (content: unknown): string => {
|
||||
if (typeof content === "string") {
|
||||
return content;
|
||||
} else if (content && content.type === 'text') {
|
||||
} else if (content && content.type === "text") {
|
||||
return content.text;
|
||||
}
|
||||
return '';
|
||||
return "";
|
||||
};
|
||||
|
||||
const [store, setStore] = useState<VectorStore | null>(null);
|
||||
|
@ -65,7 +68,9 @@ export default function ContentsListPage() {
|
|||
const response = await client.vectorStores.retrieve(vectorStoreId);
|
||||
setStore(response as VectorStore);
|
||||
} catch (err) {
|
||||
setErrorStore(err instanceof Error ? err : new Error("Failed to load vector store."));
|
||||
setErrorStore(
|
||||
err instanceof Error ? err : new Error("Failed to load vector store.")
|
||||
);
|
||||
} finally {
|
||||
setIsLoadingStore(false);
|
||||
}
|
||||
|
@ -80,10 +85,15 @@ export default function ContentsListPage() {
|
|||
setIsLoadingFile(true);
|
||||
setErrorFile(null);
|
||||
try {
|
||||
const response = await client.vectorStores.files.retrieve(vectorStoreId, fileId);
|
||||
const response = await client.vectorStores.files.retrieve(
|
||||
vectorStoreId,
|
||||
fileId
|
||||
);
|
||||
setFile(response as VectorStoreFile);
|
||||
} catch (err) {
|
||||
setErrorFile(err instanceof Error ? err : new Error("Failed to load file."));
|
||||
setErrorFile(
|
||||
err instanceof Error ? err : new Error("Failed to load file.")
|
||||
);
|
||||
} finally {
|
||||
setIsLoadingFile(false);
|
||||
}
|
||||
|
@ -99,10 +109,16 @@ export default function ContentsListPage() {
|
|||
setErrorContents(null);
|
||||
try {
|
||||
const contentsAPI = new ContentsAPI(client);
|
||||
const contentsResponse = await contentsAPI.listContents(vectorStoreId, fileId, { limit: 100 });
|
||||
const contentsResponse = await contentsAPI.listContents(
|
||||
vectorStoreId,
|
||||
fileId,
|
||||
{ limit: 100 }
|
||||
);
|
||||
setContents(contentsResponse.data);
|
||||
} catch (err) {
|
||||
setErrorContents(err instanceof Error ? err : new Error("Failed to load contents."));
|
||||
setErrorContents(
|
||||
err instanceof Error ? err : new Error("Failed to load contents.")
|
||||
);
|
||||
} finally {
|
||||
setIsLoadingContents(false);
|
||||
}
|
||||
|
@ -116,26 +132,36 @@ export default function ContentsListPage() {
|
|||
await contentsAPI.deleteContent(vectorStoreId, fileId, contentId);
|
||||
setContents(contents.filter(content => content.id !== contentId));
|
||||
} catch (err) {
|
||||
console.error('Failed to delete content:', err);
|
||||
console.error("Failed to delete content:", err);
|
||||
}
|
||||
};
|
||||
|
||||
const handleViewContent = (contentId: string) => {
|
||||
router.push(`/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents/${contentId}`);
|
||||
router.push(
|
||||
`/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents/${contentId}`
|
||||
);
|
||||
};
|
||||
|
||||
const title = `Contents in File: ${fileId}`;
|
||||
|
||||
const breadcrumbSegments: BreadcrumbSegment[] = [
|
||||
{ label: "Vector Stores", href: "/logs/vector-stores" },
|
||||
{ label: store?.name || vectorStoreId, href: `/logs/vector-stores/${vectorStoreId}` },
|
||||
{
|
||||
label: store?.name || vectorStoreId,
|
||||
href: `/logs/vector-stores/${vectorStoreId}`,
|
||||
},
|
||||
{ label: "Files", href: `/logs/vector-stores/${vectorStoreId}` },
|
||||
{ label: fileId, href: `/logs/vector-stores/${vectorStoreId}/files/${fileId}` },
|
||||
{
|
||||
label: fileId,
|
||||
href: `/logs/vector-stores/${vectorStoreId}/files/${fileId}`,
|
||||
},
|
||||
{ label: "Contents" },
|
||||
];
|
||||
|
||||
if (errorStore) {
|
||||
return <DetailErrorView title={title} id={vectorStoreId} error={errorStore} />;
|
||||
return (
|
||||
<DetailErrorView title={title} id={vectorStoreId} error={errorStore} />
|
||||
);
|
||||
}
|
||||
if (isLoadingStore) {
|
||||
return <DetailLoadingView title={title} />;
|
||||
|
@ -151,7 +177,13 @@ export default function ContentsListPage() {
|
|||
<CardTitle>Content Chunks ({contents.length})</CardTitle>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
{isLoadingContents ? (
|
||||
{isLoadingFile ? (
|
||||
<Skeleton className="h-4 w-full" />
|
||||
) : errorFile ? (
|
||||
<div className="text-destructive text-sm">
|
||||
Error loading file: {errorFile.message}
|
||||
</div>
|
||||
) : isLoadingContents ? (
|
||||
<div className="space-y-2">
|
||||
<Skeleton className="h-4 w-full" />
|
||||
<Skeleton className="h-4 w-3/4" />
|
||||
|
@ -175,7 +207,7 @@ export default function ContentsListPage() {
|
|||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{contents.map((content) => (
|
||||
{contents.map(content => (
|
||||
<TableRow key={content.id}>
|
||||
<TableCell className="font-mono text-xs">
|
||||
<Button
|
||||
|
@ -189,7 +221,10 @@ export default function ContentsListPage() {
|
|||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="max-w-md">
|
||||
<p className="text-sm truncate" title={getTextFromContent(content.content)}>
|
||||
<p
|
||||
className="text-sm truncate"
|
||||
title={getTextFromContent(content.content)}
|
||||
>
|
||||
{getTextFromContent(content.content)}
|
||||
</p>
|
||||
</div>
|
||||
|
@ -197,12 +232,25 @@ export default function ContentsListPage() {
|
|||
<TableCell className="text-xs text-gray-500">
|
||||
{content.embedding && content.embedding.length > 0 ? (
|
||||
<div className="max-w-xs">
|
||||
<span className="font-mono text-xs bg-gray-100 dark:bg-gray-800 rounded px-1 py-0.5" title={`${content.embedding.length}D vector: [${content.embedding.slice(0, 3).map(v => v.toFixed(3)).join(', ')}...]`}>
|
||||
[{content.embedding.slice(0, 3).map(v => v.toFixed(3)).join(', ')}...] ({content.embedding.length}D)
|
||||
<span
|
||||
className="font-mono text-xs bg-gray-100 dark:bg-gray-800 rounded px-1 py-0.5"
|
||||
title={`${content.embedding.length}D vector: [${content.embedding
|
||||
.slice(0, 3)
|
||||
.map(v => v.toFixed(3))
|
||||
.join(", ")}...]`}
|
||||
>
|
||||
[
|
||||
{content.embedding
|
||||
.slice(0, 3)
|
||||
.map(v => v.toFixed(3))
|
||||
.join(", ")}
|
||||
...] ({content.embedding.length}D)
|
||||
</span>
|
||||
</div>
|
||||
) : (
|
||||
<span className="text-gray-400 dark:text-gray-500 italic">No embedding</span>
|
||||
<span className="text-gray-400 dark:text-gray-500 italic">
|
||||
No embedding
|
||||
</span>
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell className="text-xs text-gray-500">
|
||||
|
@ -211,7 +259,9 @@ export default function ContentsListPage() {
|
|||
: `${content.metadata.content_length || 0} chars`}
|
||||
</TableCell>
|
||||
<TableCell className="text-xs">
|
||||
{new Date(content.created_timestamp * 1000).toLocaleString()}
|
||||
{new Date(
|
||||
content.created_timestamp * 1000
|
||||
).toLocaleString()}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="flex gap-1">
|
||||
|
|
|
@ -0,0 +1,458 @@
|
|||
import React from "react";
|
||||
import {
|
||||
render,
|
||||
screen,
|
||||
fireEvent,
|
||||
waitFor,
|
||||
act,
|
||||
} from "@testing-library/react";
|
||||
import "@testing-library/jest-dom";
|
||||
import FileDetailPage from "./page";
|
||||
import type { VectorStore } from "llama-stack-client/resources/vector-stores/vector-stores";
|
||||
import type {
|
||||
VectorStoreFile,
|
||||
FileContentResponse,
|
||||
} from "llama-stack-client/resources/vector-stores/files";
|
||||
|
||||
const mockPush = jest.fn();
|
||||
const mockParams = {
|
||||
id: "vs_123",
|
||||
fileId: "file_456",
|
||||
};
|
||||
|
||||
jest.mock("next/navigation", () => ({
|
||||
useParams: () => mockParams,
|
||||
useRouter: () => ({
|
||||
push: mockPush,
|
||||
}),
|
||||
}));
|
||||
|
||||
const mockClient = {
|
||||
vectorStores: {
|
||||
retrieve: jest.fn(),
|
||||
files: {
|
||||
retrieve: jest.fn(),
|
||||
content: jest.fn(),
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
jest.mock("@/hooks/use-auth-client", () => ({
|
||||
useAuthClient: () => mockClient,
|
||||
}));
|
||||
|
||||
describe("FileDetailPage", () => {
|
||||
const mockStore: VectorStore = {
|
||||
id: "vs_123",
|
||||
name: "Test Vector Store",
|
||||
created_at: 1710000000,
|
||||
status: "ready",
|
||||
file_counts: { total: 5 },
|
||||
usage_bytes: 1024,
|
||||
metadata: {
|
||||
provider_id: "test_provider",
|
||||
},
|
||||
};
|
||||
|
||||
const mockFile: VectorStoreFile = {
|
||||
id: "file_456",
|
||||
status: "completed",
|
||||
created_at: 1710001000,
|
||||
usage_bytes: 2048,
|
||||
chunking_strategy: { type: "fixed_size" },
|
||||
};
|
||||
|
||||
const mockFileContent: FileContentResponse = {
|
||||
content: [
|
||||
{ text: "First chunk of file content." },
|
||||
{
|
||||
text: "Second chunk with more detailed information about the content.",
|
||||
},
|
||||
{ text: "Third and final chunk of the file." },
|
||||
],
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
|
||||
mockClient.vectorStores.retrieve.mockResolvedValue(mockStore);
|
||||
mockClient.vectorStores.files.retrieve.mockResolvedValue(mockFile);
|
||||
mockClient.vectorStores.files.content.mockResolvedValue(mockFileContent);
|
||||
});
|
||||
|
||||
describe("Loading and Error States", () => {
|
||||
test("renders loading skeleton while fetching store data", async () => {
|
||||
mockClient.vectorStores.retrieve.mockImplementation(
|
||||
() => new Promise(() => {})
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await act(async () => {
|
||||
render(<FileDetailPage />);
|
||||
});
|
||||
});
|
||||
|
||||
const skeletons = document.querySelectorAll('[data-slot="skeleton"]');
|
||||
expect(skeletons.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
test("renders error message when store API call fails", async () => {
|
||||
const error = new Error("Failed to load store");
|
||||
mockClient.vectorStores.retrieve.mockRejectedValue(error);
|
||||
|
||||
await act(async () => {
|
||||
await act(async () => {
|
||||
render(<FileDetailPage />);
|
||||
});
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText(/Error loading details for ID vs_123/)
|
||||
).toBeInTheDocument();
|
||||
expect(screen.getByText(/Failed to load store/)).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("renders not found when store doesn't exist", async () => {
|
||||
mockClient.vectorStores.retrieve.mockResolvedValue(null);
|
||||
|
||||
await act(async () => {
|
||||
render(<FileDetailPage />);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText(/No details found for ID: vs_123/)
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("renders file loading skeleton", async () => {
|
||||
mockClient.vectorStores.files.retrieve.mockImplementation(
|
||||
() => new Promise(() => {})
|
||||
);
|
||||
|
||||
const { container } = render(<FileDetailPage />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("File: file_456")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
const skeletons = container.querySelectorAll('[data-slot="skeleton"]');
|
||||
expect(skeletons.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
test("renders file error message", async () => {
|
||||
const error = new Error("Failed to load file");
|
||||
mockClient.vectorStores.files.retrieve.mockRejectedValue(error);
|
||||
|
||||
await act(async () => {
|
||||
render(<FileDetailPage />);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("Error loading file: Failed to load file")
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("renders content error message", async () => {
|
||||
const error = new Error("Failed to load contents");
|
||||
mockClient.vectorStores.files.content.mockRejectedValue(error);
|
||||
|
||||
await act(async () => {
|
||||
render(<FileDetailPage />);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText(
|
||||
"Error loading content summary: Failed to load contents"
|
||||
)
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("File Information Display", () => {
|
||||
test("renders file details correctly", async () => {
|
||||
await act(async () => {
|
||||
await act(async () => {
|
||||
render(<FileDetailPage />);
|
||||
});
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("File: file_456")).toBeInTheDocument();
|
||||
expect(screen.getByText("File Information")).toBeInTheDocument();
|
||||
expect(screen.getByText("File Details")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
const statusTexts = screen.getAllByText("Status:");
|
||||
expect(statusTexts.length).toBeGreaterThan(0);
|
||||
const completedTexts = screen.getAllByText("completed");
|
||||
expect(completedTexts.length).toBeGreaterThan(0);
|
||||
expect(screen.getByText("Size:")).toBeInTheDocument();
|
||||
expect(screen.getByText("2048 bytes")).toBeInTheDocument();
|
||||
const createdTexts = screen.getAllByText("Created:");
|
||||
expect(createdTexts.length).toBeGreaterThan(0);
|
||||
const dateTexts = screen.getAllByText(
|
||||
new Date(1710001000 * 1000).toLocaleString()
|
||||
);
|
||||
expect(dateTexts.length).toBeGreaterThan(0);
|
||||
const strategyTexts = screen.getAllByText("Content Strategy:");
|
||||
expect(strategyTexts.length).toBeGreaterThan(0);
|
||||
const fixedSizeTexts = screen.getAllByText("fixed_size");
|
||||
expect(fixedSizeTexts.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
test("handles missing file data", async () => {
|
||||
mockClient.vectorStores.files.retrieve.mockResolvedValue(null);
|
||||
|
||||
await act(async () => {
|
||||
render(<FileDetailPage />);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("File not found.")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("Content Summary Display", () => {
|
||||
test("renders content summary correctly", async () => {
|
||||
await act(async () => {
|
||||
render(<FileDetailPage />);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("Content Summary")).toBeInTheDocument();
|
||||
expect(screen.getByText("Content Items:")).toBeInTheDocument();
|
||||
expect(screen.getByText("3")).toBeInTheDocument();
|
||||
expect(screen.getByText("Total Characters:")).toBeInTheDocument();
|
||||
|
||||
const totalChars = mockFileContent.content.reduce(
|
||||
(total, item) => total + item.text.length,
|
||||
0
|
||||
);
|
||||
expect(screen.getByText(totalChars.toString())).toBeInTheDocument();
|
||||
|
||||
expect(screen.getByText("Preview:")).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText(/First chunk of file content\./)
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("handles empty content", async () => {
|
||||
mockClient.vectorStores.files.content.mockResolvedValue({
|
||||
content: [],
|
||||
});
|
||||
|
||||
await act(async () => {
|
||||
render(<FileDetailPage />);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("No contents found for this file.")
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("truncates long content preview", async () => {
|
||||
const longContent = {
|
||||
content: [
|
||||
{
|
||||
text: "This is a very long piece of content that should be truncated after 200 characters to ensure the preview doesn't take up too much space in the UI and remains readable and manageable for users viewing the file details page.",
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
mockClient.vectorStores.files.content.mockResolvedValue(longContent);
|
||||
|
||||
await act(async () => {
|
||||
render(<FileDetailPage />);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText(/This is a very long piece of content/)
|
||||
).toBeInTheDocument();
|
||||
expect(screen.getByText(/\.\.\.$/)).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("Navigation and Actions", () => {
|
||||
test("navigates to contents list when View Contents button is clicked", async () => {
|
||||
await act(async () => {
|
||||
render(<FileDetailPage />);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("Actions")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
const viewContentsButton = screen.getByRole("button", {
|
||||
name: /View Contents/,
|
||||
});
|
||||
fireEvent.click(viewContentsButton);
|
||||
|
||||
expect(mockPush).toHaveBeenCalledWith(
|
||||
"/logs/vector-stores/vs_123/files/file_456/contents"
|
||||
);
|
||||
});
|
||||
|
||||
test("View Contents button is styled correctly", async () => {
|
||||
await act(async () => {
|
||||
render(<FileDetailPage />);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
const button = screen.getByRole("button", { name: /View Contents/ });
|
||||
expect(button).toHaveClass("flex", "items-center", "gap-2");
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("Breadcrumb Navigation", () => {
|
||||
test("renders correct breadcrumb structure", async () => {
|
||||
await act(async () => {
|
||||
render(<FileDetailPage />);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
const vectorStoresTexts = screen.getAllByText("Vector Stores");
|
||||
expect(vectorStoresTexts.length).toBeGreaterThan(0);
|
||||
const storeNameTexts = screen.getAllByText("Test Vector Store");
|
||||
expect(storeNameTexts.length).toBeGreaterThan(0);
|
||||
const filesTexts = screen.getAllByText("Files");
|
||||
expect(filesTexts.length).toBeGreaterThan(0);
|
||||
const fileIdTexts = screen.getAllByText("file_456");
|
||||
expect(fileIdTexts.length).toBeGreaterThan(0);
|
||||
});
|
||||
});
|
||||
|
||||
test("uses store ID when store name is not available", async () => {
|
||||
const storeWithoutName = { ...mockStore, name: "" };
|
||||
mockClient.vectorStores.retrieve.mockResolvedValue(storeWithoutName);
|
||||
|
||||
await act(async () => {
|
||||
render(<FileDetailPage />);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
const storeIdTexts = screen.getAllByText("vs_123");
|
||||
expect(storeIdTexts.length).toBeGreaterThan(0);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("Sidebar Properties", () => {
|
||||
test.skip("renders file and store properties correctly", async () => {
|
||||
await act(async () => {
|
||||
render(<FileDetailPage />);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("File ID")).toBeInTheDocument();
|
||||
const fileIdTexts = screen.getAllByText("file_456");
|
||||
expect(fileIdTexts.length).toBeGreaterThan(0);
|
||||
expect(screen.getByText("Vector Store ID")).toBeInTheDocument();
|
||||
const storeIdTexts = screen.getAllByText("vs_123");
|
||||
expect(storeIdTexts.length).toBeGreaterThan(0);
|
||||
expect(screen.getByText("Status")).toBeInTheDocument();
|
||||
const completedTexts = screen.getAllByText("completed");
|
||||
expect(completedTexts.length).toBeGreaterThan(0);
|
||||
expect(screen.getByText("Usage Bytes")).toBeInTheDocument();
|
||||
const usageTexts = screen.getAllByText("2048");
|
||||
expect(usageTexts.length).toBeGreaterThan(0);
|
||||
expect(screen.getByText("Content Strategy")).toBeInTheDocument();
|
||||
const fixedSizeTexts = screen.getAllByText("fixed_size");
|
||||
expect(fixedSizeTexts.length).toBeGreaterThan(0);
|
||||
|
||||
expect(screen.getByText("Store Name")).toBeInTheDocument();
|
||||
const storeNameTexts = screen.getAllByText("Test Vector Store");
|
||||
expect(storeNameTexts.length).toBeGreaterThan(0);
|
||||
expect(screen.getByText("Provider ID")).toBeInTheDocument();
|
||||
expect(screen.getByText("test_provider")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("handles missing optional properties", async () => {
|
||||
const minimalFile = {
|
||||
id: "file_456",
|
||||
status: "completed",
|
||||
created_at: 1710001000,
|
||||
usage_bytes: 2048,
|
||||
chunking_strategy: { type: "fixed_size" },
|
||||
};
|
||||
|
||||
const minimalStore = {
|
||||
...mockStore,
|
||||
name: "",
|
||||
metadata: {},
|
||||
};
|
||||
|
||||
mockClient.vectorStores.files.retrieve.mockResolvedValue(minimalFile);
|
||||
mockClient.vectorStores.retrieve.mockResolvedValue(minimalStore);
|
||||
|
||||
await act(async () => {
|
||||
render(<FileDetailPage />);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
const fileIdTexts = screen.getAllByText("file_456");
|
||||
expect(fileIdTexts.length).toBeGreaterThan(0);
|
||||
const storeIdTexts = screen.getAllByText("vs_123");
|
||||
expect(storeIdTexts.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
expect(screen.getByText("File: file_456")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Loading States for Individual Sections", () => {
|
||||
test("shows loading skeleton for content while file loads", async () => {
|
||||
mockClient.vectorStores.files.content.mockImplementation(
|
||||
() => new Promise(() => {})
|
||||
);
|
||||
|
||||
const { container } = render(<FileDetailPage />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("Content Summary")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
const skeletons = container.querySelectorAll('[data-slot="skeleton"]');
|
||||
expect(skeletons.length).toBeGreaterThan(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe("Error Handling", () => {
|
||||
test("handles multiple simultaneous errors gracefully", async () => {
|
||||
mockClient.vectorStores.files.retrieve.mockRejectedValue(
|
||||
new Error("File error")
|
||||
);
|
||||
mockClient.vectorStores.files.content.mockRejectedValue(
|
||||
new Error("Content error")
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
render(<FileDetailPage />);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("Error loading file: File error")
|
||||
).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText("Error loading content summary: Content error")
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
|
@ -4,9 +4,12 @@ import { useEffect, useState } from "react";
|
|||
import { useParams, useRouter } from "next/navigation";
|
||||
import { useAuthClient } from "@/hooks/use-auth-client";
|
||||
import type { VectorStore } from "llama-stack-client/resources/vector-stores/vector-stores";
|
||||
import type { VectorStoreFile, FileContentResponse } from "llama-stack-client/resources/vector-stores/files";
|
||||
import type {
|
||||
VectorStoreFile,
|
||||
FileContentResponse,
|
||||
} from "llama-stack-client/resources/vector-stores/files";
|
||||
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
|
||||
import { Skeleton } from '@/components/ui/skeleton';
|
||||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { List } from "lucide-react";
|
||||
import {
|
||||
|
@ -17,7 +20,10 @@ import {
|
|||
PropertiesCard,
|
||||
PropertyItem,
|
||||
} from "@/components/layout/detail-layout";
|
||||
import { PageBreadcrumb, BreadcrumbSegment } from "@/components/layout/page-breadcrumb";
|
||||
import {
|
||||
PageBreadcrumb,
|
||||
BreadcrumbSegment,
|
||||
} from "@/components/layout/page-breadcrumb";
|
||||
|
||||
export default function FileDetailPage() {
|
||||
const params = useParams();
|
||||
|
@ -46,7 +52,9 @@ export default function FileDetailPage() {
|
|||
const response = await client.vectorStores.retrieve(vectorStoreId);
|
||||
setStore(response as VectorStore);
|
||||
} catch (err) {
|
||||
setErrorStore(err instanceof Error ? err : new Error("Failed to load vector store."));
|
||||
setErrorStore(
|
||||
err instanceof Error ? err : new Error("Failed to load vector store.")
|
||||
);
|
||||
} finally {
|
||||
setIsLoadingStore(false);
|
||||
}
|
||||
|
@ -61,10 +69,15 @@ export default function FileDetailPage() {
|
|||
setIsLoadingFile(true);
|
||||
setErrorFile(null);
|
||||
try {
|
||||
const response = await client.vectorStores.files.retrieve(vectorStoreId, fileId);
|
||||
const response = await client.vectorStores.files.retrieve(
|
||||
vectorStoreId,
|
||||
fileId
|
||||
);
|
||||
setFile(response as VectorStoreFile);
|
||||
} catch (err) {
|
||||
setErrorFile(err instanceof Error ? err : new Error("Failed to load file."));
|
||||
setErrorFile(
|
||||
err instanceof Error ? err : new Error("Failed to load file.")
|
||||
);
|
||||
} finally {
|
||||
setIsLoadingFile(false);
|
||||
}
|
||||
|
@ -79,10 +92,15 @@ export default function FileDetailPage() {
|
|||
setIsLoadingContents(true);
|
||||
setErrorContents(null);
|
||||
try {
|
||||
const response = await client.vectorStores.files.content(vectorStoreId, fileId);
|
||||
const response = await client.vectorStores.files.content(
|
||||
vectorStoreId,
|
||||
fileId
|
||||
);
|
||||
setContents(response);
|
||||
} catch (err) {
|
||||
setErrorContents(err instanceof Error ? err : new Error("Failed to load contents."));
|
||||
setErrorContents(
|
||||
err instanceof Error ? err : new Error("Failed to load contents.")
|
||||
);
|
||||
} finally {
|
||||
setIsLoadingContents(false);
|
||||
}
|
||||
|
@ -91,20 +109,27 @@ export default function FileDetailPage() {
|
|||
}, [vectorStoreId, fileId, client]);
|
||||
|
||||
const handleViewContents = () => {
|
||||
router.push(`/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents`);
|
||||
router.push(
|
||||
`/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents`
|
||||
);
|
||||
};
|
||||
|
||||
const title = `File: ${fileId}`;
|
||||
|
||||
const breadcrumbSegments: BreadcrumbSegment[] = [
|
||||
{ label: "Vector Stores", href: "/logs/vector-stores" },
|
||||
{ label: store?.name || vectorStoreId, href: `/logs/vector-stores/${vectorStoreId}` },
|
||||
{
|
||||
label: store?.name || vectorStoreId,
|
||||
href: `/logs/vector-stores/${vectorStoreId}`,
|
||||
},
|
||||
{ label: "Files", href: `/logs/vector-stores/${vectorStoreId}` },
|
||||
{ label: fileId },
|
||||
];
|
||||
|
||||
if (errorStore) {
|
||||
return <DetailErrorView title={title} id={vectorStoreId} error={errorStore} />;
|
||||
return (
|
||||
<DetailErrorView title={title} id={vectorStoreId} error={errorStore} />
|
||||
);
|
||||
}
|
||||
if (isLoadingStore) {
|
||||
return <DetailLoadingView title={title} />;
|
||||
|
@ -136,19 +161,29 @@ export default function FileDetailPage() {
|
|||
<h3 className="text-lg font-medium mb-2">File Details</h3>
|
||||
<div className="grid grid-cols-2 gap-4 text-sm">
|
||||
<div>
|
||||
<span className="font-medium text-gray-600 dark:text-gray-400">Status:</span>
|
||||
<span className="font-medium text-gray-600 dark:text-gray-400">
|
||||
Status:
|
||||
</span>
|
||||
<span className="ml-2">{file.status}</span>
|
||||
</div>
|
||||
<div>
|
||||
<span className="font-medium text-gray-600 dark:text-gray-400">Size:</span>
|
||||
<span className="font-medium text-gray-600 dark:text-gray-400">
|
||||
Size:
|
||||
</span>
|
||||
<span className="ml-2">{file.usage_bytes} bytes</span>
|
||||
</div>
|
||||
<div>
|
||||
<span className="font-medium text-gray-600 dark:text-gray-400">Created:</span>
|
||||
<span className="ml-2">{new Date(file.created_at * 1000).toLocaleString()}</span>
|
||||
<span className="font-medium text-gray-600 dark:text-gray-400">
|
||||
Created:
|
||||
</span>
|
||||
<span className="ml-2">
|
||||
{new Date(file.created_at * 1000).toLocaleString()}
|
||||
</span>
|
||||
</div>
|
||||
<div>
|
||||
<span className="font-medium text-gray-600 dark:text-gray-400">Content Strategy:</span>
|
||||
<span className="font-medium text-gray-600 dark:text-gray-400">
|
||||
Content Strategy:
|
||||
</span>
|
||||
<span className="ml-2">{file.chunking_strategy.type}</span>
|
||||
</div>
|
||||
</div>
|
||||
|
@ -166,9 +201,7 @@ export default function FileDetailPage() {
|
|||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<p className="text-gray-500 italic text-sm">
|
||||
File not found.
|
||||
</p>
|
||||
<p className="text-gray-500 italic text-sm">File not found.</p>
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
@ -192,16 +225,27 @@ export default function FileDetailPage() {
|
|||
<div className="space-y-3">
|
||||
<div className="grid grid-cols-2 gap-4 text-sm">
|
||||
<div>
|
||||
<span className="font-medium text-gray-600 dark:text-gray-400">Content Items:</span>
|
||||
<span className="font-medium text-gray-600 dark:text-gray-400">
|
||||
Content Items:
|
||||
</span>
|
||||
<span className="ml-2">{contents.content.length}</span>
|
||||
</div>
|
||||
<div>
|
||||
<span className="font-medium text-gray-600 dark:text-gray-400">Total Characters:</span>
|
||||
<span className="ml-2">{contents.content.reduce((total, item) => total + item.text.length, 0)}</span>
|
||||
<span className="font-medium text-gray-600 dark:text-gray-400">
|
||||
Total Characters:
|
||||
</span>
|
||||
<span className="ml-2">
|
||||
{contents.content.reduce(
|
||||
(total, item) => total + item.text.length,
|
||||
0
|
||||
)}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
<div className="pt-2">
|
||||
<span className="text-sm font-medium text-gray-600 dark:text-gray-400">Preview:</span>
|
||||
<span className="text-sm font-medium text-gray-600 dark:text-gray-400">
|
||||
Preview:
|
||||
</span>
|
||||
<div className="mt-1 bg-gray-50 dark:bg-gray-800 rounded-md p-3">
|
||||
<p className="text-sm text-gray-900 dark:text-gray-100 line-clamp-3">
|
||||
{contents.content[0]?.text.substring(0, 200)}...
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
"use client";
|
||||
|
||||
import { useEffect, useState } from "react";
|
||||
import { useParams, useRouter } from "next/navigation";
|
||||
import { useParams } from "next/navigation";
|
||||
import { useAuthClient } from "@/hooks/use-auth-client";
|
||||
import type { VectorStore } from "llama-stack-client/resources/vector-stores/vector-stores";
|
||||
import type { VectorStoreFile } from "llama-stack-client/resources/vector-stores/files";
|
||||
|
@ -11,7 +11,6 @@ export default function VectorStoreDetailPage() {
|
|||
const params = useParams();
|
||||
const id = params.id as string;
|
||||
const client = useAuthClient();
|
||||
const router = useRouter();
|
||||
|
||||
const [store, setStore] = useState<VectorStore | null>(null);
|
||||
const [files, setFiles] = useState<VectorStoreFile[]>([]);
|
||||
|
@ -34,9 +33,7 @@ export default function VectorStoreDetailPage() {
|
|||
setStore(response as VectorStore);
|
||||
} catch (err) {
|
||||
setErrorStore(
|
||||
err instanceof Error
|
||||
? err
|
||||
: new Error("Failed to load vector store."),
|
||||
err instanceof Error ? err : new Error("Failed to load vector store.")
|
||||
);
|
||||
} finally {
|
||||
setIsLoadingStore(false);
|
||||
|
@ -55,18 +52,18 @@ export default function VectorStoreDetailPage() {
|
|||
setIsLoadingFiles(true);
|
||||
setErrorFiles(null);
|
||||
try {
|
||||
const result = await client.vectorStores.files.list(id as any);
|
||||
setFiles((result as any).data);
|
||||
const result = await client.vectorStores.files.list(id);
|
||||
setFiles((result as { data: VectorStoreFile[] }).data);
|
||||
} catch (err) {
|
||||
setErrorFiles(
|
||||
err instanceof Error ? err : new Error("Failed to load files."),
|
||||
err instanceof Error ? err : new Error("Failed to load files.")
|
||||
);
|
||||
} finally {
|
||||
setIsLoadingFiles(false);
|
||||
}
|
||||
};
|
||||
fetchFiles();
|
||||
}, [id]);
|
||||
}, [id, client.vectorStores.files]);
|
||||
|
||||
return (
|
||||
<VectorStoreDetailView
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
"use client";
|
||||
|
||||
import React from "react";
|
||||
import { useAuthClient } from "@/hooks/use-auth-client";
|
||||
import type {
|
||||
ListVectorStoresResponse,
|
||||
VectorStore,
|
||||
|
@ -12,7 +11,6 @@ import { Button } from "@/components/ui/button";
|
|||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCaption,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
|
@ -21,7 +19,6 @@ import {
|
|||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
|
||||
export default function VectorStoresPage() {
|
||||
const client = useAuthClient();
|
||||
const router = useRouter();
|
||||
const {
|
||||
data: stores,
|
||||
|
@ -37,7 +34,7 @@ export default function VectorStoresPage() {
|
|||
after: params.after,
|
||||
limit: params.limit,
|
||||
order: params.order,
|
||||
} as any);
|
||||
} as Parameters<typeof client.vectorStores.list>[0]);
|
||||
return response as ListVectorStoresResponse;
|
||||
},
|
||||
errorMessagePrefix: "vector stores",
|
||||
|
@ -53,11 +50,11 @@ export default function VectorStoresPage() {
|
|||
const renderContent = () => {
|
||||
if (status === "loading") {
|
||||
return (
|
||||
<div className="space-y-2">
|
||||
<Skeleton className="h-8 w-full"/>
|
||||
<Skeleton className="h-4 w-full"/>
|
||||
<Skeleton className="h-4 w-full"/>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<Skeleton className="h-8 w-full" />
|
||||
<Skeleton className="h-4 w-full" />
|
||||
<Skeleton className="h-4 w-full" />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -70,72 +67,72 @@ export default function VectorStoresPage() {
|
|||
}
|
||||
|
||||
return (
|
||||
<div className="overflow-auto flex-1 min-h-0">
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>ID</TableHead>
|
||||
<TableHead>Name</TableHead>
|
||||
<TableHead>Created</TableHead>
|
||||
<TableHead>Completed</TableHead>
|
||||
<TableHead>Cancelled</TableHead>
|
||||
<TableHead>Failed</TableHead>
|
||||
<TableHead>In Progress</TableHead>
|
||||
<TableHead>Total</TableHead>
|
||||
<TableHead>Usage Bytes</TableHead>
|
||||
<TableHead>Provider ID</TableHead>
|
||||
<TableHead>Provider Vector DB ID</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{stores.map((store) => {
|
||||
const fileCounts = store.file_counts;
|
||||
const metadata = store.metadata || {};
|
||||
const providerId = metadata.provider_id ?? "";
|
||||
const providerDbId = metadata.provider_vector_db_id ?? "";
|
||||
<div className="overflow-auto flex-1 min-h-0">
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>ID</TableHead>
|
||||
<TableHead>Name</TableHead>
|
||||
<TableHead>Created</TableHead>
|
||||
<TableHead>Completed</TableHead>
|
||||
<TableHead>Cancelled</TableHead>
|
||||
<TableHead>Failed</TableHead>
|
||||
<TableHead>In Progress</TableHead>
|
||||
<TableHead>Total</TableHead>
|
||||
<TableHead>Usage Bytes</TableHead>
|
||||
<TableHead>Provider ID</TableHead>
|
||||
<TableHead>Provider Vector DB ID</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{stores.map(store => {
|
||||
const fileCounts = store.file_counts;
|
||||
const metadata = store.metadata || {};
|
||||
const providerId = metadata.provider_id ?? "";
|
||||
const providerDbId = metadata.provider_vector_db_id ?? "";
|
||||
|
||||
return (
|
||||
<TableRow
|
||||
key={store.id}
|
||||
onClick={() => router.push(`/logs/vector-stores/${store.id}`)}
|
||||
className="cursor-pointer hover:bg-muted/50"
|
||||
return (
|
||||
<TableRow
|
||||
key={store.id}
|
||||
onClick={() => router.push(`/logs/vector-stores/${store.id}`)}
|
||||
className="cursor-pointer hover:bg-muted/50"
|
||||
>
|
||||
<TableCell>
|
||||
<Button
|
||||
variant="link"
|
||||
className="p-0 h-auto font-mono text-blue-600 hover:text-blue-800 dark:text-blue-400 dark:hover:text-blue-300"
|
||||
onClick={() =>
|
||||
router.push(`/logs/vector-stores/${store.id}`)
|
||||
}
|
||||
>
|
||||
<TableCell>
|
||||
<Button
|
||||
variant="link"
|
||||
className="p-0 h-auto font-mono text-blue-600 hover:text-blue-800 dark:text-blue-400 dark:hover:text-blue-300"
|
||||
onClick={() =>
|
||||
router.push(`/logs/vector-stores/${store.id}`)
|
||||
}
|
||||
>
|
||||
{store.id}
|
||||
</Button>
|
||||
</TableCell>
|
||||
<TableCell>{store.name}</TableCell>
|
||||
<TableCell>
|
||||
{new Date(store.created_at * 1000).toLocaleString()}
|
||||
</TableCell>
|
||||
<TableCell>{fileCounts.completed}</TableCell>
|
||||
<TableCell>{fileCounts.cancelled}</TableCell>
|
||||
<TableCell>{fileCounts.failed}</TableCell>
|
||||
<TableCell>{fileCounts.in_progress}</TableCell>
|
||||
<TableCell>{fileCounts.total}</TableCell>
|
||||
<TableCell>{store.usage_bytes}</TableCell>
|
||||
<TableCell>{providerId}</TableCell>
|
||||
<TableCell>{providerDbId}</TableCell>
|
||||
</TableRow>
|
||||
);
|
||||
})}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
{store.id}
|
||||
</Button>
|
||||
</TableCell>
|
||||
<TableCell>{store.name}</TableCell>
|
||||
<TableCell>
|
||||
{new Date(store.created_at * 1000).toLocaleString()}
|
||||
</TableCell>
|
||||
<TableCell>{fileCounts.completed}</TableCell>
|
||||
<TableCell>{fileCounts.cancelled}</TableCell>
|
||||
<TableCell>{fileCounts.failed}</TableCell>
|
||||
<TableCell>{fileCounts.in_progress}</TableCell>
|
||||
<TableCell>{fileCounts.total}</TableCell>
|
||||
<TableCell>{store.usage_bytes}</TableCell>
|
||||
<TableCell>{providerId}</TableCell>
|
||||
<TableCell>{providerDbId}</TableCell>
|
||||
</TableRow>
|
||||
);
|
||||
})}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
<h1 className="text-2xl font-semibold">Vector Stores</h1>
|
||||
{renderContent()}
|
||||
</div>
|
||||
<div className="space-y-4">
|
||||
<h1 className="text-2xl font-semibold">Vector Stores</h1>
|
||||
{renderContent()}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
|
@ -14,7 +14,7 @@ describe("ChatCompletionDetailView", () => {
|
|||
isLoading={true}
|
||||
error={null}
|
||||
id="test-id"
|
||||
/>,
|
||||
/>
|
||||
);
|
||||
// Use the data-slot attribute for Skeletons
|
||||
const skeletons = container.querySelectorAll('[data-slot="skeleton"]');
|
||||
|
@ -28,10 +28,10 @@ describe("ChatCompletionDetailView", () => {
|
|||
isLoading={false}
|
||||
error={{ name: "Error", message: "Network Error" }}
|
||||
id="err-id"
|
||||
/>,
|
||||
/>
|
||||
);
|
||||
expect(
|
||||
screen.getByText(/Error loading details for ID err-id: Network Error/),
|
||||
screen.getByText(/Error loading details for ID err-id: Network Error/)
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
|
@ -42,11 +42,11 @@ describe("ChatCompletionDetailView", () => {
|
|||
isLoading={false}
|
||||
error={{ name: "Error", message: "" }}
|
||||
id="err-id"
|
||||
/>,
|
||||
/>
|
||||
);
|
||||
// Use regex to match the error message regardless of whitespace
|
||||
expect(
|
||||
screen.getByText(/Error loading details for ID\s*err-id\s*:/),
|
||||
screen.getByText(/Error loading details for ID\s*err-id\s*:/)
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
|
@ -57,11 +57,11 @@ describe("ChatCompletionDetailView", () => {
|
|||
isLoading={false}
|
||||
error={{} as Error}
|
||||
id="err-id"
|
||||
/>,
|
||||
/>
|
||||
);
|
||||
// Use regex to match the error message regardless of whitespace
|
||||
expect(
|
||||
screen.getByText(/Error loading details for ID\s*err-id\s*:/),
|
||||
screen.getByText(/Error loading details for ID\s*err-id\s*:/)
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
|
@ -72,10 +72,10 @@ describe("ChatCompletionDetailView", () => {
|
|||
isLoading={false}
|
||||
error={null}
|
||||
id="notfound-id"
|
||||
/>,
|
||||
/>
|
||||
);
|
||||
expect(
|
||||
screen.getByText("No details found for ID: notfound-id."),
|
||||
screen.getByText("No details found for ID: notfound-id.")
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
|
@ -100,7 +100,7 @@ describe("ChatCompletionDetailView", () => {
|
|||
isLoading={false}
|
||||
error={null}
|
||||
id={mockCompletion.id}
|
||||
/>,
|
||||
/>
|
||||
);
|
||||
// Input
|
||||
expect(screen.getByText("Input")).toBeInTheDocument();
|
||||
|
@ -112,7 +112,7 @@ describe("ChatCompletionDetailView", () => {
|
|||
expect(screen.getByText("Properties")).toBeInTheDocument();
|
||||
expect(screen.getByText("Created:")).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText(new Date(1710000000 * 1000).toLocaleString()),
|
||||
screen.getByText(new Date(1710000000 * 1000).toLocaleString())
|
||||
).toBeInTheDocument();
|
||||
expect(screen.getByText("ID:")).toBeInTheDocument();
|
||||
expect(screen.getByText("comp_123")).toBeInTheDocument();
|
||||
|
@ -150,7 +150,7 @@ describe("ChatCompletionDetailView", () => {
|
|||
isLoading={false}
|
||||
error={null}
|
||||
id={mockCompletion.id}
|
||||
/>,
|
||||
/>
|
||||
);
|
||||
// Output should include the tool call block (should be present twice: input and output)
|
||||
const toolCallLabels = screen.getAllByText("Tool Call");
|
||||
|
@ -178,13 +178,13 @@ describe("ChatCompletionDetailView", () => {
|
|||
isLoading={false}
|
||||
error={null}
|
||||
id={mockCompletion.id}
|
||||
/>,
|
||||
/>
|
||||
);
|
||||
// Input section should be present but empty
|
||||
expect(screen.getByText("Input")).toBeInTheDocument();
|
||||
// Output section should show fallback message
|
||||
expect(
|
||||
screen.getByText("No message found in assistant's choice."),
|
||||
screen.getByText("No message found in assistant's choice.")
|
||||
).toBeInTheDocument();
|
||||
// Properties should show N/A for finish reason
|
||||
expect(screen.getByText("Finish Reason:")).toBeInTheDocument();
|
||||
|
|
|
@ -53,14 +53,14 @@ export function ChatCompletionDetailView({
|
|||
{completion.choices?.[0]?.message?.tool_calls &&
|
||||
Array.isArray(completion.choices[0].message.tool_calls) &&
|
||||
!completion.input_messages?.some(
|
||||
(im) =>
|
||||
im =>
|
||||
im.role === "assistant" &&
|
||||
im.tool_calls &&
|
||||
Array.isArray(im.tool_calls) &&
|
||||
im.tool_calls.length > 0,
|
||||
im.tool_calls.length > 0
|
||||
)
|
||||
? completion.choices[0].message.tool_calls.map(
|
||||
(toolCall: any, index: number) => {
|
||||
(toolCall: { function?: { name?: string } }, index: number) => {
|
||||
const assistantToolCallMessage: ChatMessage = {
|
||||
role: "assistant",
|
||||
tool_calls: [toolCall],
|
||||
|
@ -72,7 +72,7 @@ export function ChatCompletionDetailView({
|
|||
message={assistantToolCallMessage}
|
||||
/>
|
||||
);
|
||||
},
|
||||
}
|
||||
)
|
||||
: null}
|
||||
</CardContent>
|
||||
|
@ -89,7 +89,7 @@ export function ChatCompletionDetailView({
|
|||
/>
|
||||
) : (
|
||||
<p className="text-gray-500 italic text-sm">
|
||||
No message found in assistant's choice.
|
||||
No message found in assistant's choice.
|
||||
</p>
|
||||
)}
|
||||
</CardContent>
|
||||
|
@ -120,13 +120,18 @@ export function ChatCompletionDetailView({
|
|||
value={
|
||||
<div>
|
||||
<ul className="list-disc list-inside pl-4 mt-1">
|
||||
{toolCalls.map((toolCall: any, index: number) => (
|
||||
<li key={index}>
|
||||
<span className="text-gray-900 font-medium">
|
||||
{toolCall.function?.name || "N/A"}
|
||||
</span>
|
||||
</li>
|
||||
))}
|
||||
{toolCalls.map(
|
||||
(
|
||||
toolCall: { function?: { name?: string } },
|
||||
index: number
|
||||
) => (
|
||||
<li key={index}>
|
||||
<span className="text-gray-900 font-medium">
|
||||
{toolCall.function?.name || "N/A"}
|
||||
</span>
|
||||
</li>
|
||||
)
|
||||
)}
|
||||
</ul>
|
||||
</div>
|
||||
}
|
||||
|
|
|
@ -83,7 +83,7 @@ describe("ChatCompletionsTable", () => {
|
|||
// Default pass-through implementations
|
||||
truncateText.mockImplementation((text: string | undefined) => text);
|
||||
extractTextFromContentPart.mockImplementation((content: unknown) =>
|
||||
typeof content === "string" ? content : "extracted text",
|
||||
typeof content === "string" ? content : "extracted text"
|
||||
);
|
||||
extractDisplayableText.mockImplementation((message: unknown) => {
|
||||
const msg = message as { content?: string };
|
||||
|
@ -138,7 +138,7 @@ describe("ChatCompletionsTable", () => {
|
|||
if (row) {
|
||||
fireEvent.click(row);
|
||||
expect(mockPush).toHaveBeenCalledWith(
|
||||
"/logs/chat-completions/completion_123",
|
||||
"/logs/chat-completions/completion_123"
|
||||
);
|
||||
} else {
|
||||
throw new Error('Row with "Test prompt" not found for router mock test.');
|
||||
|
@ -162,7 +162,7 @@ describe("ChatCompletionsTable", () => {
|
|||
expect(tableCaption).toBeInTheDocument();
|
||||
if (tableCaption) {
|
||||
const captionSkeleton = tableCaption.querySelector(
|
||||
'[data-slot="skeleton"]',
|
||||
'[data-slot="skeleton"]'
|
||||
);
|
||||
expect(captionSkeleton).toBeInTheDocument();
|
||||
}
|
||||
|
@ -172,7 +172,7 @@ describe("ChatCompletionsTable", () => {
|
|||
expect(tableBody).toBeInTheDocument();
|
||||
if (tableBody) {
|
||||
const bodySkeletons = tableBody.querySelectorAll(
|
||||
'[data-slot="skeleton"]',
|
||||
'[data-slot="skeleton"]'
|
||||
);
|
||||
expect(bodySkeletons.length).toBeGreaterThan(0);
|
||||
}
|
||||
|
@ -192,14 +192,14 @@ describe("ChatCompletionsTable", () => {
|
|||
|
||||
render(<ChatCompletionsTable {...defaultProps} />);
|
||||
expect(
|
||||
screen.getByText("Unable to load chat completions"),
|
||||
screen.getByText("Unable to load chat completions")
|
||||
).toBeInTheDocument();
|
||||
expect(screen.getByText(errorMessage)).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test.each([{ name: "Error", message: "" }, {}])(
|
||||
"renders default error message when error has no message",
|
||||
(errorObject) => {
|
||||
errorObject => {
|
||||
mockedUsePagination.mockReturnValue({
|
||||
data: [],
|
||||
status: "error",
|
||||
|
@ -210,14 +210,14 @@ describe("ChatCompletionsTable", () => {
|
|||
|
||||
render(<ChatCompletionsTable {...defaultProps} />);
|
||||
expect(
|
||||
screen.getByText("Unable to load chat completions"),
|
||||
screen.getByText("Unable to load chat completions")
|
||||
).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText(
|
||||
"An unexpected error occurred while loading the data.",
|
||||
),
|
||||
"An unexpected error occurred while loading the data."
|
||||
)
|
||||
).toBeInTheDocument();
|
||||
},
|
||||
}
|
||||
);
|
||||
});
|
||||
|
||||
|
@ -225,7 +225,7 @@ describe("ChatCompletionsTable", () => {
|
|||
test('renders "No chat completions found." and no table when data array is empty', () => {
|
||||
render(<ChatCompletionsTable {...defaultProps} />);
|
||||
expect(
|
||||
screen.getByText("No chat completions found."),
|
||||
screen.getByText("No chat completions found.")
|
||||
).toBeInTheDocument();
|
||||
|
||||
// Ensure that the table structure is NOT rendered in the empty state
|
||||
|
@ -292,7 +292,7 @@ describe("ChatCompletionsTable", () => {
|
|||
|
||||
// Table caption
|
||||
expect(
|
||||
screen.getByText("A list of your recent chat completions."),
|
||||
screen.getByText("A list of your recent chat completions.")
|
||||
).toBeInTheDocument();
|
||||
|
||||
// Table headers
|
||||
|
@ -306,14 +306,14 @@ describe("ChatCompletionsTable", () => {
|
|||
expect(screen.getByText("Test output")).toBeInTheDocument();
|
||||
expect(screen.getByText("llama-test-model")).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText(new Date(1710000000 * 1000).toLocaleString()),
|
||||
screen.getByText(new Date(1710000000 * 1000).toLocaleString())
|
||||
).toBeInTheDocument();
|
||||
|
||||
expect(screen.getByText("Another input")).toBeInTheDocument();
|
||||
expect(screen.getByText("Another output")).toBeInTheDocument();
|
||||
expect(screen.getByText("llama-another-model")).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText(new Date(1710001000 * 1000).toLocaleString()),
|
||||
screen.getByText(new Date(1710001000 * 1000).toLocaleString())
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
@ -328,7 +328,7 @@ describe("ChatCompletionsTable", () => {
|
|||
return typeof text === "string" && text.length > effectiveMaxLength
|
||||
? text.slice(0, effectiveMaxLength) + "..."
|
||||
: text;
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
const longInput =
|
||||
|
@ -368,7 +368,7 @@ describe("ChatCompletionsTable", () => {
|
|||
|
||||
// The truncated text should be present for both input and output
|
||||
const truncatedTexts = screen.getAllByText(
|
||||
longInput.slice(0, 10) + "...",
|
||||
longInput.slice(0, 10) + "..."
|
||||
);
|
||||
expect(truncatedTexts.length).toBe(2); // one for input, one for output
|
||||
});
|
||||
|
@ -420,7 +420,7 @@ describe("ChatCompletionsTable", () => {
|
|||
// Verify the extracted text appears in the table
|
||||
expect(screen.getByText("Extracted input")).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText("Extracted output from assistant"),
|
||||
screen.getByText("Extracted output from assistant")
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
|
|
@ -5,6 +5,7 @@ import {
|
|||
UsePaginationOptions,
|
||||
ListChatCompletionsResponse,
|
||||
} from "@/lib/types";
|
||||
import { ListChatCompletionsParams } from "@/lib/llama-stack-client";
|
||||
import { LogsTable, LogTableRow } from "@/components/logs/logs-table";
|
||||
import {
|
||||
extractTextFromContentPart,
|
||||
|
@ -38,14 +39,14 @@ export function ChatCompletionsTable({
|
|||
limit: number;
|
||||
model?: string;
|
||||
order?: string;
|
||||
},
|
||||
}
|
||||
) => {
|
||||
const response = await client.chat.completions.list({
|
||||
after: params.after,
|
||||
limit: params.limit,
|
||||
...(params.model && { model: params.model }),
|
||||
...(params.order && { order: params.order }),
|
||||
} as any);
|
||||
} as ListChatCompletionsParams);
|
||||
|
||||
return response as ListChatCompletionsResponse;
|
||||
};
|
||||
|
|
|
@ -37,21 +37,26 @@ export function ChatMessageItem({ message }: ChatMessageItemProps) {
|
|||
) {
|
||||
return (
|
||||
<>
|
||||
{message.tool_calls.map((toolCall: any, index: number) => {
|
||||
const formattedToolCall = formatToolCallToString(toolCall);
|
||||
const toolCallContent = (
|
||||
<ToolCallBlock>
|
||||
{formattedToolCall || "Error: Could not display tool call"}
|
||||
</ToolCallBlock>
|
||||
);
|
||||
return (
|
||||
<MessageBlock
|
||||
key={index}
|
||||
label="Tool Call"
|
||||
content={toolCallContent}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
{message.tool_calls.map(
|
||||
(
|
||||
toolCall: { function?: { name?: string; arguments?: unknown } },
|
||||
index: number
|
||||
) => {
|
||||
const formattedToolCall = formatToolCallToString(toolCall);
|
||||
const toolCallContent = (
|
||||
<ToolCallBlock>
|
||||
{formattedToolCall || "Error: Could not display tool call"}
|
||||
</ToolCallBlock>
|
||||
);
|
||||
return (
|
||||
<MessageBlock
|
||||
key={index}
|
||||
label="Tool Call"
|
||||
content={toolCallContent}
|
||||
/>
|
||||
);
|
||||
}
|
||||
)}
|
||||
</>
|
||||
);
|
||||
} else {
|
||||
|
|
|
@ -1,18 +1,18 @@
|
|||
"use client"
|
||||
"use client";
|
||||
|
||||
import React, { useMemo, useState } from "react"
|
||||
import { cva, type VariantProps } from "class-variance-authority"
|
||||
import { motion } from "framer-motion"
|
||||
import { Ban, ChevronRight, Code2, Loader2, Terminal } from "lucide-react"
|
||||
import React, { useMemo, useState } from "react";
|
||||
import { cva, type VariantProps } from "class-variance-authority";
|
||||
import { motion } from "framer-motion";
|
||||
import { Ban, ChevronRight, Code2, Loader2, Terminal } from "lucide-react";
|
||||
|
||||
import { cn } from "@/lib/utils"
|
||||
import { cn } from "@/lib/utils";
|
||||
import {
|
||||
Collapsible,
|
||||
CollapsibleContent,
|
||||
CollapsibleTrigger,
|
||||
} from "@/components/ui/collapsible"
|
||||
import { FilePreview } from "@/components/ui/file-preview"
|
||||
import { MarkdownRenderer } from "@/components/chat-playground/markdown-renderer"
|
||||
} from "@/components/ui/collapsible";
|
||||
import { FilePreview } from "@/components/ui/file-preview";
|
||||
import { MarkdownRenderer } from "@/components/chat-playground/markdown-renderer";
|
||||
|
||||
const chatBubbleVariants = cva(
|
||||
"group/message relative break-words rounded-lg p-3 text-sm sm:max-w-[70%]",
|
||||
|
@ -52,66 +52,66 @@ const chatBubbleVariants = cva(
|
|||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
);
|
||||
|
||||
type Animation = VariantProps<typeof chatBubbleVariants>["animation"]
|
||||
type Animation = VariantProps<typeof chatBubbleVariants>["animation"];
|
||||
|
||||
interface Attachment {
|
||||
name?: string
|
||||
contentType?: string
|
||||
url: string
|
||||
name?: string;
|
||||
contentType?: string;
|
||||
url: string;
|
||||
}
|
||||
|
||||
interface PartialToolCall {
|
||||
state: "partial-call"
|
||||
toolName: string
|
||||
state: "partial-call";
|
||||
toolName: string;
|
||||
}
|
||||
|
||||
interface ToolCall {
|
||||
state: "call"
|
||||
toolName: string
|
||||
state: "call";
|
||||
toolName: string;
|
||||
}
|
||||
|
||||
interface ToolResult {
|
||||
state: "result"
|
||||
toolName: string
|
||||
state: "result";
|
||||
toolName: string;
|
||||
result: {
|
||||
__cancelled?: boolean
|
||||
[key: string]: any
|
||||
}
|
||||
__cancelled?: boolean;
|
||||
[key: string]: unknown;
|
||||
};
|
||||
}
|
||||
|
||||
type ToolInvocation = PartialToolCall | ToolCall | ToolResult
|
||||
type ToolInvocation = PartialToolCall | ToolCall | ToolResult;
|
||||
|
||||
interface ReasoningPart {
|
||||
type: "reasoning"
|
||||
reasoning: string
|
||||
type: "reasoning";
|
||||
reasoning: string;
|
||||
}
|
||||
|
||||
interface ToolInvocationPart {
|
||||
type: "tool-invocation"
|
||||
toolInvocation: ToolInvocation
|
||||
type: "tool-invocation";
|
||||
toolInvocation: ToolInvocation;
|
||||
}
|
||||
|
||||
interface TextPart {
|
||||
type: "text"
|
||||
text: string
|
||||
type: "text";
|
||||
text: string;
|
||||
}
|
||||
|
||||
// For compatibility with AI SDK types, not used
|
||||
interface SourcePart {
|
||||
type: "source"
|
||||
source?: any
|
||||
type: "source";
|
||||
source?: unknown;
|
||||
}
|
||||
|
||||
interface FilePart {
|
||||
type: "file"
|
||||
mimeType: string
|
||||
data: string
|
||||
type: "file";
|
||||
mimeType: string;
|
||||
data: string;
|
||||
}
|
||||
|
||||
interface StepStartPart {
|
||||
type: "step-start"
|
||||
type: "step-start";
|
||||
}
|
||||
|
||||
type MessagePart =
|
||||
|
@ -120,22 +120,22 @@ type MessagePart =
|
|||
| ToolInvocationPart
|
||||
| SourcePart
|
||||
| FilePart
|
||||
| StepStartPart
|
||||
| StepStartPart;
|
||||
|
||||
export interface Message {
|
||||
id: string
|
||||
role: "user" | "assistant" | (string & {})
|
||||
content: string
|
||||
createdAt?: Date
|
||||
experimental_attachments?: Attachment[]
|
||||
toolInvocations?: ToolInvocation[]
|
||||
parts?: MessagePart[]
|
||||
id: string;
|
||||
role: "user" | "assistant" | (string & {});
|
||||
content: string;
|
||||
createdAt?: Date;
|
||||
experimental_attachments?: Attachment[];
|
||||
toolInvocations?: ToolInvocation[];
|
||||
parts?: MessagePart[];
|
||||
}
|
||||
|
||||
export interface ChatMessageProps extends Message {
|
||||
showTimeStamp?: boolean
|
||||
animation?: Animation
|
||||
actions?: React.ReactNode
|
||||
showTimeStamp?: boolean;
|
||||
animation?: Animation;
|
||||
actions?: React.ReactNode;
|
||||
}
|
||||
|
||||
export const ChatMessage: React.FC<ChatMessageProps> = ({
|
||||
|
@ -150,21 +150,21 @@ export const ChatMessage: React.FC<ChatMessageProps> = ({
|
|||
parts,
|
||||
}) => {
|
||||
const files = useMemo(() => {
|
||||
return experimental_attachments?.map((attachment) => {
|
||||
const dataArray = dataUrlToUint8Array(attachment.url)
|
||||
return experimental_attachments?.map(attachment => {
|
||||
const dataArray = dataUrlToUint8Array(attachment.url);
|
||||
const file = new File([dataArray], attachment.name ?? "Unknown", {
|
||||
type: attachment.contentType,
|
||||
})
|
||||
return file
|
||||
})
|
||||
}, [experimental_attachments])
|
||||
});
|
||||
return file;
|
||||
});
|
||||
}, [experimental_attachments]);
|
||||
|
||||
const isUser = role === "user"
|
||||
const isUser = role === "user";
|
||||
|
||||
const formattedTime = createdAt?.toLocaleTimeString("en-US", {
|
||||
hour: "2-digit",
|
||||
minute: "2-digit",
|
||||
})
|
||||
});
|
||||
|
||||
if (isUser) {
|
||||
return (
|
||||
|
@ -174,7 +174,7 @@ export const ChatMessage: React.FC<ChatMessageProps> = ({
|
|||
{files ? (
|
||||
<div className="mb-1 flex flex-wrap gap-2">
|
||||
{files.map((file, index) => {
|
||||
return <FilePreview file={file} key={index} />
|
||||
return <FilePreview file={file} key={index} />;
|
||||
})}
|
||||
</div>
|
||||
) : null}
|
||||
|
@ -195,7 +195,7 @@ export const ChatMessage: React.FC<ChatMessageProps> = ({
|
|||
</time>
|
||||
) : null}
|
||||
</div>
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (parts && parts.length > 0) {
|
||||
|
@ -230,23 +230,23 @@ export const ChatMessage: React.FC<ChatMessageProps> = ({
|
|||
</time>
|
||||
) : null}
|
||||
</div>
|
||||
)
|
||||
);
|
||||
} else if (part.type === "reasoning") {
|
||||
return <ReasoningBlock key={`reasoning-${index}`} part={part} />
|
||||
return <ReasoningBlock key={`reasoning-${index}`} part={part} />;
|
||||
} else if (part.type === "tool-invocation") {
|
||||
return (
|
||||
<ToolCall
|
||||
key={`tool-${index}`}
|
||||
toolInvocations={[part.toolInvocation]}
|
||||
/>
|
||||
)
|
||||
);
|
||||
}
|
||||
return null
|
||||
})
|
||||
return null;
|
||||
});
|
||||
}
|
||||
|
||||
if (toolInvocations && toolInvocations.length > 0) {
|
||||
return <ToolCall toolInvocations={toolInvocations} />
|
||||
return <ToolCall toolInvocations={toolInvocations} />;
|
||||
}
|
||||
|
||||
return (
|
||||
|
@ -272,17 +272,17 @@ export const ChatMessage: React.FC<ChatMessageProps> = ({
|
|||
</time>
|
||||
) : null}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
);
|
||||
};
|
||||
|
||||
function dataUrlToUint8Array(data: string) {
|
||||
const base64 = data.split(",")[1]
|
||||
const buf = Buffer.from(base64, "base64")
|
||||
return new Uint8Array(buf)
|
||||
const base64 = data.split(",")[1];
|
||||
const buf = Buffer.from(base64, "base64");
|
||||
return new Uint8Array(buf);
|
||||
}
|
||||
|
||||
const ReasoningBlock = ({ part }: { part: ReasoningPart }) => {
|
||||
const [isOpen, setIsOpen] = useState(false)
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
|
||||
return (
|
||||
<div className="mb-2 flex flex-col items-start sm:max-w-[70%]">
|
||||
|
@ -319,20 +319,20 @@ const ReasoningBlock = ({ part }: { part: ReasoningPart }) => {
|
|||
</CollapsibleContent>
|
||||
</Collapsible>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
);
|
||||
};
|
||||
|
||||
function ToolCall({
|
||||
toolInvocations,
|
||||
}: Pick<ChatMessageProps, "toolInvocations">) {
|
||||
if (!toolInvocations?.length) return null
|
||||
if (!toolInvocations?.length) return null;
|
||||
|
||||
return (
|
||||
<div className="flex flex-col items-start gap-2">
|
||||
{toolInvocations.map((invocation, index) => {
|
||||
const isCancelled =
|
||||
invocation.state === "result" &&
|
||||
invocation.result.__cancelled === true
|
||||
invocation.result.__cancelled === true;
|
||||
|
||||
if (isCancelled) {
|
||||
return (
|
||||
|
@ -350,7 +350,7 @@ function ToolCall({
|
|||
</span>
|
||||
</span>
|
||||
</div>
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
switch (invocation.state) {
|
||||
|
@ -373,7 +373,7 @@ function ToolCall({
|
|||
</span>
|
||||
<Loader2 className="h-3 w-3 animate-spin" />
|
||||
</div>
|
||||
)
|
||||
);
|
||||
case "result":
|
||||
return (
|
||||
<div
|
||||
|
@ -395,11 +395,11 @@ function ToolCall({
|
|||
{JSON.stringify(invocation.result, null, 2)}
|
||||
</pre>
|
||||
</div>
|
||||
)
|
||||
);
|
||||
default:
|
||||
return null
|
||||
return null;
|
||||
}
|
||||
})}
|
||||
</div>
|
||||
)
|
||||
);
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
"use client"
|
||||
"use client";
|
||||
|
||||
import {
|
||||
forwardRef,
|
||||
|
@ -6,48 +6,48 @@ import {
|
|||
useRef,
|
||||
useState,
|
||||
type ReactElement,
|
||||
} from "react"
|
||||
import { ArrowDown, ThumbsDown, ThumbsUp } from "lucide-react"
|
||||
} from "react";
|
||||
import { ArrowDown, ThumbsDown, ThumbsUp } from "lucide-react";
|
||||
|
||||
import { cn } from "@/lib/utils"
|
||||
import { useAutoScroll } from "@/hooks/use-auto-scroll"
|
||||
import { Button } from "@/components/ui/button"
|
||||
import { type Message } from "@/components/chat-playground/chat-message"
|
||||
import { CopyButton } from "@/components/ui/copy-button"
|
||||
import { MessageInput } from "@/components/chat-playground/message-input"
|
||||
import { MessageList } from "@/components/chat-playground/message-list"
|
||||
import { PromptSuggestions } from "@/components/chat-playground/prompt-suggestions"
|
||||
import { cn } from "@/lib/utils";
|
||||
import { useAutoScroll } from "@/hooks/use-auto-scroll";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { type Message } from "@/components/chat-playground/chat-message";
|
||||
import { CopyButton } from "@/components/ui/copy-button";
|
||||
import { MessageInput } from "@/components/chat-playground/message-input";
|
||||
import { MessageList } from "@/components/chat-playground/message-list";
|
||||
import { PromptSuggestions } from "@/components/chat-playground/prompt-suggestions";
|
||||
|
||||
interface ChatPropsBase {
|
||||
handleSubmit: (
|
||||
event?: { preventDefault?: () => void },
|
||||
options?: { experimental_attachments?: FileList }
|
||||
) => void
|
||||
messages: Array<Message>
|
||||
input: string
|
||||
className?: string
|
||||
handleInputChange: React.ChangeEventHandler<HTMLTextAreaElement>
|
||||
isGenerating: boolean
|
||||
stop?: () => void
|
||||
) => void;
|
||||
messages: Array<Message>;
|
||||
input: string;
|
||||
className?: string;
|
||||
handleInputChange: React.ChangeEventHandler<HTMLTextAreaElement>;
|
||||
isGenerating: boolean;
|
||||
stop?: () => void;
|
||||
onRateResponse?: (
|
||||
messageId: string,
|
||||
rating: "thumbs-up" | "thumbs-down"
|
||||
) => void
|
||||
setMessages?: (messages: any[]) => void
|
||||
transcribeAudio?: (blob: Blob) => Promise<string>
|
||||
) => void;
|
||||
setMessages?: (messages: Message[]) => void;
|
||||
transcribeAudio?: (blob: Blob) => Promise<string>;
|
||||
}
|
||||
|
||||
interface ChatPropsWithoutSuggestions extends ChatPropsBase {
|
||||
append?: never
|
||||
suggestions?: never
|
||||
append?: never;
|
||||
suggestions?: never;
|
||||
}
|
||||
|
||||
interface ChatPropsWithSuggestions extends ChatPropsBase {
|
||||
append: (message: { role: "user"; content: string }) => void
|
||||
suggestions: string[]
|
||||
append: (message: { role: "user"; content: string }) => void;
|
||||
suggestions: string[];
|
||||
}
|
||||
|
||||
type ChatProps = ChatPropsWithoutSuggestions | ChatPropsWithSuggestions
|
||||
type ChatProps = ChatPropsWithoutSuggestions | ChatPropsWithSuggestions;
|
||||
|
||||
export function Chat({
|
||||
messages,
|
||||
|
@ -63,34 +63,34 @@ export function Chat({
|
|||
setMessages,
|
||||
transcribeAudio,
|
||||
}: ChatProps) {
|
||||
const lastMessage = messages.at(-1)
|
||||
const isEmpty = messages.length === 0
|
||||
const isTyping = lastMessage?.role === "user"
|
||||
const lastMessage = messages.at(-1);
|
||||
const isEmpty = messages.length === 0;
|
||||
const isTyping = lastMessage?.role === "user";
|
||||
|
||||
const messagesRef = useRef(messages)
|
||||
messagesRef.current = messages
|
||||
const messagesRef = useRef(messages);
|
||||
messagesRef.current = messages;
|
||||
|
||||
// Enhanced stop function that marks pending tool calls as cancelled
|
||||
const handleStop = useCallback(() => {
|
||||
stop?.()
|
||||
stop?.();
|
||||
|
||||
if (!setMessages) return
|
||||
if (!setMessages) return;
|
||||
|
||||
const latestMessages = [...messagesRef.current]
|
||||
const latestMessages = [...messagesRef.current];
|
||||
const lastAssistantMessage = latestMessages.findLast(
|
||||
(m) => m.role === "assistant"
|
||||
)
|
||||
m => m.role === "assistant"
|
||||
);
|
||||
|
||||
if (!lastAssistantMessage) return
|
||||
if (!lastAssistantMessage) return;
|
||||
|
||||
let needsUpdate = false
|
||||
let updatedMessage = { ...lastAssistantMessage }
|
||||
let needsUpdate = false;
|
||||
let updatedMessage = { ...lastAssistantMessage };
|
||||
|
||||
if (lastAssistantMessage.toolInvocations) {
|
||||
const updatedToolInvocations = lastAssistantMessage.toolInvocations.map(
|
||||
(toolInvocation) => {
|
||||
toolInvocation => {
|
||||
if (toolInvocation.state === "call") {
|
||||
needsUpdate = true
|
||||
needsUpdate = true;
|
||||
return {
|
||||
...toolInvocation,
|
||||
state: "result",
|
||||
|
@ -98,61 +98,66 @@ export function Chat({
|
|||
content: "Tool execution was cancelled",
|
||||
__cancelled: true, // Special marker to indicate cancellation
|
||||
},
|
||||
} as const
|
||||
} as const;
|
||||
}
|
||||
return toolInvocation
|
||||
return toolInvocation;
|
||||
}
|
||||
)
|
||||
);
|
||||
|
||||
if (needsUpdate) {
|
||||
updatedMessage = {
|
||||
...updatedMessage,
|
||||
toolInvocations: updatedToolInvocations,
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if (lastAssistantMessage.parts && lastAssistantMessage.parts.length > 0) {
|
||||
const updatedParts = lastAssistantMessage.parts.map((part: any) => {
|
||||
if (
|
||||
part.type === "tool-invocation" &&
|
||||
part.toolInvocation &&
|
||||
part.toolInvocation.state === "call"
|
||||
) {
|
||||
needsUpdate = true
|
||||
return {
|
||||
...part,
|
||||
toolInvocation: {
|
||||
...part.toolInvocation,
|
||||
state: "result",
|
||||
result: {
|
||||
content: "Tool execution was cancelled",
|
||||
__cancelled: true,
|
||||
const updatedParts = lastAssistantMessage.parts.map(
|
||||
(part: {
|
||||
type: string;
|
||||
toolInvocation?: { state: string; toolName: string };
|
||||
}) => {
|
||||
if (
|
||||
part.type === "tool-invocation" &&
|
||||
part.toolInvocation &&
|
||||
part.toolInvocation.state === "call"
|
||||
) {
|
||||
needsUpdate = true;
|
||||
return {
|
||||
...part,
|
||||
toolInvocation: {
|
||||
...part.toolInvocation,
|
||||
state: "result",
|
||||
result: {
|
||||
content: "Tool execution was cancelled",
|
||||
__cancelled: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
}
|
||||
return part;
|
||||
}
|
||||
return part
|
||||
})
|
||||
);
|
||||
|
||||
if (needsUpdate) {
|
||||
updatedMessage = {
|
||||
...updatedMessage,
|
||||
parts: updatedParts,
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if (needsUpdate) {
|
||||
const messageIndex = latestMessages.findIndex(
|
||||
(m) => m.id === lastAssistantMessage.id
|
||||
)
|
||||
m => m.id === lastAssistantMessage.id
|
||||
);
|
||||
if (messageIndex !== -1) {
|
||||
latestMessages[messageIndex] = updatedMessage
|
||||
setMessages(latestMessages)
|
||||
latestMessages[messageIndex] = updatedMessage;
|
||||
setMessages(latestMessages);
|
||||
}
|
||||
}
|
||||
}, [stop, setMessages, messagesRef])
|
||||
}, [stop, setMessages, messagesRef]);
|
||||
|
||||
const messageOptions = useCallback(
|
||||
(message: Message) => ({
|
||||
|
@ -189,7 +194,7 @@ export function Chat({
|
|||
),
|
||||
}),
|
||||
[onRateResponse]
|
||||
)
|
||||
);
|
||||
|
||||
return (
|
||||
<ChatContainer className={className}>
|
||||
|
@ -237,15 +242,15 @@ export function Chat({
|
|||
</div>
|
||||
</div>
|
||||
</ChatContainer>
|
||||
)
|
||||
);
|
||||
}
|
||||
Chat.displayName = "Chat"
|
||||
Chat.displayName = "Chat";
|
||||
|
||||
export function ChatMessages({
|
||||
messages,
|
||||
children,
|
||||
}: React.PropsWithChildren<{
|
||||
messages: Message[]
|
||||
messages: Message[];
|
||||
}>) {
|
||||
const {
|
||||
containerRef,
|
||||
|
@ -253,7 +258,7 @@ export function ChatMessages({
|
|||
handleScroll,
|
||||
shouldAutoScroll,
|
||||
handleTouchStart,
|
||||
} = useAutoScroll([messages])
|
||||
} = useAutoScroll([messages]);
|
||||
|
||||
return (
|
||||
<div
|
||||
|
@ -281,7 +286,7 @@ export function ChatMessages({
|
|||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
export const ChatContainer = forwardRef<
|
||||
|
@ -294,56 +299,56 @@ export const ChatContainer = forwardRef<
|
|||
className={cn("flex flex-col max-h-full w-full", className)}
|
||||
{...props}
|
||||
/>
|
||||
)
|
||||
})
|
||||
ChatContainer.displayName = "ChatContainer"
|
||||
);
|
||||
});
|
||||
ChatContainer.displayName = "ChatContainer";
|
||||
|
||||
interface ChatFormProps {
|
||||
className?: string
|
||||
isPending: boolean
|
||||
className?: string;
|
||||
isPending: boolean;
|
||||
handleSubmit: (
|
||||
event?: { preventDefault?: () => void },
|
||||
options?: { experimental_attachments?: FileList }
|
||||
) => void
|
||||
) => void;
|
||||
children: (props: {
|
||||
files: File[] | null
|
||||
setFiles: React.Dispatch<React.SetStateAction<File[] | null>>
|
||||
}) => ReactElement
|
||||
files: File[] | null;
|
||||
setFiles: React.Dispatch<React.SetStateAction<File[] | null>>;
|
||||
}) => ReactElement;
|
||||
}
|
||||
|
||||
export const ChatForm = forwardRef<HTMLFormElement, ChatFormProps>(
|
||||
({ children, handleSubmit, isPending, className }, ref) => {
|
||||
const [files, setFiles] = useState<File[] | null>(null)
|
||||
const [files, setFiles] = useState<File[] | null>(null);
|
||||
|
||||
const onSubmit = (event: React.FormEvent) => {
|
||||
// if (isPending) {
|
||||
// event.preventDefault()
|
||||
// return
|
||||
// }
|
||||
|
||||
if (!files) {
|
||||
handleSubmit(event)
|
||||
return
|
||||
if (isPending) {
|
||||
event.preventDefault();
|
||||
return;
|
||||
}
|
||||
|
||||
const fileList = createFileList(files)
|
||||
handleSubmit(event, { experimental_attachments: fileList })
|
||||
setFiles(null)
|
||||
}
|
||||
if (!files) {
|
||||
handleSubmit(event);
|
||||
return;
|
||||
}
|
||||
|
||||
const fileList = createFileList(files);
|
||||
handleSubmit(event, { experimental_attachments: fileList });
|
||||
setFiles(null);
|
||||
};
|
||||
|
||||
return (
|
||||
<form ref={ref} onSubmit={onSubmit} className={className}>
|
||||
{children({ files, setFiles })}
|
||||
</form>
|
||||
)
|
||||
);
|
||||
}
|
||||
)
|
||||
ChatForm.displayName = "ChatForm"
|
||||
);
|
||||
ChatForm.displayName = "ChatForm";
|
||||
|
||||
function createFileList(files: File[] | FileList): FileList {
|
||||
const dataTransfer = new DataTransfer()
|
||||
const dataTransfer = new DataTransfer();
|
||||
for (const file of Array.from(files)) {
|
||||
dataTransfer.items.add(file)
|
||||
dataTransfer.items.add(file);
|
||||
}
|
||||
return dataTransfer.files
|
||||
return dataTransfer.files;
|
||||
}
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
"use client"
|
||||
"use client";
|
||||
|
||||
import { AnimatePresence, motion } from "framer-motion"
|
||||
import { X } from "lucide-react"
|
||||
import { AnimatePresence, motion } from "framer-motion";
|
||||
import { X } from "lucide-react";
|
||||
|
||||
interface InterruptPromptProps {
|
||||
isOpen: boolean
|
||||
close: () => void
|
||||
isOpen: boolean;
|
||||
close: () => void;
|
||||
}
|
||||
|
||||
export function InterruptPrompt({ isOpen, close }: InterruptPromptProps) {
|
||||
|
@ -37,5 +37,5 @@ export function InterruptPrompt({ isOpen, close }: InterruptPromptProps) {
|
|||
</motion.div>
|
||||
)}
|
||||
</AnimatePresence>
|
||||
)
|
||||
);
|
||||
}
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
import React, { Suspense, useEffect, useState } from "react"
|
||||
import Markdown from "react-markdown"
|
||||
import remarkGfm from "remark-gfm"
|
||||
import React, { Suspense, useEffect, useState } from "react";
|
||||
import Markdown from "react-markdown";
|
||||
import remarkGfm from "remark-gfm";
|
||||
|
||||
import { cn } from "@/lib/utils"
|
||||
import { CopyButton } from "@/components/ui/copy-button"
|
||||
import { cn } from "@/lib/utils";
|
||||
import { CopyButton } from "@/components/ui/copy-button";
|
||||
|
||||
interface MarkdownRendererProps {
|
||||
children: string
|
||||
children: string;
|
||||
}
|
||||
|
||||
export function MarkdownRenderer({ children }: MarkdownRendererProps) {
|
||||
|
@ -16,34 +16,34 @@ export function MarkdownRenderer({ children }: MarkdownRendererProps) {
|
|||
{children}
|
||||
</Markdown>
|
||||
</div>
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
interface HighlightedPre extends React.HTMLAttributes<HTMLPreElement> {
|
||||
children: string
|
||||
language: string
|
||||
children: string;
|
||||
language: string;
|
||||
}
|
||||
|
||||
const HighlightedPre = React.memo(
|
||||
({ children, language, ...props }: HighlightedPre) => {
|
||||
const [tokens, setTokens] = useState<any[] | null>(null)
|
||||
const [isSupported, setIsSupported] = useState(false)
|
||||
const [tokens, setTokens] = useState<unknown[] | null>(null);
|
||||
const [isSupported, setIsSupported] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
let mounted = true
|
||||
let mounted = true;
|
||||
|
||||
const loadAndHighlight = async () => {
|
||||
try {
|
||||
const { codeToTokens, bundledLanguages } = await import("shiki")
|
||||
const { codeToTokens, bundledLanguages } = await import("shiki");
|
||||
|
||||
if (!mounted) return
|
||||
if (!mounted) return;
|
||||
|
||||
if (!(language in bundledLanguages)) {
|
||||
setIsSupported(false)
|
||||
return
|
||||
setIsSupported(false);
|
||||
return;
|
||||
}
|
||||
|
||||
setIsSupported(true)
|
||||
setIsSupported(true);
|
||||
|
||||
const { tokens: highlightedTokens } = await codeToTokens(children, {
|
||||
lang: language as keyof typeof bundledLanguages,
|
||||
|
@ -52,31 +52,31 @@ const HighlightedPre = React.memo(
|
|||
light: "github-light",
|
||||
dark: "github-dark",
|
||||
},
|
||||
})
|
||||
});
|
||||
|
||||
if (mounted) {
|
||||
setTokens(highlightedTokens)
|
||||
setTokens(highlightedTokens);
|
||||
}
|
||||
} catch (error) {
|
||||
} catch {
|
||||
if (mounted) {
|
||||
setIsSupported(false)
|
||||
setIsSupported(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
loadAndHighlight()
|
||||
loadAndHighlight();
|
||||
|
||||
return () => {
|
||||
mounted = false
|
||||
}
|
||||
}, [children, language])
|
||||
mounted = false;
|
||||
};
|
||||
}, [children, language]);
|
||||
|
||||
if (!isSupported) {
|
||||
return <pre {...props}>{children}</pre>
|
||||
return <pre {...props}>{children}</pre>;
|
||||
}
|
||||
|
||||
if (!tokens) {
|
||||
return <pre {...props}>{children}</pre>
|
||||
return <pre {...props}>{children}</pre>;
|
||||
}
|
||||
|
||||
return (
|
||||
|
@ -89,7 +89,7 @@ const HighlightedPre = React.memo(
|
|||
const style =
|
||||
typeof token.htmlStyle === "string"
|
||||
? undefined
|
||||
: token.htmlStyle
|
||||
: token.htmlStyle;
|
||||
|
||||
return (
|
||||
<span
|
||||
|
@ -99,7 +99,7 @@ const HighlightedPre = React.memo(
|
|||
>
|
||||
{token.content}
|
||||
</span>
|
||||
)
|
||||
);
|
||||
})}
|
||||
</span>
|
||||
{lineIndex !== tokens.length - 1 && "\n"}
|
||||
|
@ -107,15 +107,15 @@ const HighlightedPre = React.memo(
|
|||
))}
|
||||
</code>
|
||||
</pre>
|
||||
)
|
||||
);
|
||||
}
|
||||
)
|
||||
HighlightedPre.displayName = "HighlightedCode"
|
||||
);
|
||||
HighlightedPre.displayName = "HighlightedCode";
|
||||
|
||||
interface CodeBlockProps extends React.HTMLAttributes<HTMLPreElement> {
|
||||
children: React.ReactNode
|
||||
className?: string
|
||||
language: string
|
||||
children: React.ReactNode;
|
||||
className?: string;
|
||||
language: string;
|
||||
}
|
||||
|
||||
const CodeBlock = ({
|
||||
|
@ -127,12 +127,12 @@ const CodeBlock = ({
|
|||
const code =
|
||||
typeof children === "string"
|
||||
? children
|
||||
: childrenTakeAllStringContents(children)
|
||||
: childrenTakeAllStringContents(children);
|
||||
|
||||
const preClass = cn(
|
||||
"overflow-x-scroll rounded-md border bg-background/50 p-4 font-mono text-sm [scrollbar-width:none]",
|
||||
className
|
||||
)
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="group/code relative mb-4">
|
||||
|
@ -152,27 +152,27 @@ const CodeBlock = ({
|
|||
<CopyButton content={code} copyMessage="Copied code to clipboard" />
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
);
|
||||
};
|
||||
|
||||
function childrenTakeAllStringContents(element: any): string {
|
||||
function childrenTakeAllStringContents(element: unknown): string {
|
||||
if (typeof element === "string") {
|
||||
return element
|
||||
return element;
|
||||
}
|
||||
|
||||
if (element?.props?.children) {
|
||||
let children = element.props.children
|
||||
const children = element.props.children;
|
||||
|
||||
if (Array.isArray(children)) {
|
||||
return children
|
||||
.map((child) => childrenTakeAllStringContents(child))
|
||||
.join("")
|
||||
.map(child => childrenTakeAllStringContents(child))
|
||||
.join("");
|
||||
} else {
|
||||
return childrenTakeAllStringContents(children)
|
||||
return childrenTakeAllStringContents(children);
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
return "";
|
||||
}
|
||||
|
||||
const COMPONENTS = {
|
||||
|
@ -184,8 +184,15 @@ const COMPONENTS = {
|
|||
strong: withClass("strong", "font-semibold"),
|
||||
a: withClass("a", "text-primary underline underline-offset-2"),
|
||||
blockquote: withClass("blockquote", "border-l-2 border-primary pl-4"),
|
||||
code: ({ children, className, node, ...rest }: any) => {
|
||||
const match = /language-(\w+)/.exec(className || "")
|
||||
code: ({
|
||||
children,
|
||||
className,
|
||||
...rest
|
||||
}: {
|
||||
children: React.ReactNode;
|
||||
className?: string;
|
||||
}) => {
|
||||
const match = /language-(\w+)/.exec(className || "");
|
||||
return match ? (
|
||||
<CodeBlock className={className} language={match[1]} {...rest}>
|
||||
{children}
|
||||
|
@ -199,9 +206,9 @@ const COMPONENTS = {
|
|||
>
|
||||
{children}
|
||||
</code>
|
||||
)
|
||||
);
|
||||
},
|
||||
pre: ({ children }: any) => children,
|
||||
pre: ({ children }: { children: React.ReactNode }) => children,
|
||||
ol: withClass("ol", "list-decimal space-y-2 pl-6"),
|
||||
ul: withClass("ul", "list-disc space-y-2 pl-6"),
|
||||
li: withClass("li", "my-1.5"),
|
||||
|
@ -220,14 +227,14 @@ const COMPONENTS = {
|
|||
tr: withClass("tr", "m-0 border-t p-0 even:bg-muted"),
|
||||
p: withClass("p", "whitespace-pre-wrap"),
|
||||
hr: withClass("hr", "border-foreground/20"),
|
||||
}
|
||||
};
|
||||
|
||||
function withClass(Tag: keyof JSX.IntrinsicElements, classes: string) {
|
||||
const Component = ({ node, ...props }: any) => (
|
||||
const Component = ({ ...props }: Record<string, unknown>) => (
|
||||
<Tag className={classes} {...props} />
|
||||
)
|
||||
Component.displayName = Tag
|
||||
return Component
|
||||
);
|
||||
Component.displayName = Tag;
|
||||
return Component;
|
||||
}
|
||||
|
||||
export default MarkdownRenderer
|
||||
export default MarkdownRenderer;
|
||||
|
|
|
@ -1,41 +1,41 @@
|
|||
"use client"
|
||||
"use client";
|
||||
|
||||
import React, { useEffect, useRef, useState } from "react"
|
||||
import { AnimatePresence, motion } from "framer-motion"
|
||||
import { ArrowUp, Info, Loader2, Mic, Paperclip, Square } from "lucide-react"
|
||||
import { omit } from "remeda"
|
||||
import React, { useEffect, useRef, useState } from "react";
|
||||
import { AnimatePresence, motion } from "framer-motion";
|
||||
import { ArrowUp, Info, Loader2, Mic, Paperclip, Square } from "lucide-react";
|
||||
import { omit } from "remeda";
|
||||
|
||||
import { cn } from "@/lib/utils"
|
||||
import { useAudioRecording } from "@/hooks/use-audio-recording"
|
||||
import { useAutosizeTextArea } from "@/hooks/use-autosize-textarea"
|
||||
import { AudioVisualizer } from "@/components/ui/audio-visualizer"
|
||||
import { Button } from "@/components/ui/button"
|
||||
import { FilePreview } from "@/components/ui/file-preview"
|
||||
import { InterruptPrompt } from "@/components/chat-playground/interrupt-prompt"
|
||||
import { cn } from "@/lib/utils";
|
||||
import { useAudioRecording } from "@/hooks/use-audio-recording";
|
||||
import { useAutosizeTextArea } from "@/hooks/use-autosize-textarea";
|
||||
import { AudioVisualizer } from "@/components/ui/audio-visualizer";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { FilePreview } from "@/components/ui/file-preview";
|
||||
import { InterruptPrompt } from "@/components/chat-playground/interrupt-prompt";
|
||||
|
||||
interface MessageInputBaseProps
|
||||
extends React.TextareaHTMLAttributes<HTMLTextAreaElement> {
|
||||
value: string
|
||||
submitOnEnter?: boolean
|
||||
stop?: () => void
|
||||
isGenerating: boolean
|
||||
enableInterrupt?: boolean
|
||||
transcribeAudio?: (blob: Blob) => Promise<string>
|
||||
value: string;
|
||||
submitOnEnter?: boolean;
|
||||
stop?: () => void;
|
||||
isGenerating: boolean;
|
||||
enableInterrupt?: boolean;
|
||||
transcribeAudio?: (blob: Blob) => Promise<string>;
|
||||
}
|
||||
|
||||
interface MessageInputWithoutAttachmentProps extends MessageInputBaseProps {
|
||||
allowAttachments?: false
|
||||
allowAttachments?: false;
|
||||
}
|
||||
|
||||
interface MessageInputWithAttachmentsProps extends MessageInputBaseProps {
|
||||
allowAttachments: true
|
||||
files: File[] | null
|
||||
setFiles: React.Dispatch<React.SetStateAction<File[] | null>>
|
||||
allowAttachments: true;
|
||||
files: File[] | null;
|
||||
setFiles: React.Dispatch<React.SetStateAction<File[] | null>>;
|
||||
}
|
||||
|
||||
type MessageInputProps =
|
||||
| MessageInputWithoutAttachmentProps
|
||||
| MessageInputWithAttachmentsProps
|
||||
| MessageInputWithAttachmentsProps;
|
||||
|
||||
export function MessageInput({
|
||||
placeholder = "Ask AI...",
|
||||
|
@ -48,8 +48,8 @@ export function MessageInput({
|
|||
transcribeAudio,
|
||||
...props
|
||||
}: MessageInputProps) {
|
||||
const [isDragging, setIsDragging] = useState(false)
|
||||
const [showInterruptPrompt, setShowInterruptPrompt] = useState(false)
|
||||
const [isDragging, setIsDragging] = useState(false);
|
||||
const [showInterruptPrompt, setShowInterruptPrompt] = useState(false);
|
||||
|
||||
const {
|
||||
isListening,
|
||||
|
@ -61,123 +61,124 @@ export function MessageInput({
|
|||
stopRecording,
|
||||
} = useAudioRecording({
|
||||
transcribeAudio,
|
||||
onTranscriptionComplete: (text) => {
|
||||
props.onChange?.({ target: { value: text } } as any)
|
||||
onTranscriptionComplete: text => {
|
||||
props.onChange?.({
|
||||
target: { value: text },
|
||||
} as React.ChangeEvent<HTMLTextAreaElement>);
|
||||
},
|
||||
})
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
if (!isGenerating) {
|
||||
setShowInterruptPrompt(false)
|
||||
setShowInterruptPrompt(false);
|
||||
}
|
||||
}, [isGenerating])
|
||||
}, [isGenerating]);
|
||||
|
||||
const addFiles = (files: File[] | null) => {
|
||||
if (props.allowAttachments) {
|
||||
props.setFiles((currentFiles) => {
|
||||
props.setFiles(currentFiles => {
|
||||
if (currentFiles === null) {
|
||||
return files
|
||||
return files;
|
||||
}
|
||||
|
||||
if (files === null) {
|
||||
return currentFiles
|
||||
return currentFiles;
|
||||
}
|
||||
|
||||
return [...currentFiles, ...files]
|
||||
})
|
||||
return [...currentFiles, ...files];
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const onDragOver = (event: React.DragEvent) => {
|
||||
if (props.allowAttachments !== true) return
|
||||
event.preventDefault()
|
||||
setIsDragging(true)
|
||||
}
|
||||
if (props.allowAttachments !== true) return;
|
||||
event.preventDefault();
|
||||
setIsDragging(true);
|
||||
};
|
||||
|
||||
const onDragLeave = (event: React.DragEvent) => {
|
||||
if (props.allowAttachments !== true) return
|
||||
event.preventDefault()
|
||||
setIsDragging(false)
|
||||
}
|
||||
if (props.allowAttachments !== true) return;
|
||||
event.preventDefault();
|
||||
setIsDragging(false);
|
||||
};
|
||||
|
||||
const onDrop = (event: React.DragEvent) => {
|
||||
setIsDragging(false)
|
||||
if (props.allowAttachments !== true) return
|
||||
event.preventDefault()
|
||||
const dataTransfer = event.dataTransfer
|
||||
setIsDragging(false);
|
||||
if (props.allowAttachments !== true) return;
|
||||
event.preventDefault();
|
||||
const dataTransfer = event.dataTransfer;
|
||||
if (dataTransfer.files.length) {
|
||||
addFiles(Array.from(dataTransfer.files))
|
||||
addFiles(Array.from(dataTransfer.files));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const onPaste = (event: React.ClipboardEvent) => {
|
||||
const items = event.clipboardData?.items
|
||||
if (!items) return
|
||||
const items = event.clipboardData?.items;
|
||||
if (!items) return;
|
||||
|
||||
const text = event.clipboardData.getData("text")
|
||||
const text = event.clipboardData.getData("text");
|
||||
if (text && text.length > 500 && props.allowAttachments) {
|
||||
event.preventDefault()
|
||||
const blob = new Blob([text], { type: "text/plain" })
|
||||
event.preventDefault();
|
||||
const blob = new Blob([text], { type: "text/plain" });
|
||||
const file = new File([blob], "Pasted text", {
|
||||
type: "text/plain",
|
||||
lastModified: Date.now(),
|
||||
})
|
||||
addFiles([file])
|
||||
return
|
||||
});
|
||||
addFiles([file]);
|
||||
return;
|
||||
}
|
||||
|
||||
const files = Array.from(items)
|
||||
.map((item) => item.getAsFile())
|
||||
.filter((file) => file !== null)
|
||||
.map(item => item.getAsFile())
|
||||
.filter(file => file !== null);
|
||||
|
||||
if (props.allowAttachments && files.length > 0) {
|
||||
addFiles(files)
|
||||
addFiles(files);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const onKeyDown = (event: React.KeyboardEvent<HTMLTextAreaElement>) => {
|
||||
if (submitOnEnter && event.key === "Enter" && !event.shiftKey) {
|
||||
event.preventDefault()
|
||||
event.preventDefault();
|
||||
|
||||
if (isGenerating && stop && enableInterrupt) {
|
||||
if (showInterruptPrompt) {
|
||||
stop()
|
||||
setShowInterruptPrompt(false)
|
||||
event.currentTarget.form?.requestSubmit()
|
||||
stop();
|
||||
setShowInterruptPrompt(false);
|
||||
event.currentTarget.form?.requestSubmit();
|
||||
} else if (
|
||||
props.value ||
|
||||
(props.allowAttachments && props.files?.length)
|
||||
) {
|
||||
setShowInterruptPrompt(true)
|
||||
return
|
||||
setShowInterruptPrompt(true);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
event.currentTarget.form?.requestSubmit()
|
||||
event.currentTarget.form?.requestSubmit();
|
||||
}
|
||||
|
||||
onKeyDownProp?.(event)
|
||||
}
|
||||
onKeyDownProp?.(event);
|
||||
};
|
||||
|
||||
const textAreaRef = useRef<HTMLTextAreaElement>(null)
|
||||
const [textAreaHeight, setTextAreaHeight] = useState<number>(0)
|
||||
const textAreaRef = useRef<HTMLTextAreaElement>(null);
|
||||
const [textAreaHeight, setTextAreaHeight] = useState<number>(0);
|
||||
|
||||
useEffect(() => {
|
||||
if (textAreaRef.current) {
|
||||
setTextAreaHeight(textAreaRef.current.offsetHeight)
|
||||
setTextAreaHeight(textAreaRef.current.offsetHeight);
|
||||
}
|
||||
}, [props.value])
|
||||
}, [props.value]);
|
||||
|
||||
const showFileList =
|
||||
props.allowAttachments && props.files && props.files.length > 0
|
||||
|
||||
props.allowAttachments && props.files && props.files.length > 0;
|
||||
|
||||
useAutosizeTextArea({
|
||||
ref: textAreaRef,
|
||||
maxHeight: 240,
|
||||
borderWidth: 1,
|
||||
dependencies: [props.value, showFileList],
|
||||
})
|
||||
});
|
||||
|
||||
return (
|
||||
<div
|
||||
|
@ -220,24 +221,24 @@ export function MessageInput({
|
|||
<div className="absolute inset-x-3 bottom-0 z-20 overflow-x-scroll py-3">
|
||||
<div className="flex space-x-3">
|
||||
<AnimatePresence mode="popLayout">
|
||||
{props.files?.map((file) => {
|
||||
{props.files?.map(file => {
|
||||
return (
|
||||
<FilePreview
|
||||
key={file.name + String(file.lastModified)}
|
||||
file={file}
|
||||
onRemove={() => {
|
||||
props.setFiles((files) => {
|
||||
if (!files) return null
|
||||
props.setFiles(files => {
|
||||
if (!files) return null;
|
||||
|
||||
const filtered = Array.from(files).filter(
|
||||
(f) => f !== file
|
||||
)
|
||||
if (filtered.length === 0) return null
|
||||
return filtered
|
||||
})
|
||||
f => f !== file
|
||||
);
|
||||
if (filtered.length === 0) return null;
|
||||
return filtered;
|
||||
});
|
||||
}}
|
||||
/>
|
||||
)
|
||||
);
|
||||
})}
|
||||
</AnimatePresence>
|
||||
</div>
|
||||
|
@ -256,8 +257,8 @@ export function MessageInput({
|
|||
aria-label="Attach a file"
|
||||
disabled={true}
|
||||
onClick={async () => {
|
||||
const files = await showFileUploadDialog()
|
||||
addFiles(files)
|
||||
const files = await showFileUploadDialog();
|
||||
addFiles(files);
|
||||
}}
|
||||
>
|
||||
<Paperclip className="h-4 w-4" />
|
||||
|
@ -308,12 +309,12 @@ export function MessageInput({
|
|||
onStopRecording={stopRecording}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
);
|
||||
}
|
||||
MessageInput.displayName = "MessageInput"
|
||||
MessageInput.displayName = "MessageInput";
|
||||
|
||||
interface FileUploadOverlayProps {
|
||||
isDragging: boolean
|
||||
isDragging: boolean;
|
||||
}
|
||||
|
||||
function FileUploadOverlay({ isDragging }: FileUploadOverlayProps) {
|
||||
|
@ -333,29 +334,29 @@ function FileUploadOverlay({ isDragging }: FileUploadOverlayProps) {
|
|||
</motion.div>
|
||||
)}
|
||||
</AnimatePresence>
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
function showFileUploadDialog() {
|
||||
const input = document.createElement("input")
|
||||
const input = document.createElement("input");
|
||||
|
||||
input.type = "file"
|
||||
input.multiple = true
|
||||
input.accept = "*/*"
|
||||
input.click()
|
||||
input.type = "file";
|
||||
input.multiple = true;
|
||||
input.accept = "*/*";
|
||||
input.click();
|
||||
|
||||
return new Promise<File[] | null>((resolve) => {
|
||||
input.onchange = (e) => {
|
||||
const files = (e.currentTarget as HTMLInputElement).files
|
||||
return new Promise<File[] | null>(resolve => {
|
||||
input.onchange = e => {
|
||||
const files = (e.currentTarget as HTMLInputElement).files;
|
||||
|
||||
if (files) {
|
||||
resolve(Array.from(files))
|
||||
return
|
||||
resolve(Array.from(files));
|
||||
return;
|
||||
}
|
||||
|
||||
resolve(null)
|
||||
}
|
||||
})
|
||||
resolve(null);
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
function TranscribingOverlay() {
|
||||
|
@ -385,12 +386,12 @@ function TranscribingOverlay() {
|
|||
Transcribing audio...
|
||||
</p>
|
||||
</motion.div>
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
interface RecordingPromptProps {
|
||||
isVisible: boolean
|
||||
onStopRecording: () => void
|
||||
isVisible: boolean;
|
||||
onStopRecording: () => void;
|
||||
}
|
||||
|
||||
function RecordingPrompt({ isVisible, onStopRecording }: RecordingPromptProps) {
|
||||
|
@ -418,15 +419,15 @@ function RecordingPrompt({ isVisible, onStopRecording }: RecordingPromptProps) {
|
|||
</motion.div>
|
||||
)}
|
||||
</AnimatePresence>
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
interface RecordingControlsProps {
|
||||
isRecording: boolean
|
||||
isTranscribing: boolean
|
||||
audioStream: MediaStream | null
|
||||
textAreaHeight: number
|
||||
onStopRecording: () => void
|
||||
isRecording: boolean;
|
||||
isTranscribing: boolean;
|
||||
audioStream: MediaStream | null;
|
||||
textAreaHeight: number;
|
||||
onStopRecording: () => void;
|
||||
}
|
||||
|
||||
function RecordingControls({
|
||||
|
@ -448,7 +449,7 @@ function RecordingControls({
|
|||
onClick={onStopRecording}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (isTranscribing) {
|
||||
|
@ -459,8 +460,8 @@ function RecordingControls({
|
|||
>
|
||||
<TranscribingOverlay />
|
||||
</div>
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
return null
|
||||
return null;
|
||||
}
|
||||
|
|
|
@ -2,18 +2,18 @@ import {
|
|||
ChatMessage,
|
||||
type ChatMessageProps,
|
||||
type Message,
|
||||
} from "@/components/chat-playground/chat-message"
|
||||
import { TypingIndicator } from "@/components/chat-playground/typing-indicator"
|
||||
} from "@/components/chat-playground/chat-message";
|
||||
import { TypingIndicator } from "@/components/chat-playground/typing-indicator";
|
||||
|
||||
type AdditionalMessageOptions = Omit<ChatMessageProps, keyof Message>
|
||||
type AdditionalMessageOptions = Omit<ChatMessageProps, keyof Message>;
|
||||
|
||||
interface MessageListProps {
|
||||
messages: Message[]
|
||||
showTimeStamps?: boolean
|
||||
isTyping?: boolean
|
||||
messages: Message[];
|
||||
showTimeStamps?: boolean;
|
||||
isTyping?: boolean;
|
||||
messageOptions?:
|
||||
| AdditionalMessageOptions
|
||||
| ((message: Message) => AdditionalMessageOptions)
|
||||
| ((message: Message) => AdditionalMessageOptions);
|
||||
}
|
||||
|
||||
export function MessageList({
|
||||
|
@ -28,7 +28,7 @@ export function MessageList({
|
|||
const additionalOptions =
|
||||
typeof messageOptions === "function"
|
||||
? messageOptions(message)
|
||||
: messageOptions
|
||||
: messageOptions;
|
||||
|
||||
return (
|
||||
<ChatMessage
|
||||
|
@ -37,9 +37,9 @@ export function MessageList({
|
|||
{...message}
|
||||
{...additionalOptions}
|
||||
/>
|
||||
)
|
||||
);
|
||||
})}
|
||||
{isTyping && <TypingIndicator />}
|
||||
</div>
|
||||
)
|
||||
);
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
interface PromptSuggestionsProps {
|
||||
label: string
|
||||
append: (message: { role: "user"; content: string }) => void
|
||||
suggestions: string[]
|
||||
label: string;
|
||||
append: (message: { role: "user"; content: string }) => void;
|
||||
suggestions: string[];
|
||||
}
|
||||
|
||||
export function PromptSuggestions({
|
||||
|
@ -13,7 +13,7 @@ export function PromptSuggestions({
|
|||
<div className="space-y-6">
|
||||
<h2 className="text-center text-2xl font-bold">{label}</h2>
|
||||
<div className="flex gap-6 text-sm">
|
||||
{suggestions.map((suggestion) => (
|
||||
{suggestions.map(suggestion => (
|
||||
<button
|
||||
key={suggestion}
|
||||
onClick={() => append({ role: "user", content: suggestion })}
|
||||
|
@ -24,5 +24,5 @@ export function PromptSuggestions({
|
|||
))}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
);
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import { Dot } from "lucide-react"
|
||||
import { Dot } from "lucide-react";
|
||||
|
||||
export function TypingIndicator() {
|
||||
return (
|
||||
|
@ -11,5 +11,5 @@ export function TypingIndicator() {
|
|||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
);
|
||||
}
|
||||
|
|
|
@ -56,18 +56,19 @@ const manageItems = [
|
|||
},
|
||||
];
|
||||
|
||||
const optimizeItems: { title: string; url: string; icon: React.ElementType }[] = [
|
||||
const optimizeItems: { title: string; url: string; icon: React.ElementType }[] =
|
||||
[
|
||||
{
|
||||
title: "Evaluations",
|
||||
url: "",
|
||||
icon: Compass,
|
||||
title: "Evaluations",
|
||||
url: "",
|
||||
icon: Compass,
|
||||
},
|
||||
{
|
||||
title: "Fine-tuning",
|
||||
url: "",
|
||||
icon: Settings2,
|
||||
title: "Fine-tuning",
|
||||
url: "",
|
||||
icon: Settings2,
|
||||
},
|
||||
];
|
||||
];
|
||||
|
||||
interface SidebarItem {
|
||||
title: string;
|
||||
|
@ -79,7 +80,7 @@ export function AppSidebar() {
|
|||
const pathname = usePathname();
|
||||
|
||||
const renderSidebarItems = (items: SidebarItem[]) => {
|
||||
return items.map((item) => {
|
||||
return items.map(item => {
|
||||
const isActive = pathname.startsWith(item.url);
|
||||
return (
|
||||
<SidebarMenuItem key={item.title}>
|
||||
|
@ -88,14 +89,14 @@ export function AppSidebar() {
|
|||
className={cn(
|
||||
"justify-start",
|
||||
isActive &&
|
||||
"bg-gray-200 dark:bg-gray-700 hover:bg-gray-200 dark:hover:bg-gray-700 text-gray-900 dark:text-gray-100",
|
||||
"bg-gray-200 dark:bg-gray-700 hover:bg-gray-200 dark:hover:bg-gray-700 text-gray-900 dark:text-gray-100"
|
||||
)}
|
||||
>
|
||||
<Link href={item.url}>
|
||||
<item.icon
|
||||
className={cn(
|
||||
isActive && "text-gray-900 dark:text-gray-100",
|
||||
"mr-2 h-4 w-4",
|
||||
"mr-2 h-4 w-4"
|
||||
)}
|
||||
/>
|
||||
<span>{item.title}</span>
|
||||
|
@ -106,46 +107,48 @@ export function AppSidebar() {
|
|||
});
|
||||
};
|
||||
|
||||
return (
|
||||
<Sidebar>
|
||||
<SidebarHeader>
|
||||
<Link href="/">Llama Stack</Link>
|
||||
</SidebarHeader>
|
||||
<SidebarContent>
|
||||
<SidebarGroup>
|
||||
<SidebarGroupLabel>Create</SidebarGroupLabel>
|
||||
<SidebarGroupContent>
|
||||
<SidebarMenu>{renderSidebarItems(createItems)}</SidebarMenu>
|
||||
</SidebarGroupContent>
|
||||
</SidebarGroup>
|
||||
return (
|
||||
<Sidebar>
|
||||
<SidebarHeader>
|
||||
<Link href="/">Llama Stack</Link>
|
||||
</SidebarHeader>
|
||||
<SidebarContent>
|
||||
<SidebarGroup>
|
||||
<SidebarGroupLabel>Create</SidebarGroupLabel>
|
||||
<SidebarGroupContent>
|
||||
<SidebarMenu>{renderSidebarItems(createItems)}</SidebarMenu>
|
||||
</SidebarGroupContent>
|
||||
</SidebarGroup>
|
||||
|
||||
<SidebarGroup>
|
||||
<SidebarGroupLabel>Manage</SidebarGroupLabel>
|
||||
<SidebarGroupContent>
|
||||
<SidebarMenu>{renderSidebarItems(manageItems)}</SidebarMenu>
|
||||
</SidebarGroupContent>
|
||||
</SidebarGroup>
|
||||
<SidebarGroup>
|
||||
<SidebarGroupLabel>Manage</SidebarGroupLabel>
|
||||
<SidebarGroupContent>
|
||||
<SidebarMenu>{renderSidebarItems(manageItems)}</SidebarMenu>
|
||||
</SidebarGroupContent>
|
||||
</SidebarGroup>
|
||||
|
||||
<SidebarGroup>
|
||||
<SidebarGroupLabel>Optimize</SidebarGroupLabel>
|
||||
<SidebarGroupContent>
|
||||
<SidebarMenu>
|
||||
{optimizeItems.map((item) => (
|
||||
<SidebarMenuItem key={item.title}>
|
||||
<SidebarMenuButton
|
||||
disabled
|
||||
className="justify-start opacity-60 cursor-not-allowed"
|
||||
>
|
||||
<item.icon className="mr-2 h-4 w-4" />
|
||||
<span>{item.title}</span>
|
||||
<span className="ml-2 text-xs text-gray-500">(Coming Soon)</span>
|
||||
</SidebarMenuButton>
|
||||
</SidebarMenuItem>
|
||||
))}
|
||||
</SidebarMenu>
|
||||
</SidebarGroupContent>
|
||||
</SidebarGroup>
|
||||
</SidebarContent>
|
||||
</Sidebar>
|
||||
<SidebarGroup>
|
||||
<SidebarGroupLabel>Optimize</SidebarGroupLabel>
|
||||
<SidebarGroupContent>
|
||||
<SidebarMenu>
|
||||
{optimizeItems.map(item => (
|
||||
<SidebarMenuItem key={item.title}>
|
||||
<SidebarMenuButton
|
||||
disabled
|
||||
className="justify-start opacity-60 cursor-not-allowed"
|
||||
>
|
||||
<item.icon className="mr-2 h-4 w-4" />
|
||||
<span>{item.title}</span>
|
||||
<span className="ml-2 text-xs text-gray-500">
|
||||
(Coming Soon)
|
||||
</span>
|
||||
</SidebarMenuButton>
|
||||
</SidebarMenuItem>
|
||||
))}
|
||||
</SidebarMenu>
|
||||
</SidebarGroupContent>
|
||||
</SidebarGroup>
|
||||
</SidebarContent>
|
||||
</Sidebar>
|
||||
);
|
||||
}
|
||||
|
|
|
@ -2,7 +2,7 @@ import React from "react";
|
|||
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
|
||||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
|
||||
export function DetailLoadingView({ title }: { title: string }) {
|
||||
export function DetailLoadingView() {
|
||||
return (
|
||||
<>
|
||||
<Skeleton className="h-8 w-3/4 mb-6" /> {/* Title Skeleton */}
|
||||
|
|
|
@ -67,7 +67,7 @@ describe("LogsTable Viewport Loading", () => {
|
|||
() => {
|
||||
expect(mockLoadMore).toHaveBeenCalled();
|
||||
},
|
||||
{ timeout: 300 },
|
||||
{ timeout: 300 }
|
||||
);
|
||||
|
||||
expect(mockLoadMore).toHaveBeenCalledTimes(1);
|
||||
|
@ -81,11 +81,11 @@ describe("LogsTable Viewport Loading", () => {
|
|||
{...defaultProps}
|
||||
status="loading-more"
|
||||
onLoadMore={mockLoadMore}
|
||||
/>,
|
||||
/>
|
||||
);
|
||||
|
||||
// Wait for possible triggers
|
||||
await new Promise((resolve) => setTimeout(resolve, 300));
|
||||
await new Promise(resolve => setTimeout(resolve, 300));
|
||||
|
||||
expect(mockLoadMore).not.toHaveBeenCalled();
|
||||
});
|
||||
|
@ -94,15 +94,11 @@ describe("LogsTable Viewport Loading", () => {
|
|||
const mockLoadMore = jest.fn();
|
||||
|
||||
render(
|
||||
<LogsTable
|
||||
{...defaultProps}
|
||||
status="loading"
|
||||
onLoadMore={mockLoadMore}
|
||||
/>,
|
||||
<LogsTable {...defaultProps} status="loading" onLoadMore={mockLoadMore} />
|
||||
);
|
||||
|
||||
// Wait for possible triggers
|
||||
await new Promise((resolve) => setTimeout(resolve, 300));
|
||||
await new Promise(resolve => setTimeout(resolve, 300));
|
||||
|
||||
expect(mockLoadMore).not.toHaveBeenCalled();
|
||||
});
|
||||
|
@ -111,18 +107,18 @@ describe("LogsTable Viewport Loading", () => {
|
|||
const mockLoadMore = jest.fn();
|
||||
|
||||
render(
|
||||
<LogsTable {...defaultProps} hasMore={false} onLoadMore={mockLoadMore} />,
|
||||
<LogsTable {...defaultProps} hasMore={false} onLoadMore={mockLoadMore} />
|
||||
);
|
||||
|
||||
// Wait for possible triggers
|
||||
await new Promise((resolve) => setTimeout(resolve, 300));
|
||||
await new Promise(resolve => setTimeout(resolve, 300));
|
||||
|
||||
expect(mockLoadMore).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test("sentinel element should not be rendered when loading", () => {
|
||||
const { container } = render(
|
||||
<LogsTable {...defaultProps} status="loading-more" />,
|
||||
<LogsTable {...defaultProps} status="loading-more" />
|
||||
);
|
||||
|
||||
// Check that no sentinel row with height: 1 exists
|
||||
|
@ -132,7 +128,7 @@ describe("LogsTable Viewport Loading", () => {
|
|||
|
||||
test("sentinel element should be rendered when not loading and hasMore", () => {
|
||||
const { container } = render(
|
||||
<LogsTable {...defaultProps} hasMore={true} status="idle" />,
|
||||
<LogsTable {...defaultProps} hasMore={true} status="idle" />
|
||||
);
|
||||
|
||||
// Check that sentinel row exists
|
||||
|
|
|
@ -70,7 +70,7 @@ describe("LogsTable", () => {
|
|||
describe("Loading State", () => {
|
||||
test("renders skeleton UI when isLoading is true", () => {
|
||||
const { container } = render(
|
||||
<LogsTable {...defaultProps} status="loading" />,
|
||||
<LogsTable {...defaultProps} status="loading" />
|
||||
);
|
||||
|
||||
// Check for skeleton in the table caption
|
||||
|
@ -78,7 +78,7 @@ describe("LogsTable", () => {
|
|||
expect(tableCaption).toBeInTheDocument();
|
||||
if (tableCaption) {
|
||||
const captionSkeleton = tableCaption.querySelector(
|
||||
'[data-slot="skeleton"]',
|
||||
'[data-slot="skeleton"]'
|
||||
);
|
||||
expect(captionSkeleton).toBeInTheDocument();
|
||||
}
|
||||
|
@ -88,7 +88,7 @@ describe("LogsTable", () => {
|
|||
expect(tableBody).toBeInTheDocument();
|
||||
if (tableBody) {
|
||||
const bodySkeletons = tableBody.querySelectorAll(
|
||||
'[data-slot="skeleton"]',
|
||||
'[data-slot="skeleton"]'
|
||||
);
|
||||
expect(bodySkeletons.length).toBeGreaterThan(0);
|
||||
}
|
||||
|
@ -102,7 +102,7 @@ describe("LogsTable", () => {
|
|||
|
||||
test("renders correct number of skeleton rows", () => {
|
||||
const { container } = render(
|
||||
<LogsTable {...defaultProps} status="loading" />,
|
||||
<LogsTable {...defaultProps} status="loading" />
|
||||
);
|
||||
|
||||
const skeletonRows = container.querySelectorAll("tbody tr");
|
||||
|
@ -118,10 +118,10 @@ describe("LogsTable", () => {
|
|||
{...defaultProps}
|
||||
status="error"
|
||||
error={{ name: "Error", message: errorMessage } as Error}
|
||||
/>,
|
||||
/>
|
||||
);
|
||||
expect(
|
||||
screen.getByText("Unable to load chat completions"),
|
||||
screen.getByText("Unable to load chat completions")
|
||||
).toBeInTheDocument();
|
||||
expect(screen.getByText(errorMessage)).toBeInTheDocument();
|
||||
});
|
||||
|
@ -132,29 +132,25 @@ describe("LogsTable", () => {
|
|||
{...defaultProps}
|
||||
status="error"
|
||||
error={{ name: "Error", message: "" } as Error}
|
||||
/>,
|
||||
/>
|
||||
);
|
||||
expect(
|
||||
screen.getByText("Unable to load chat completions"),
|
||||
screen.getByText("Unable to load chat completions")
|
||||
).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText(
|
||||
"An unexpected error occurred while loading the data.",
|
||||
),
|
||||
screen.getByText("An unexpected error occurred while loading the data.")
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("renders default error message when error prop is an object without message", () => {
|
||||
render(
|
||||
<LogsTable {...defaultProps} status="error" error={{} as Error} />,
|
||||
<LogsTable {...defaultProps} status="error" error={{} as Error} />
|
||||
);
|
||||
expect(
|
||||
screen.getByText("Unable to load chat completions"),
|
||||
screen.getByText("Unable to load chat completions")
|
||||
).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText(
|
||||
"An unexpected error occurred while loading the data.",
|
||||
),
|
||||
screen.getByText("An unexpected error occurred while loading the data.")
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
|
@ -164,7 +160,7 @@ describe("LogsTable", () => {
|
|||
{...defaultProps}
|
||||
status="error"
|
||||
error={{ name: "Error", message: "Test error" } as Error}
|
||||
/>,
|
||||
/>
|
||||
);
|
||||
const table = screen.queryByRole("table");
|
||||
expect(table).not.toBeInTheDocument();
|
||||
|
@ -178,7 +174,7 @@ describe("LogsTable", () => {
|
|||
{...defaultProps}
|
||||
data={[]}
|
||||
emptyMessage="Custom empty message"
|
||||
/>,
|
||||
/>
|
||||
);
|
||||
expect(screen.getByText("Custom empty message")).toBeInTheDocument();
|
||||
|
||||
|
@ -214,7 +210,7 @@ describe("LogsTable", () => {
|
|||
{...defaultProps}
|
||||
data={mockData}
|
||||
caption="Custom table caption"
|
||||
/>,
|
||||
/>
|
||||
);
|
||||
|
||||
// Table caption
|
||||
|
@ -311,8 +307,8 @@ describe("LogsTable", () => {
|
|||
// Verify truncated text is displayed
|
||||
const truncatedTexts = screen.getAllByText("This is a ...");
|
||||
expect(truncatedTexts).toHaveLength(2); // one for input, one for output
|
||||
truncatedTexts.forEach((textElement) =>
|
||||
expect(textElement).toBeInTheDocument(),
|
||||
truncatedTexts.forEach(textElement =>
|
||||
expect(textElement).toBeInTheDocument()
|
||||
);
|
||||
});
|
||||
|
||||
|
@ -332,12 +328,12 @@ describe("LogsTable", () => {
|
|||
|
||||
// Model name should not be passed to truncateText
|
||||
expect(truncateText).not.toHaveBeenCalledWith(
|
||||
"very-long-model-name-that-should-not-be-truncated",
|
||||
"very-long-model-name-that-should-not-be-truncated"
|
||||
);
|
||||
|
||||
// Full model name should be displayed
|
||||
expect(
|
||||
screen.getByText("very-long-model-name-that-should-not-be-truncated"),
|
||||
screen.getByText("very-long-model-name-that-should-not-be-truncated")
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
|
|
@ -142,7 +142,7 @@ export function LogsTable({
|
|||
<Table>
|
||||
<TableCaption className="sr-only">{caption}</TableCaption>
|
||||
<TableBody>
|
||||
{data.map((row) => (
|
||||
{data.map(row => (
|
||||
<TableRow
|
||||
key={row.id}
|
||||
onClick={() => router.push(row.detailPath)}
|
||||
|
|
|
@ -22,7 +22,7 @@ export function GroupedItemsDisplay({
|
|||
|
||||
return (
|
||||
<>
|
||||
{groupedItems.map((groupedItem) => {
|
||||
{groupedItems.map(groupedItem => {
|
||||
// If this is a function call with an output, render the grouped component
|
||||
if (
|
||||
groupedItem.outputItem &&
|
||||
|
|
|
@ -18,7 +18,7 @@ export interface GroupedItem {
|
|||
* @returns Array of grouped items with their outputs
|
||||
*/
|
||||
export function useFunctionCallGrouping(
|
||||
items: AnyResponseItem[],
|
||||
items: AnyResponseItem[]
|
||||
): GroupedItem[] {
|
||||
return useMemo(() => {
|
||||
const groupedItems: GroupedItem[] = [];
|
||||
|
|
|
@ -52,7 +52,7 @@ export function ItemRenderer({
|
|||
// Fallback to generic item for unknown types
|
||||
return (
|
||||
<GenericItemComponent
|
||||
item={item as any}
|
||||
item={item as Record<string, unknown>}
|
||||
index={index}
|
||||
keyPrefix={keyPrefix}
|
||||
/>
|
||||
|
|
|
@ -20,7 +20,7 @@ export function MessageItemComponent({
|
|||
content = item.content;
|
||||
} else if (Array.isArray(item.content)) {
|
||||
content = item.content
|
||||
.map((c) => {
|
||||
.map(c => {
|
||||
return c.type === "input_text" || c.type === "output_text"
|
||||
? c.text
|
||||
: JSON.stringify(c);
|
||||
|
|
|
@ -18,7 +18,7 @@ describe("ResponseDetailView", () => {
|
|||
describe("Loading State", () => {
|
||||
test("renders loading skeleton when isLoading is true", () => {
|
||||
const { container } = render(
|
||||
<ResponseDetailView {...defaultProps} isLoading={true} />,
|
||||
<ResponseDetailView {...defaultProps} isLoading={true} />
|
||||
);
|
||||
|
||||
// Check for skeleton elements
|
||||
|
@ -36,13 +36,13 @@ describe("ResponseDetailView", () => {
|
|||
<ResponseDetailView
|
||||
{...defaultProps}
|
||||
error={{ name: "Error", message: errorMessage }}
|
||||
/>,
|
||||
/>
|
||||
);
|
||||
|
||||
expect(screen.getByText("Responses Details")).toBeInTheDocument();
|
||||
// The error message is split across elements, so we check for parts
|
||||
expect(
|
||||
screen.getByText(/Error loading details for ID/),
|
||||
screen.getByText(/Error loading details for ID/)
|
||||
).toBeInTheDocument();
|
||||
expect(screen.getByText(/test_id/)).toBeInTheDocument();
|
||||
expect(screen.getByText(/Network Error/)).toBeInTheDocument();
|
||||
|
@ -53,11 +53,11 @@ describe("ResponseDetailView", () => {
|
|||
<ResponseDetailView
|
||||
{...defaultProps}
|
||||
error={{ name: "Error", message: "" }}
|
||||
/>,
|
||||
/>
|
||||
);
|
||||
|
||||
expect(
|
||||
screen.getByText(/Error loading details for ID/),
|
||||
screen.getByText(/Error loading details for ID/)
|
||||
).toBeInTheDocument();
|
||||
expect(screen.getByText(/test_id/)).toBeInTheDocument();
|
||||
});
|
||||
|
@ -124,14 +124,14 @@ describe("ResponseDetailView", () => {
|
|||
// Check properties - use regex to handle text split across elements
|
||||
expect(screen.getByText(/Created/)).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText(new Date(1710000000 * 1000).toLocaleString()),
|
||||
screen.getByText(new Date(1710000000 * 1000).toLocaleString())
|
||||
).toBeInTheDocument();
|
||||
|
||||
// Check for the specific ID label (not Previous Response ID)
|
||||
expect(
|
||||
screen.getByText((content, element) => {
|
||||
return element?.tagName === "STRONG" && content === "ID:";
|
||||
}),
|
||||
})
|
||||
).toBeInTheDocument();
|
||||
expect(screen.getByText("resp_123")).toBeInTheDocument();
|
||||
|
||||
|
@ -166,7 +166,7 @@ describe("ResponseDetailView", () => {
|
|||
};
|
||||
|
||||
render(
|
||||
<ResponseDetailView {...defaultProps} response={minimalResponse} />,
|
||||
<ResponseDetailView {...defaultProps} response={minimalResponse} />
|
||||
);
|
||||
|
||||
// Should show required properties
|
||||
|
@ -179,7 +179,7 @@ describe("ResponseDetailView", () => {
|
|||
expect(screen.queryByText("Top P")).not.toBeInTheDocument();
|
||||
expect(screen.queryByText("Parallel Tool Calls")).not.toBeInTheDocument();
|
||||
expect(
|
||||
screen.queryByText("Previous Response ID"),
|
||||
screen.queryByText("Previous Response ID")
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
|
@ -196,7 +196,7 @@ describe("ResponseDetailView", () => {
|
|||
|
||||
// The error is shown in the properties sidebar, not as a separate "Error" label
|
||||
expect(
|
||||
screen.getByText("invalid_request: The request was invalid"),
|
||||
screen.getByText("invalid_request: The request was invalid")
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
@ -218,7 +218,7 @@ describe("ResponseDetailView", () => {
|
|||
{...defaultProps}
|
||||
response={mockResponse}
|
||||
isLoadingInputItems={true}
|
||||
/>,
|
||||
/>
|
||||
);
|
||||
|
||||
// Check for skeleton loading in input items section
|
||||
|
@ -227,7 +227,7 @@ describe("ResponseDetailView", () => {
|
|||
{...defaultProps}
|
||||
response={mockResponse}
|
||||
isLoadingInputItems={true}
|
||||
/>,
|
||||
/>
|
||||
);
|
||||
|
||||
const skeletons = container.querySelectorAll('[data-slot="skeleton"]');
|
||||
|
@ -243,16 +243,16 @@ describe("ResponseDetailView", () => {
|
|||
name: "Error",
|
||||
message: "Failed to load input items",
|
||||
}}
|
||||
/>,
|
||||
/>
|
||||
);
|
||||
|
||||
expect(
|
||||
screen.getByText(
|
||||
"Error loading input items: Failed to load input items",
|
||||
),
|
||||
"Error loading input items: Failed to load input items"
|
||||
)
|
||||
).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText("Falling back to response input data."),
|
||||
screen.getByText("Falling back to response input data.")
|
||||
).toBeInTheDocument();
|
||||
|
||||
// Should still show fallback input data
|
||||
|
@ -276,7 +276,7 @@ describe("ResponseDetailView", () => {
|
|||
{...defaultProps}
|
||||
response={mockResponse}
|
||||
inputItems={mockInputItems}
|
||||
/>,
|
||||
/>
|
||||
);
|
||||
|
||||
// Should show input items data, not response.input
|
||||
|
@ -295,7 +295,7 @@ describe("ResponseDetailView", () => {
|
|||
{...defaultProps}
|
||||
response={mockResponse}
|
||||
inputItems={emptyInputItems}
|
||||
/>,
|
||||
/>
|
||||
);
|
||||
|
||||
// Should show fallback input data
|
||||
|
@ -313,7 +313,7 @@ describe("ResponseDetailView", () => {
|
|||
{...defaultProps}
|
||||
response={responseWithoutInput}
|
||||
inputItems={null}
|
||||
/>,
|
||||
/>
|
||||
);
|
||||
|
||||
expect(screen.getByText("No input data available.")).toBeInTheDocument();
|
||||
|
@ -443,7 +443,7 @@ describe("ResponseDetailView", () => {
|
|||
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
|
||||
|
||||
expect(
|
||||
screen.getByText('input_function({"param": "value"})'),
|
||||
screen.getByText('input_function({"param": "value"})')
|
||||
).toBeInTheDocument();
|
||||
expect(screen.getByText("Function Call")).toBeInTheDocument();
|
||||
});
|
||||
|
@ -468,7 +468,7 @@ describe("ResponseDetailView", () => {
|
|||
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
|
||||
|
||||
expect(
|
||||
screen.getByText("web_search_call(status: completed)"),
|
||||
screen.getByText("web_search_call(status: completed)")
|
||||
).toBeInTheDocument();
|
||||
expect(screen.getByText("Function Call")).toBeInTheDocument();
|
||||
expect(screen.getByText("(Web Search)")).toBeInTheDocument();
|
||||
|
@ -522,7 +522,7 @@ describe("ResponseDetailView", () => {
|
|||
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
|
||||
|
||||
expect(
|
||||
screen.getByText("First output Second output"),
|
||||
screen.getByText("First output Second output")
|
||||
).toBeInTheDocument();
|
||||
expect(screen.getByText("Assistant")).toBeInTheDocument();
|
||||
});
|
||||
|
@ -549,7 +549,7 @@ describe("ResponseDetailView", () => {
|
|||
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
|
||||
|
||||
expect(
|
||||
screen.getByText('search_function({"query": "test"})'),
|
||||
screen.getByText('search_function({"query": "test"})')
|
||||
).toBeInTheDocument();
|
||||
expect(screen.getByText("Function Call")).toBeInTheDocument();
|
||||
});
|
||||
|
@ -598,7 +598,7 @@ describe("ResponseDetailView", () => {
|
|||
render(<ResponseDetailView {...defaultProps} response={mockResponse} />);
|
||||
|
||||
expect(
|
||||
screen.getByText("web_search_call(status: completed)"),
|
||||
screen.getByText("web_search_call(status: completed)")
|
||||
).toBeInTheDocument();
|
||||
expect(screen.getByText(/Function Call/)).toBeInTheDocument();
|
||||
expect(screen.getByText("(Web Search)")).toBeInTheDocument();
|
||||
|
@ -616,7 +616,7 @@ describe("ResponseDetailView", () => {
|
|||
type: "unknown_type",
|
||||
custom_field: "custom_value",
|
||||
data: { nested: "object" },
|
||||
} as any,
|
||||
} as unknown,
|
||||
],
|
||||
input: [],
|
||||
};
|
||||
|
@ -625,7 +625,7 @@ describe("ResponseDetailView", () => {
|
|||
|
||||
// Should show JSON stringified content
|
||||
expect(
|
||||
screen.getByText(/custom_field.*custom_value/),
|
||||
screen.getByText(/custom_field.*custom_value/)
|
||||
).toBeInTheDocument();
|
||||
expect(screen.getByText("(unknown_type)")).toBeInTheDocument();
|
||||
});
|
||||
|
@ -666,7 +666,7 @@ describe("ResponseDetailView", () => {
|
|||
role: "assistant",
|
||||
call_id: "call_123",
|
||||
content: "sunny and warm",
|
||||
} as any, // Using any to bypass the type restriction for this test
|
||||
} as unknown, // Using any to bypass the type restriction for this test
|
||||
],
|
||||
input: [],
|
||||
};
|
||||
|
@ -676,7 +676,7 @@ describe("ResponseDetailView", () => {
|
|||
// Should show the function call and message as separate items (not grouped)
|
||||
expect(screen.getByText("Function Call")).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText('get_weather({"city": "Tokyo"})'),
|
||||
screen.getByText('get_weather({"city": "Tokyo"})')
|
||||
).toBeInTheDocument();
|
||||
expect(screen.getByText("Assistant")).toBeInTheDocument();
|
||||
expect(screen.getByText("sunny and warm")).toBeInTheDocument();
|
||||
|
@ -706,7 +706,7 @@ describe("ResponseDetailView", () => {
|
|||
status: "completed",
|
||||
call_id: "call_123",
|
||||
output: "sunny and warm",
|
||||
} as any, // Using any to bypass the type restriction for this test
|
||||
} as unknown,
|
||||
],
|
||||
input: [],
|
||||
};
|
||||
|
@ -717,7 +717,7 @@ describe("ResponseDetailView", () => {
|
|||
expect(screen.getByText("Function Call")).toBeInTheDocument();
|
||||
expect(screen.getByText("Arguments")).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText('get_weather({"city": "Tokyo"})'),
|
||||
screen.getByText('get_weather({"city": "Tokyo"})')
|
||||
).toBeInTheDocument();
|
||||
// Use getAllByText since there are multiple "Output" elements (card title and output label)
|
||||
const outputElements = screen.getAllByText("Output");
|
||||
|
|
|
@ -146,7 +146,7 @@ describe("ResponsesTable", () => {
|
|||
expect(tableCaption).toBeInTheDocument();
|
||||
if (tableCaption) {
|
||||
const captionSkeleton = tableCaption.querySelector(
|
||||
'[data-slot="skeleton"]',
|
||||
'[data-slot="skeleton"]'
|
||||
);
|
||||
expect(captionSkeleton).toBeInTheDocument();
|
||||
}
|
||||
|
@ -156,7 +156,7 @@ describe("ResponsesTable", () => {
|
|||
expect(tableBody).toBeInTheDocument();
|
||||
if (tableBody) {
|
||||
const bodySkeletons = tableBody.querySelectorAll(
|
||||
'[data-slot="skeleton"]',
|
||||
'[data-slot="skeleton"]'
|
||||
);
|
||||
expect(bodySkeletons.length).toBeGreaterThan(0);
|
||||
}
|
||||
|
@ -176,14 +176,14 @@ describe("ResponsesTable", () => {
|
|||
|
||||
render(<ResponsesTable {...defaultProps} />);
|
||||
expect(
|
||||
screen.getByText("Unable to load chat completions"),
|
||||
screen.getByText("Unable to load chat completions")
|
||||
).toBeInTheDocument();
|
||||
expect(screen.getByText(errorMessage)).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test.each([{ name: "Error", message: "" }, {}])(
|
||||
"renders default error message when error has no message",
|
||||
(errorObject) => {
|
||||
errorObject => {
|
||||
mockedUsePagination.mockReturnValue({
|
||||
data: [],
|
||||
status: "error",
|
||||
|
@ -194,14 +194,14 @@ describe("ResponsesTable", () => {
|
|||
|
||||
render(<ResponsesTable {...defaultProps} />);
|
||||
expect(
|
||||
screen.getByText("Unable to load chat completions"),
|
||||
screen.getByText("Unable to load chat completions")
|
||||
).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText(
|
||||
"An unexpected error occurred while loading the data.",
|
||||
),
|
||||
"An unexpected error occurred while loading the data."
|
||||
)
|
||||
).toBeInTheDocument();
|
||||
},
|
||||
}
|
||||
);
|
||||
});
|
||||
|
||||
|
@ -275,7 +275,7 @@ describe("ResponsesTable", () => {
|
|||
|
||||
// Table caption
|
||||
expect(
|
||||
screen.getByText("A list of your recent responses."),
|
||||
screen.getByText("A list of your recent responses.")
|
||||
).toBeInTheDocument();
|
||||
|
||||
// Table headers
|
||||
|
@ -289,14 +289,14 @@ describe("ResponsesTable", () => {
|
|||
expect(screen.getByText("Test output")).toBeInTheDocument();
|
||||
expect(screen.getByText("llama-test-model")).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText(new Date(1710000000 * 1000).toLocaleString()),
|
||||
screen.getByText(new Date(1710000000 * 1000).toLocaleString())
|
||||
).toBeInTheDocument();
|
||||
|
||||
expect(screen.getByText("Another input")).toBeInTheDocument();
|
||||
expect(screen.getByText("Another output")).toBeInTheDocument();
|
||||
expect(screen.getByText("llama-another-model")).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText(new Date(1710001000 * 1000).toLocaleString()),
|
||||
screen.getByText(new Date(1710001000 * 1000).toLocaleString())
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
@ -487,7 +487,7 @@ describe("ResponsesTable", () => {
|
|||
|
||||
render(<ResponsesTable {...defaultProps} />);
|
||||
expect(
|
||||
screen.getByText('search_function({"query": "test"})'),
|
||||
screen.getByText('search_function({"query": "test"})')
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
|
@ -548,7 +548,7 @@ describe("ResponsesTable", () => {
|
|||
|
||||
render(<ResponsesTable {...defaultProps} />);
|
||||
expect(
|
||||
screen.getByText("web_search_call(status: completed)"),
|
||||
screen.getByText("web_search_call(status: completed)")
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
|
@ -565,7 +565,7 @@ describe("ResponsesTable", () => {
|
|||
id: "unknown_123",
|
||||
status: "completed",
|
||||
custom_field: "custom_value",
|
||||
} as any,
|
||||
} as unknown,
|
||||
],
|
||||
input: [{ type: "message", content: "input" }],
|
||||
};
|
||||
|
@ -594,7 +594,7 @@ describe("ResponsesTable", () => {
|
|||
{
|
||||
type: "unknown_type",
|
||||
data: "some data",
|
||||
} as any,
|
||||
} as unknown,
|
||||
],
|
||||
input: [{ type: "message", content: "input" }],
|
||||
};
|
||||
|
@ -623,7 +623,7 @@ describe("ResponsesTable", () => {
|
|||
return typeof text === "string" && text.length > effectiveMaxLength
|
||||
? text.slice(0, effectiveMaxLength) + "..."
|
||||
: text;
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
const longInput =
|
||||
|
@ -665,7 +665,7 @@ describe("ResponsesTable", () => {
|
|||
|
||||
// The truncated text should be present for both input and output
|
||||
const truncatedTexts = screen.getAllByText(
|
||||
longInput.slice(0, 10) + "...",
|
||||
longInput.slice(0, 10) + "..."
|
||||
);
|
||||
expect(truncatedTexts.length).toBe(2); // one for input, one for output
|
||||
});
|
||||
|
|
|
@ -27,7 +27,7 @@ interface ResponsesTableProps {
|
|||
* Helper function to convert ResponseListResponse.Data to OpenAIResponse
|
||||
*/
|
||||
const convertResponseListData = (
|
||||
responseData: ResponseListResponse.Data,
|
||||
responseData: ResponseListResponse.Data
|
||||
): OpenAIResponse => {
|
||||
return {
|
||||
id: responseData.id,
|
||||
|
@ -56,8 +56,8 @@ function getInputText(response: OpenAIResponse): string {
|
|||
}
|
||||
|
||||
function getOutputText(response: OpenAIResponse): string {
|
||||
const firstMessage = response.output.find((item) =>
|
||||
isMessageItem(item as any),
|
||||
const firstMessage = response.output.find(item =>
|
||||
isMessageItem(item as Record<string, unknown>)
|
||||
);
|
||||
if (firstMessage) {
|
||||
const content = extractContentFromItem(firstMessage as MessageItem);
|
||||
|
@ -66,15 +66,15 @@ function getOutputText(response: OpenAIResponse): string {
|
|||
}
|
||||
}
|
||||
|
||||
const functionCall = response.output.find((item) =>
|
||||
isFunctionCallItem(item as any),
|
||||
const functionCall = response.output.find(item =>
|
||||
isFunctionCallItem(item as Record<string, unknown>)
|
||||
);
|
||||
if (functionCall) {
|
||||
return formatFunctionCall(functionCall as FunctionCallItem);
|
||||
}
|
||||
|
||||
const webSearchCall = response.output.find((item) =>
|
||||
isWebSearchCallItem(item as any),
|
||||
const webSearchCall = response.output.find(item =>
|
||||
isWebSearchCallItem(item as Record<string, unknown>)
|
||||
);
|
||||
if (webSearchCall) {
|
||||
return formatWebSearchCall(webSearchCall as WebSearchCallItem);
|
||||
|
@ -95,7 +95,7 @@ function extractContentFromItem(item: {
|
|||
} else if (Array.isArray(item.content)) {
|
||||
const textContent = item.content.find(
|
||||
(c: ResponseInputMessageContent) =>
|
||||
c.type === "input_text" || c.type === "output_text",
|
||||
c.type === "input_text" || c.type === "output_text"
|
||||
);
|
||||
return textContent?.text || "";
|
||||
}
|
||||
|
@ -131,14 +131,14 @@ export function ResponsesTable({ paginationOptions }: ResponsesTableProps) {
|
|||
limit: number;
|
||||
model?: string;
|
||||
order?: string;
|
||||
},
|
||||
}
|
||||
) => {
|
||||
const response = await client.responses.list({
|
||||
after: params.after,
|
||||
limit: params.limit,
|
||||
...(params.model && { model: params.model }),
|
||||
...(params.order && { order: params.order }),
|
||||
} as any);
|
||||
} as Parameters<typeof client.responses.list>[0]);
|
||||
|
||||
const listResponse = response as ResponseListResponse;
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ export type AnyResponseItem =
|
|||
| FunctionCallOutputItem;
|
||||
|
||||
export function isMessageInput(
|
||||
item: ResponseInput,
|
||||
item: ResponseInput
|
||||
): item is ResponseInput & { type: "message" } {
|
||||
return item.type === "message";
|
||||
}
|
||||
|
@ -39,23 +39,23 @@ export function isMessageItem(item: AnyResponseItem): item is MessageItem {
|
|||
}
|
||||
|
||||
export function isFunctionCallItem(
|
||||
item: AnyResponseItem,
|
||||
item: AnyResponseItem
|
||||
): item is FunctionCallItem {
|
||||
return item.type === "function_call" && "name" in item;
|
||||
}
|
||||
|
||||
export function isWebSearchCallItem(
|
||||
item: AnyResponseItem,
|
||||
item: AnyResponseItem
|
||||
): item is WebSearchCallItem {
|
||||
return item.type === "web_search_call";
|
||||
}
|
||||
|
||||
export function isFunctionCallOutputItem(
|
||||
item: AnyResponseItem,
|
||||
item: AnyResponseItem
|
||||
): item is FunctionCallOutputItem {
|
||||
return (
|
||||
item.type === "function_call_output" &&
|
||||
"call_id" in item &&
|
||||
typeof (item as any).call_id === "string"
|
||||
typeof (item as Record<string, unknown>).call_id === "string"
|
||||
);
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
"use client"
|
||||
"use client";
|
||||
|
||||
import { useEffect, useRef } from "react"
|
||||
import { useEffect, useRef } from "react";
|
||||
|
||||
// Configuration constants for the audio analyzer
|
||||
const AUDIO_CONFIG = {
|
||||
|
@ -14,12 +14,12 @@ const AUDIO_CONFIG = {
|
|||
MAX_INTENSITY: 255, // Maximum gray value (brighter)
|
||||
INTENSITY_RANGE: 155, // MAX_INTENSITY - MIN_INTENSITY
|
||||
},
|
||||
} as const
|
||||
} as const;
|
||||
|
||||
interface AudioVisualizerProps {
|
||||
stream: MediaStream | null
|
||||
isRecording: boolean
|
||||
onClick: () => void
|
||||
stream: MediaStream | null;
|
||||
isRecording: boolean;
|
||||
onClick: () => void;
|
||||
}
|
||||
|
||||
export function AudioVisualizer({
|
||||
|
@ -28,91 +28,91 @@ export function AudioVisualizer({
|
|||
onClick,
|
||||
}: AudioVisualizerProps) {
|
||||
// Refs for managing audio context and animation
|
||||
const canvasRef = useRef<HTMLCanvasElement>(null)
|
||||
const audioContextRef = useRef<AudioContext | null>(null)
|
||||
const analyserRef = useRef<AnalyserNode | null>(null)
|
||||
const animationFrameRef = useRef<number>()
|
||||
const containerRef = useRef<HTMLDivElement>(null)
|
||||
const canvasRef = useRef<HTMLCanvasElement>(null);
|
||||
const audioContextRef = useRef<AudioContext | null>(null);
|
||||
const analyserRef = useRef<AnalyserNode | null>(null);
|
||||
const animationFrameRef = useRef<number>();
|
||||
const containerRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
// Cleanup function to stop visualization and close audio context
|
||||
const cleanup = () => {
|
||||
if (animationFrameRef.current) {
|
||||
cancelAnimationFrame(animationFrameRef.current)
|
||||
cancelAnimationFrame(animationFrameRef.current);
|
||||
}
|
||||
if (audioContextRef.current) {
|
||||
audioContextRef.current.close()
|
||||
audioContextRef.current.close();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Cleanup on unmount
|
||||
useEffect(() => {
|
||||
return cleanup
|
||||
}, [])
|
||||
return cleanup;
|
||||
}, []);
|
||||
|
||||
// Start or stop visualization based on recording state
|
||||
useEffect(() => {
|
||||
if (stream && isRecording) {
|
||||
startVisualization()
|
||||
startVisualization();
|
||||
} else {
|
||||
cleanup()
|
||||
cleanup();
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [stream, isRecording])
|
||||
}, [stream, isRecording]);
|
||||
|
||||
// Handle window resize
|
||||
useEffect(() => {
|
||||
const handleResize = () => {
|
||||
if (canvasRef.current && containerRef.current) {
|
||||
const container = containerRef.current
|
||||
const canvas = canvasRef.current
|
||||
const dpr = window.devicePixelRatio || 1
|
||||
const container = containerRef.current;
|
||||
const canvas = canvasRef.current;
|
||||
const dpr = window.devicePixelRatio || 1;
|
||||
|
||||
// Set canvas size based on container and device pixel ratio
|
||||
const rect = container.getBoundingClientRect()
|
||||
const rect = container.getBoundingClientRect();
|
||||
// Account for the 2px total margin (1px on each side)
|
||||
canvas.width = (rect.width - 2) * dpr
|
||||
canvas.height = (rect.height - 2) * dpr
|
||||
canvas.width = (rect.width - 2) * dpr;
|
||||
canvas.height = (rect.height - 2) * dpr;
|
||||
|
||||
// Scale canvas CSS size to match container minus margins
|
||||
canvas.style.width = `${rect.width - 2}px`
|
||||
canvas.style.height = `${rect.height - 2}px`
|
||||
canvas.style.width = `${rect.width - 2}px`;
|
||||
canvas.style.height = `${rect.height - 2}px`;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
window.addEventListener("resize", handleResize)
|
||||
window.addEventListener("resize", handleResize);
|
||||
// Initial setup
|
||||
handleResize()
|
||||
handleResize();
|
||||
|
||||
return () => window.removeEventListener("resize", handleResize)
|
||||
}, [])
|
||||
return () => window.removeEventListener("resize", handleResize);
|
||||
}, []);
|
||||
|
||||
// Initialize audio context and start visualization
|
||||
const startVisualization = async () => {
|
||||
try {
|
||||
const audioContext = new AudioContext()
|
||||
audioContextRef.current = audioContext
|
||||
const audioContext = new AudioContext();
|
||||
audioContextRef.current = audioContext;
|
||||
|
||||
const analyser = audioContext.createAnalyser()
|
||||
analyser.fftSize = AUDIO_CONFIG.FFT_SIZE
|
||||
analyser.smoothingTimeConstant = AUDIO_CONFIG.SMOOTHING
|
||||
analyserRef.current = analyser
|
||||
const analyser = audioContext.createAnalyser();
|
||||
analyser.fftSize = AUDIO_CONFIG.FFT_SIZE;
|
||||
analyser.smoothingTimeConstant = AUDIO_CONFIG.SMOOTHING;
|
||||
analyserRef.current = analyser;
|
||||
|
||||
const source = audioContext.createMediaStreamSource(stream!)
|
||||
source.connect(analyser)
|
||||
const source = audioContext.createMediaStreamSource(stream!);
|
||||
source.connect(analyser);
|
||||
|
||||
draw()
|
||||
draw();
|
||||
} catch (error) {
|
||||
console.error("Error starting visualization:", error)
|
||||
console.error("Error starting visualization:", error);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Calculate the color intensity based on bar height
|
||||
const getBarColor = (normalizedHeight: number) => {
|
||||
const intensity =
|
||||
Math.floor(normalizedHeight * AUDIO_CONFIG.COLOR.INTENSITY_RANGE) +
|
||||
AUDIO_CONFIG.COLOR.MIN_INTENSITY
|
||||
return `rgb(${intensity}, ${intensity}, ${intensity})`
|
||||
}
|
||||
AUDIO_CONFIG.COLOR.MIN_INTENSITY;
|
||||
return `rgb(${intensity}, ${intensity}, ${intensity})`;
|
||||
};
|
||||
|
||||
// Draw a single bar of the visualizer
|
||||
const drawBar = (
|
||||
|
@ -123,52 +123,52 @@ export function AudioVisualizer({
|
|||
height: number,
|
||||
color: string
|
||||
) => {
|
||||
ctx.fillStyle = color
|
||||
ctx.fillStyle = color;
|
||||
// Draw upper bar (above center)
|
||||
ctx.fillRect(x, centerY - height, width, height)
|
||||
ctx.fillRect(x, centerY - height, width, height);
|
||||
// Draw lower bar (below center)
|
||||
ctx.fillRect(x, centerY, width, height)
|
||||
}
|
||||
ctx.fillRect(x, centerY, width, height);
|
||||
};
|
||||
|
||||
// Main drawing function
|
||||
const draw = () => {
|
||||
if (!isRecording) return
|
||||
if (!isRecording) return;
|
||||
|
||||
const canvas = canvasRef.current
|
||||
const ctx = canvas?.getContext("2d")
|
||||
if (!canvas || !ctx || !analyserRef.current) return
|
||||
const canvas = canvasRef.current;
|
||||
const ctx = canvas?.getContext("2d");
|
||||
if (!canvas || !ctx || !analyserRef.current) return;
|
||||
|
||||
const dpr = window.devicePixelRatio || 1
|
||||
ctx.scale(dpr, dpr)
|
||||
const dpr = window.devicePixelRatio || 1;
|
||||
ctx.scale(dpr, dpr);
|
||||
|
||||
const analyser = analyserRef.current
|
||||
const bufferLength = analyser.frequencyBinCount
|
||||
const frequencyData = new Uint8Array(bufferLength)
|
||||
const analyser = analyserRef.current;
|
||||
const bufferLength = analyser.frequencyBinCount;
|
||||
const frequencyData = new Uint8Array(bufferLength);
|
||||
|
||||
const drawFrame = () => {
|
||||
animationFrameRef.current = requestAnimationFrame(drawFrame)
|
||||
animationFrameRef.current = requestAnimationFrame(drawFrame);
|
||||
|
||||
// Get current frequency data
|
||||
analyser.getByteFrequencyData(frequencyData)
|
||||
analyser.getByteFrequencyData(frequencyData);
|
||||
|
||||
// Clear canvas - use CSS pixels for clearing
|
||||
ctx.clearRect(0, 0, canvas.width / dpr, canvas.height / dpr)
|
||||
ctx.clearRect(0, 0, canvas.width / dpr, canvas.height / dpr);
|
||||
|
||||
// Calculate dimensions in CSS pixels
|
||||
const barWidth = Math.max(
|
||||
AUDIO_CONFIG.MIN_BAR_WIDTH,
|
||||
canvas.width / dpr / bufferLength - AUDIO_CONFIG.BAR_SPACING
|
||||
)
|
||||
const centerY = canvas.height / dpr / 2
|
||||
let x = 0
|
||||
);
|
||||
const centerY = canvas.height / dpr / 2;
|
||||
let x = 0;
|
||||
|
||||
// Draw each frequency bar
|
||||
for (let i = 0; i < bufferLength; i++) {
|
||||
const normalizedHeight = frequencyData[i] / 255 // Convert to 0-1 range
|
||||
const normalizedHeight = frequencyData[i] / 255; // Convert to 0-1 range
|
||||
const barHeight = Math.max(
|
||||
AUDIO_CONFIG.MIN_BAR_HEIGHT,
|
||||
normalizedHeight * centerY
|
||||
)
|
||||
);
|
||||
|
||||
drawBar(
|
||||
ctx,
|
||||
|
@ -177,14 +177,14 @@ export function AudioVisualizer({
|
|||
barWidth,
|
||||
barHeight,
|
||||
getBarColor(normalizedHeight)
|
||||
)
|
||||
);
|
||||
|
||||
x += barWidth + AUDIO_CONFIG.BAR_SPACING
|
||||
x += barWidth + AUDIO_CONFIG.BAR_SPACING;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
drawFrame()
|
||||
}
|
||||
drawFrame();
|
||||
};
|
||||
|
||||
return (
|
||||
<div
|
||||
|
@ -194,5 +194,5 @@ export function AudioVisualizer({
|
|||
>
|
||||
<canvas ref={canvasRef} className="h-full w-full" />
|
||||
</div>
|
||||
)
|
||||
);
|
||||
}
|
||||
|
|
|
@ -14,7 +14,7 @@ function BreadcrumbList({ className, ...props }: React.ComponentProps<"ol">) {
|
|||
data-slot="breadcrumb-list"
|
||||
className={cn(
|
||||
"text-muted-foreground flex flex-wrap items-center gap-1.5 text-sm break-words sm:gap-2.5",
|
||||
className,
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
import * as React from "react"
|
||||
import { Slot } from "@radix-ui/react-slot"
|
||||
import { cva, type VariantProps } from "class-variance-authority"
|
||||
import * as React from "react";
|
||||
import { Slot } from "@radix-ui/react-slot";
|
||||
import { cva, type VariantProps } from "class-variance-authority";
|
||||
|
||||
import { cn } from "@/lib/utils"
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
const buttonVariants = cva(
|
||||
"inline-flex items-center justify-center gap-2 whitespace-nowrap rounded-md text-sm font-medium transition-all disabled:pointer-events-none disabled:opacity-50 [&_svg]:pointer-events-none [&_svg:not([class*='size-'])]:size-4 shrink-0 [&_svg]:shrink-0 outline-none focus-visible:border-ring focus-visible:ring-ring/50 focus-visible:ring-[3px] aria-invalid:ring-destructive/20 dark:aria-invalid:ring-destructive/40 aria-invalid:border-destructive",
|
||||
|
@ -33,7 +33,7 @@ const buttonVariants = cva(
|
|||
size: "default",
|
||||
},
|
||||
}
|
||||
)
|
||||
);
|
||||
|
||||
function Button({
|
||||
className,
|
||||
|
@ -43,9 +43,9 @@ function Button({
|
|||
...props
|
||||
}: React.ComponentProps<"button"> &
|
||||
VariantProps<typeof buttonVariants> & {
|
||||
asChild?: boolean
|
||||
asChild?: boolean;
|
||||
}) {
|
||||
const Comp = asChild ? Slot : "button"
|
||||
const Comp = asChild ? Slot : "button";
|
||||
|
||||
return (
|
||||
<Comp
|
||||
|
@ -53,7 +53,7 @@ function Button({
|
|||
className={cn(buttonVariants({ variant, size, className }))}
|
||||
{...props}
|
||||
/>
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
export { Button, buttonVariants }
|
||||
export { Button, buttonVariants };
|
||||
|
|
|
@ -8,7 +8,7 @@ function Card({ className, ...props }: React.ComponentProps<"div">) {
|
|||
data-slot="card"
|
||||
className={cn(
|
||||
"bg-card text-card-foreground flex flex-col gap-6 rounded-xl border py-6 shadow-sm",
|
||||
className,
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
|
@ -21,7 +21,7 @@ function CardHeader({ className, ...props }: React.ComponentProps<"div">) {
|
|||
data-slot="card-header"
|
||||
className={cn(
|
||||
"@container/card-header grid auto-rows-min grid-rows-[auto_auto] items-start gap-1.5 px-6 has-data-[slot=card-action]:grid-cols-[1fr_auto] [.border-b]:pb-6",
|
||||
className,
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
|
@ -54,7 +54,7 @@ function CardAction({ className, ...props }: React.ComponentProps<"div">) {
|
|||
data-slot="card-action"
|
||||
className={cn(
|
||||
"col-start-2 row-span-2 row-start-1 self-start justify-self-end",
|
||||
className,
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
"use client"
|
||||
"use client";
|
||||
|
||||
import * as CollapsiblePrimitive from "@radix-ui/react-collapsible"
|
||||
import * as CollapsiblePrimitive from "@radix-ui/react-collapsible";
|
||||
|
||||
function Collapsible({
|
||||
...props
|
||||
}: React.ComponentProps<typeof CollapsiblePrimitive.Root>) {
|
||||
return <CollapsiblePrimitive.Root data-slot="collapsible" {...props} />
|
||||
return <CollapsiblePrimitive.Root data-slot="collapsible" {...props} />;
|
||||
}
|
||||
|
||||
function CollapsibleTrigger({
|
||||
|
@ -16,7 +16,7 @@ function CollapsibleTrigger({
|
|||
data-slot="collapsible-trigger"
|
||||
{...props}
|
||||
/>
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
function CollapsibleContent({
|
||||
|
@ -27,7 +27,7 @@ function CollapsibleContent({
|
|||
data-slot="collapsible-content"
|
||||
{...props}
|
||||
/>
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
export { Collapsible, CollapsibleTrigger, CollapsibleContent }
|
||||
export { Collapsible, CollapsibleTrigger, CollapsibleContent };
|
||||
|
|
|
@ -1,21 +1,21 @@
|
|||
"use client"
|
||||
"use client";
|
||||
|
||||
import { Check, Copy } from "lucide-react"
|
||||
import { Check, Copy } from "lucide-react";
|
||||
|
||||
import { cn } from "@/lib/utils"
|
||||
import { useCopyToClipboard } from "@/hooks/use-copy-to-clipboard"
|
||||
import { Button } from "@/components/ui/button"
|
||||
import { cn } from "@/lib/utils";
|
||||
import { useCopyToClipboard } from "@/hooks/use-copy-to-clipboard";
|
||||
import { Button } from "@/components/ui/button";
|
||||
|
||||
type CopyButtonProps = {
|
||||
content: string
|
||||
copyMessage?: string
|
||||
}
|
||||
content: string;
|
||||
copyMessage?: string;
|
||||
};
|
||||
|
||||
export function CopyButton({ content, copyMessage }: CopyButtonProps) {
|
||||
const { isCopied, handleCopy } = useCopyToClipboard({
|
||||
text: content,
|
||||
copyMessage,
|
||||
})
|
||||
});
|
||||
|
||||
return (
|
||||
<Button
|
||||
|
@ -40,5 +40,5 @@ export function CopyButton({ content, copyMessage }: CopyButtonProps) {
|
|||
)}
|
||||
/>
|
||||
</Button>
|
||||
)
|
||||
);
|
||||
}
|
||||
|
|
|
@ -43,7 +43,7 @@ function DropdownMenuContent({
|
|||
sideOffset={sideOffset}
|
||||
className={cn(
|
||||
"bg-popover text-popover-foreground data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2 z-50 max-h-(--radix-dropdown-menu-content-available-height) min-w-[8rem] origin-(--radix-dropdown-menu-content-transform-origin) overflow-x-hidden overflow-y-auto rounded-md border p-1 shadow-md",
|
||||
className,
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
|
@ -75,7 +75,7 @@ function DropdownMenuItem({
|
|||
data-variant={variant}
|
||||
className={cn(
|
||||
"focus:bg-accent focus:text-accent-foreground data-[variant=destructive]:text-destructive data-[variant=destructive]:focus:bg-destructive/10 dark:data-[variant=destructive]:focus:bg-destructive/20 data-[variant=destructive]:focus:text-destructive data-[variant=destructive]:*:[svg]:!text-destructive [&_svg:not([class*='text-'])]:text-muted-foreground relative flex cursor-default items-center gap-2 rounded-sm px-2 py-1.5 text-sm outline-hidden select-none data-[disabled]:pointer-events-none data-[disabled]:opacity-50 data-[inset]:pl-8 [&_svg]:pointer-events-none [&_svg]:shrink-0 [&_svg:not([class*='size-'])]:size-4",
|
||||
className,
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
|
@ -93,7 +93,7 @@ function DropdownMenuCheckboxItem({
|
|||
data-slot="dropdown-menu-checkbox-item"
|
||||
className={cn(
|
||||
"focus:bg-accent focus:text-accent-foreground relative flex cursor-default items-center gap-2 rounded-sm py-1.5 pr-2 pl-8 text-sm outline-hidden select-none data-[disabled]:pointer-events-none data-[disabled]:opacity-50 [&_svg]:pointer-events-none [&_svg]:shrink-0 [&_svg:not([class*='size-'])]:size-4",
|
||||
className,
|
||||
className
|
||||
)}
|
||||
checked={checked}
|
||||
{...props}
|
||||
|
@ -129,7 +129,7 @@ function DropdownMenuRadioItem({
|
|||
data-slot="dropdown-menu-radio-item"
|
||||
className={cn(
|
||||
"focus:bg-accent focus:text-accent-foreground relative flex cursor-default items-center gap-2 rounded-sm py-1.5 pr-2 pl-8 text-sm outline-hidden select-none data-[disabled]:pointer-events-none data-[disabled]:opacity-50 [&_svg]:pointer-events-none [&_svg]:shrink-0 [&_svg:not([class*='size-'])]:size-4",
|
||||
className,
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
>
|
||||
|
@ -156,7 +156,7 @@ function DropdownMenuLabel({
|
|||
data-inset={inset}
|
||||
className={cn(
|
||||
"px-2 py-1.5 text-sm font-medium data-[inset]:pl-8",
|
||||
className,
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
|
@ -185,7 +185,7 @@ function DropdownMenuShortcut({
|
|||
data-slot="dropdown-menu-shortcut"
|
||||
className={cn(
|
||||
"text-muted-foreground ml-auto text-xs tracking-widest",
|
||||
className,
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
|
@ -212,7 +212,7 @@ function DropdownMenuSubTrigger({
|
|||
data-inset={inset}
|
||||
className={cn(
|
||||
"focus:bg-accent focus:text-accent-foreground data-[state=open]:bg-accent data-[state=open]:text-accent-foreground flex cursor-default items-center rounded-sm px-2 py-1.5 text-sm outline-hidden select-none data-[inset]:pl-8",
|
||||
className,
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
>
|
||||
|
@ -231,7 +231,7 @@ function DropdownMenuSubContent({
|
|||
data-slot="dropdown-menu-sub-content"
|
||||
className={cn(
|
||||
"bg-popover text-popover-foreground data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2 z-50 min-w-[8rem] origin-(--radix-dropdown-menu-content-transform-origin) overflow-hidden rounded-md border p-1 shadow-lg",
|
||||
className,
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
|
|
|
@ -1,18 +1,18 @@
|
|||
"use client"
|
||||
"use client";
|
||||
|
||||
import React, { useEffect } from "react"
|
||||
import { motion } from "framer-motion"
|
||||
import { FileIcon, X } from "lucide-react"
|
||||
import React, { useEffect } from "react";
|
||||
import { motion } from "framer-motion";
|
||||
import { FileIcon, X } from "lucide-react";
|
||||
|
||||
interface FilePreviewProps {
|
||||
file: File
|
||||
onRemove?: () => void
|
||||
file: File;
|
||||
onRemove?: () => void;
|
||||
}
|
||||
|
||||
export const FilePreview = React.forwardRef<HTMLDivElement, FilePreviewProps>(
|
||||
(props, ref) => {
|
||||
if (props.file.type.startsWith("image/")) {
|
||||
return <ImageFilePreview {...props} ref={ref} />
|
||||
return <ImageFilePreview {...props} ref={ref} />;
|
||||
}
|
||||
|
||||
if (
|
||||
|
@ -20,13 +20,13 @@ export const FilePreview = React.forwardRef<HTMLDivElement, FilePreviewProps>(
|
|||
props.file.name.endsWith(".txt") ||
|
||||
props.file.name.endsWith(".md")
|
||||
) {
|
||||
return <TextFilePreview {...props} ref={ref} />
|
||||
return <TextFilePreview {...props} ref={ref} />;
|
||||
}
|
||||
|
||||
return <GenericFilePreview {...props} ref={ref} />
|
||||
return <GenericFilePreview {...props} ref={ref} />;
|
||||
}
|
||||
)
|
||||
FilePreview.displayName = "FilePreview"
|
||||
);
|
||||
FilePreview.displayName = "FilePreview";
|
||||
|
||||
const ImageFilePreview = React.forwardRef<HTMLDivElement, FilePreviewProps>(
|
||||
({ file, onRemove }, ref) => {
|
||||
|
@ -62,23 +62,23 @@ const ImageFilePreview = React.forwardRef<HTMLDivElement, FilePreviewProps>(
|
|||
</button>
|
||||
) : null}
|
||||
</motion.div>
|
||||
)
|
||||
);
|
||||
}
|
||||
)
|
||||
ImageFilePreview.displayName = "ImageFilePreview"
|
||||
);
|
||||
ImageFilePreview.displayName = "ImageFilePreview";
|
||||
|
||||
const TextFilePreview = React.forwardRef<HTMLDivElement, FilePreviewProps>(
|
||||
({ file, onRemove }, ref) => {
|
||||
const [preview, setPreview] = React.useState<string>("")
|
||||
const [preview, setPreview] = React.useState<string>("");
|
||||
|
||||
useEffect(() => {
|
||||
const reader = new FileReader()
|
||||
reader.onload = (e) => {
|
||||
const text = e.target?.result as string
|
||||
setPreview(text.slice(0, 50) + (text.length > 50 ? "..." : ""))
|
||||
}
|
||||
reader.readAsText(file)
|
||||
}, [file])
|
||||
const reader = new FileReader();
|
||||
reader.onload = e => {
|
||||
const text = e.target?.result as string;
|
||||
setPreview(text.slice(0, 50) + (text.length > 50 ? "..." : ""));
|
||||
};
|
||||
reader.readAsText(file);
|
||||
}, [file]);
|
||||
|
||||
return (
|
||||
<motion.div
|
||||
|
@ -111,10 +111,10 @@ const TextFilePreview = React.forwardRef<HTMLDivElement, FilePreviewProps>(
|
|||
</button>
|
||||
) : null}
|
||||
</motion.div>
|
||||
)
|
||||
);
|
||||
}
|
||||
)
|
||||
TextFilePreview.displayName = "TextFilePreview"
|
||||
);
|
||||
TextFilePreview.displayName = "TextFilePreview";
|
||||
|
||||
const GenericFilePreview = React.forwardRef<HTMLDivElement, FilePreviewProps>(
|
||||
({ file, onRemove }, ref) => {
|
||||
|
@ -147,7 +147,7 @@ const GenericFilePreview = React.forwardRef<HTMLDivElement, FilePreviewProps>(
|
|||
</button>
|
||||
) : null}
|
||||
</motion.div>
|
||||
)
|
||||
);
|
||||
}
|
||||
)
|
||||
GenericFilePreview.displayName = "GenericFilePreview"
|
||||
);
|
||||
GenericFilePreview.displayName = "GenericFilePreview";
|
||||
|
|
|
@ -11,7 +11,7 @@ function Input({ className, type, ...props }: React.ComponentProps<"input">) {
|
|||
"file:text-foreground placeholder:text-muted-foreground selection:bg-primary selection:text-primary-foreground dark:bg-input/30 border-input flex h-9 w-full min-w-0 rounded-md border bg-transparent px-3 py-1 text-base shadow-xs transition-[color,box-shadow] outline-none file:inline-flex file:h-7 file:border-0 file:bg-transparent file:text-sm file:font-medium disabled:pointer-events-none disabled:cursor-not-allowed disabled:opacity-50 md:text-sm",
|
||||
"focus-visible:border-ring focus-visible:ring-ring/50 focus-visible:ring-[3px]",
|
||||
"aria-invalid:ring-destructive/20 dark:aria-invalid:ring-destructive/40 aria-invalid:border-destructive",
|
||||
className,
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
|
|
|
@ -1,27 +1,27 @@
|
|||
"use client"
|
||||
"use client";
|
||||
|
||||
import * as React from "react"
|
||||
import * as SelectPrimitive from "@radix-ui/react-select"
|
||||
import { CheckIcon, ChevronDownIcon, ChevronUpIcon } from "lucide-react"
|
||||
import * as React from "react";
|
||||
import * as SelectPrimitive from "@radix-ui/react-select";
|
||||
import { CheckIcon, ChevronDownIcon, ChevronUpIcon } from "lucide-react";
|
||||
|
||||
import { cn } from "@/lib/utils"
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
function Select({
|
||||
...props
|
||||
}: React.ComponentProps<typeof SelectPrimitive.Root>) {
|
||||
return <SelectPrimitive.Root data-slot="select" {...props} />
|
||||
return <SelectPrimitive.Root data-slot="select" {...props} />;
|
||||
}
|
||||
|
||||
function SelectGroup({
|
||||
...props
|
||||
}: React.ComponentProps<typeof SelectPrimitive.Group>) {
|
||||
return <SelectPrimitive.Group data-slot="select-group" {...props} />
|
||||
return <SelectPrimitive.Group data-slot="select-group" {...props} />;
|
||||
}
|
||||
|
||||
function SelectValue({
|
||||
...props
|
||||
}: React.ComponentProps<typeof SelectPrimitive.Value>) {
|
||||
return <SelectPrimitive.Value data-slot="select-value" {...props} />
|
||||
return <SelectPrimitive.Value data-slot="select-value" {...props} />;
|
||||
}
|
||||
|
||||
function SelectTrigger({
|
||||
|
@ -30,7 +30,7 @@ function SelectTrigger({
|
|||
children,
|
||||
...props
|
||||
}: React.ComponentProps<typeof SelectPrimitive.Trigger> & {
|
||||
size?: "sm" | "default"
|
||||
size?: "sm" | "default";
|
||||
}) {
|
||||
return (
|
||||
<SelectPrimitive.Trigger
|
||||
|
@ -47,7 +47,7 @@ function SelectTrigger({
|
|||
<ChevronDownIcon className="size-4 opacity-50" />
|
||||
</SelectPrimitive.Icon>
|
||||
</SelectPrimitive.Trigger>
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
function SelectContent({
|
||||
|
@ -82,7 +82,7 @@ function SelectContent({
|
|||
<SelectScrollDownButton />
|
||||
</SelectPrimitive.Content>
|
||||
</SelectPrimitive.Portal>
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
function SelectLabel({
|
||||
|
@ -95,7 +95,7 @@ function SelectLabel({
|
|||
className={cn("text-muted-foreground px-2 py-1.5 text-xs", className)}
|
||||
{...props}
|
||||
/>
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
function SelectItem({
|
||||
|
@ -119,7 +119,7 @@ function SelectItem({
|
|||
</span>
|
||||
<SelectPrimitive.ItemText>{children}</SelectPrimitive.ItemText>
|
||||
</SelectPrimitive.Item>
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
function SelectSeparator({
|
||||
|
@ -132,7 +132,7 @@ function SelectSeparator({
|
|||
className={cn("bg-border pointer-events-none -mx-1 my-1 h-px", className)}
|
||||
{...props}
|
||||
/>
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
function SelectScrollUpButton({
|
||||
|
@ -150,7 +150,7 @@ function SelectScrollUpButton({
|
|||
>
|
||||
<ChevronUpIcon className="size-4" />
|
||||
</SelectPrimitive.ScrollUpButton>
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
function SelectScrollDownButton({
|
||||
|
@ -168,7 +168,7 @@ function SelectScrollDownButton({
|
|||
>
|
||||
<ChevronDownIcon className="size-4" />
|
||||
</SelectPrimitive.ScrollDownButton>
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
export {
|
||||
|
@ -182,4 +182,4 @@ export {
|
|||
SelectSeparator,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
}
|
||||
};
|
||||
|
|
|
@ -18,7 +18,7 @@ function Separator({
|
|||
orientation={orientation}
|
||||
className={cn(
|
||||
"bg-border shrink-0 data-[orientation=horizontal]:h-px data-[orientation=horizontal]:w-full data-[orientation=vertical]:h-full data-[orientation=vertical]:w-px",
|
||||
className,
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
|
|
|
@ -37,7 +37,7 @@ function SheetOverlay({
|
|||
data-slot="sheet-overlay"
|
||||
className={cn(
|
||||
"data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 fixed inset-0 z-50 bg-black/50",
|
||||
className,
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
|
@ -67,7 +67,7 @@ function SheetContent({
|
|||
"data-[state=closed]:slide-out-to-top data-[state=open]:slide-in-from-top inset-x-0 top-0 h-auto border-b",
|
||||
side === "bottom" &&
|
||||
"data-[state=closed]:slide-out-to-bottom data-[state=open]:slide-in-from-bottom inset-x-0 bottom-0 h-auto border-t",
|
||||
className,
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
>
|
||||
|
|
|
@ -85,12 +85,12 @@ function SidebarProvider({
|
|||
// This sets the cookie to keep the sidebar state.
|
||||
document.cookie = `${SIDEBAR_COOKIE_NAME}=${openState}; path=/; max-age=${SIDEBAR_COOKIE_MAX_AGE}`;
|
||||
},
|
||||
[setOpenProp, open],
|
||||
[setOpenProp, open]
|
||||
);
|
||||
|
||||
// Helper to toggle the sidebar.
|
||||
const toggleSidebar = React.useCallback(() => {
|
||||
return isMobile ? setOpenMobile((open) => !open) : setOpen((open) => !open);
|
||||
return isMobile ? setOpenMobile(open => !open) : setOpen(open => !open);
|
||||
}, [isMobile, setOpen, setOpenMobile]);
|
||||
|
||||
// Adds a keyboard shortcut to toggle the sidebar.
|
||||
|
@ -123,7 +123,7 @@ function SidebarProvider({
|
|||
setOpenMobile,
|
||||
toggleSidebar,
|
||||
}),
|
||||
[state, open, setOpen, isMobile, openMobile, setOpenMobile, toggleSidebar],
|
||||
[state, open, setOpen, isMobile, openMobile, setOpenMobile, toggleSidebar]
|
||||
);
|
||||
|
||||
return (
|
||||
|
@ -140,7 +140,7 @@ function SidebarProvider({
|
|||
}
|
||||
className={cn(
|
||||
"group/sidebar-wrapper has-data-[variant=inset]:bg-sidebar flex min-h-svh w-full",
|
||||
className,
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
>
|
||||
|
@ -171,7 +171,7 @@ function Sidebar({
|
|||
data-slot="sidebar"
|
||||
className={cn(
|
||||
"bg-sidebar text-sidebar-foreground flex h-full w-(--sidebar-width) flex-col",
|
||||
className,
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
>
|
||||
|
@ -223,7 +223,7 @@ function Sidebar({
|
|||
"group-data-[side=right]:rotate-180",
|
||||
variant === "floating" || variant === "inset"
|
||||
? "group-data-[collapsible=icon]:w-[calc(var(--sidebar-width-icon)+(--spacing(4)))]"
|
||||
: "group-data-[collapsible=icon]:w-(--sidebar-width-icon)",
|
||||
: "group-data-[collapsible=icon]:w-(--sidebar-width-icon)"
|
||||
)}
|
||||
/>
|
||||
<div
|
||||
|
@ -237,7 +237,7 @@ function Sidebar({
|
|||
variant === "floating" || variant === "inset"
|
||||
? "p-2 group-data-[collapsible=icon]:w-[calc(var(--sidebar-width-icon)+(--spacing(4))+2px)]"
|
||||
: "group-data-[collapsible=icon]:w-(--sidebar-width-icon) group-data-[side=left]:border-r group-data-[side=right]:border-l",
|
||||
className,
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
>
|
||||
|
@ -267,7 +267,7 @@ function SidebarTrigger({
|
|||
variant="ghost"
|
||||
size="icon"
|
||||
className={cn("size-7", className)}
|
||||
onClick={(event) => {
|
||||
onClick={event => {
|
||||
onClick?.(event);
|
||||
toggleSidebar();
|
||||
}}
|
||||
|
@ -297,7 +297,7 @@ function SidebarRail({ className, ...props }: React.ComponentProps<"button">) {
|
|||
"hover:group-data-[collapsible=offcanvas]:bg-sidebar group-data-[collapsible=offcanvas]:translate-x-0 group-data-[collapsible=offcanvas]:after:left-full",
|
||||
"[[data-side=left][data-collapsible=offcanvas]_&]:-right-2",
|
||||
"[[data-side=right][data-collapsible=offcanvas]_&]:-left-2",
|
||||
className,
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
|
@ -311,7 +311,7 @@ function SidebarInset({ className, ...props }: React.ComponentProps<"main">) {
|
|||
className={cn(
|
||||
"bg-background relative flex w-full flex-1 flex-col",
|
||||
"md:peer-data-[variant=inset]:m-2 md:peer-data-[variant=inset]:ml-0 md:peer-data-[variant=inset]:rounded-xl md:peer-data-[variant=inset]:shadow-sm md:peer-data-[variant=inset]:peer-data-[state=collapsed]:ml-2",
|
||||
className,
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
|
@ -375,7 +375,7 @@ function SidebarContent({ className, ...props }: React.ComponentProps<"div">) {
|
|||
data-sidebar="content"
|
||||
className={cn(
|
||||
"flex min-h-0 flex-1 flex-col gap-2 overflow-auto group-data-[collapsible=icon]:overflow-hidden",
|
||||
className,
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
|
@ -407,7 +407,7 @@ function SidebarGroupLabel({
|
|||
className={cn(
|
||||
"text-sidebar-foreground/70 ring-sidebar-ring flex h-8 shrink-0 items-center rounded-md px-2 text-xs font-medium outline-hidden transition-[margin,opacity] duration-200 ease-linear focus-visible:ring-2 [&>svg]:size-4 [&>svg]:shrink-0",
|
||||
"group-data-[collapsible=icon]:-mt-8 group-data-[collapsible=icon]:opacity-0",
|
||||
className,
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
|
@ -430,7 +430,7 @@ function SidebarGroupAction({
|
|||
// Increases the hit area of the button on mobile.
|
||||
"after:absolute after:-inset-2 md:after:hidden",
|
||||
"group-data-[collapsible=icon]:hidden",
|
||||
className,
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
|
@ -492,7 +492,7 @@ const sidebarMenuButtonVariants = cva(
|
|||
variant: "default",
|
||||
size: "default",
|
||||
},
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
function SidebarMenuButton({
|
||||
|
@ -570,7 +570,7 @@ function SidebarMenuAction({
|
|||
"group-data-[collapsible=icon]:hidden",
|
||||
showOnHover &&
|
||||
"peer-data-[active=true]/menu-button:text-sidebar-accent-foreground group-focus-within/menu-item:opacity-100 group-hover/menu-item:opacity-100 data-[state=open]:opacity-100 md:opacity-0",
|
||||
className,
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
|
@ -592,7 +592,7 @@ function SidebarMenuBadge({
|
|||
"peer-data-[size=default]/menu-button:top-1.5",
|
||||
"peer-data-[size=lg]/menu-button:top-2.5",
|
||||
"group-data-[collapsible=icon]:hidden",
|
||||
className,
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
|
@ -645,7 +645,7 @@ function SidebarMenuSub({ className, ...props }: React.ComponentProps<"ul">) {
|
|||
className={cn(
|
||||
"border-sidebar-border mx-3.5 flex min-w-0 translate-x-px flex-col gap-1 border-l px-2.5 py-0.5",
|
||||
"group-data-[collapsible=icon]:hidden",
|
||||
className,
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
|
@ -691,7 +691,7 @@ function SidebarMenuSubButton({
|
|||
size === "sm" && "text-xs",
|
||||
size === "md" && "text-sm",
|
||||
"group-data-[collapsible=icon]:hidden",
|
||||
className,
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
"use client"
|
||||
"use client";
|
||||
|
||||
import { useTheme } from "next-themes"
|
||||
import { Toaster as Sonner, ToasterProps } from "sonner"
|
||||
import { useTheme } from "next-themes";
|
||||
import { Toaster as Sonner, ToasterProps } from "sonner";
|
||||
|
||||
const Toaster = ({ ...props }: ToasterProps) => {
|
||||
const { theme = "system" } = useTheme()
|
||||
const { theme = "system" } = useTheme();
|
||||
|
||||
return (
|
||||
<Sonner
|
||||
|
@ -19,7 +19,7 @@ const Toaster = ({ ...props }: ToasterProps) => {
|
|||
}
|
||||
{...props}
|
||||
/>
|
||||
)
|
||||
}
|
||||
);
|
||||
};
|
||||
|
||||
export { Toaster }
|
||||
export { Toaster };
|
||||
|
|
|
@ -45,7 +45,7 @@ function TableFooter({ className, ...props }: React.ComponentProps<"tfoot">) {
|
|||
data-slot="table-footer"
|
||||
className={cn(
|
||||
"bg-muted/50 border-t font-medium [&>tr]:last:border-b-0",
|
||||
className,
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
|
@ -58,7 +58,7 @@ function TableRow({ className, ...props }: React.ComponentProps<"tr">) {
|
|||
data-slot="table-row"
|
||||
className={cn(
|
||||
"hover:bg-muted/50 data-[state=selected]:bg-muted border-b transition-colors",
|
||||
className,
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
|
@ -71,7 +71,7 @@ function TableHead({ className, ...props }: React.ComponentProps<"th">) {
|
|||
data-slot="table-head"
|
||||
className={cn(
|
||||
"text-foreground h-10 px-2 text-left align-middle font-medium whitespace-nowrap [&:has([role=checkbox])]:pr-0 [&>[role=checkbox]]:translate-y-[2px]",
|
||||
className,
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
|
@ -84,7 +84,7 @@ function TableCell({ className, ...props }: React.ComponentProps<"td">) {
|
|||
data-slot="table-cell"
|
||||
className={cn(
|
||||
"p-2 align-middle whitespace-nowrap [&:has([role=checkbox])]:pr-0 [&>[role=checkbox]]:translate-y-[2px]",
|
||||
className,
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
|
|
|
@ -47,7 +47,7 @@ function TooltipContent({
|
|||
sideOffset={sideOffset}
|
||||
className={cn(
|
||||
"bg-primary text-primary-foreground animate-in fade-in-0 zoom-in-95 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2 z-50 w-fit origin-(--radix-tooltip-content-transform-origin) rounded-md px-3 py-1.5 text-xs text-balance",
|
||||
className,
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
>
|
||||
|
|
|
@ -0,0 +1,315 @@
|
|||
import React from "react";
|
||||
import { render, screen, fireEvent } from "@testing-library/react";
|
||||
import "@testing-library/jest-dom";
|
||||
import { VectorStoreDetailView } from "./vector-store-detail";
|
||||
import type { VectorStore } from "llama-stack-client/resources/vector-stores/vector-stores";
|
||||
import type { VectorStoreFile } from "llama-stack-client/resources/vector-stores/files";
|
||||
|
||||
const mockPush = jest.fn();
|
||||
jest.mock("next/navigation", () => ({
|
||||
useRouter: () => ({
|
||||
push: mockPush,
|
||||
}),
|
||||
}));
|
||||
|
||||
describe("VectorStoreDetailView", () => {
|
||||
const defaultProps = {
|
||||
store: null,
|
||||
files: [],
|
||||
isLoadingStore: false,
|
||||
isLoadingFiles: false,
|
||||
errorStore: null,
|
||||
errorFiles: null,
|
||||
id: "test_vector_store_id",
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
mockPush.mockClear();
|
||||
});
|
||||
|
||||
describe("Loading States", () => {
|
||||
test("renders loading skeleton when store is loading", () => {
|
||||
const { container } = render(
|
||||
<VectorStoreDetailView {...defaultProps} isLoadingStore={true} />
|
||||
);
|
||||
|
||||
const skeletons = container.querySelectorAll('[data-slot="skeleton"]');
|
||||
expect(skeletons.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
test("renders files loading skeleton when files are loading", () => {
|
||||
const mockStore: VectorStore = {
|
||||
id: "vs_123",
|
||||
name: "Test Vector Store",
|
||||
created_at: 1710000000,
|
||||
status: "ready",
|
||||
file_counts: { total: 5 },
|
||||
usage_bytes: 1024,
|
||||
metadata: {
|
||||
provider_id: "test_provider",
|
||||
provider_vector_db_id: "test_db_id",
|
||||
},
|
||||
};
|
||||
|
||||
const { container } = render(
|
||||
<VectorStoreDetailView
|
||||
{...defaultProps}
|
||||
store={mockStore}
|
||||
isLoadingFiles={true}
|
||||
/>
|
||||
);
|
||||
|
||||
expect(screen.getByText("Vector Store Details")).toBeInTheDocument();
|
||||
expect(screen.getByText("Files")).toBeInTheDocument();
|
||||
const skeletons = container.querySelectorAll('[data-slot="skeleton"]');
|
||||
expect(skeletons.length).toBeGreaterThan(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe("Error States", () => {
|
||||
test("renders error message when store error occurs", () => {
|
||||
render(
|
||||
<VectorStoreDetailView
|
||||
{...defaultProps}
|
||||
errorStore={{ name: "Error", message: "Failed to load store" }}
|
||||
/>
|
||||
);
|
||||
|
||||
expect(screen.getByText("Vector Store Details")).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText(/Error loading details for ID test_vector_store_id/)
|
||||
).toBeInTheDocument();
|
||||
expect(screen.getByText(/Failed to load store/)).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("renders files error when files fail to load", () => {
|
||||
const mockStore: VectorStore = {
|
||||
id: "vs_123",
|
||||
name: "Test Vector Store",
|
||||
created_at: 1710000000,
|
||||
status: "ready",
|
||||
file_counts: { total: 5 },
|
||||
usage_bytes: 1024,
|
||||
metadata: {
|
||||
provider_id: "test_provider",
|
||||
provider_vector_db_id: "test_db_id",
|
||||
},
|
||||
};
|
||||
|
||||
render(
|
||||
<VectorStoreDetailView
|
||||
{...defaultProps}
|
||||
store={mockStore}
|
||||
errorFiles={{ name: "Error", message: "Failed to load files" }}
|
||||
/>
|
||||
);
|
||||
|
||||
expect(screen.getByText("Files")).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText("Error loading files: Failed to load files")
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Not Found State", () => {
|
||||
test("renders not found message when store is null", () => {
|
||||
render(<VectorStoreDetailView {...defaultProps} store={null} />);
|
||||
|
||||
expect(screen.getByText("Vector Store Details")).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText(/No details found for ID: test_vector_store_id/)
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Store Data Rendering", () => {
|
||||
const mockStore: VectorStore = {
|
||||
id: "vs_123",
|
||||
name: "Test Vector Store",
|
||||
created_at: 1710000000,
|
||||
status: "ready",
|
||||
file_counts: { total: 3 },
|
||||
usage_bytes: 2048,
|
||||
metadata: {
|
||||
provider_id: "test_provider",
|
||||
provider_vector_db_id: "test_db_id",
|
||||
},
|
||||
};
|
||||
|
||||
test("renders store properties correctly", () => {
|
||||
render(<VectorStoreDetailView {...defaultProps} store={mockStore} />);
|
||||
|
||||
expect(screen.getByText("Vector Store Details")).toBeInTheDocument();
|
||||
expect(screen.getByText("vs_123")).toBeInTheDocument();
|
||||
expect(screen.getByText("Test Vector Store")).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText(new Date(1710000000 * 1000).toLocaleString())
|
||||
).toBeInTheDocument();
|
||||
expect(screen.getByText("ready")).toBeInTheDocument();
|
||||
expect(screen.getByText("3")).toBeInTheDocument();
|
||||
expect(screen.getByText("2048")).toBeInTheDocument();
|
||||
expect(screen.getByText("test_provider")).toBeInTheDocument();
|
||||
expect(screen.getByText("test_db_id")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("handles empty/missing optional fields", () => {
|
||||
const minimalStore: VectorStore = {
|
||||
id: "vs_minimal",
|
||||
name: "",
|
||||
created_at: 1710000000,
|
||||
status: "ready",
|
||||
file_counts: { total: 0 },
|
||||
usage_bytes: 0,
|
||||
metadata: {},
|
||||
};
|
||||
|
||||
render(<VectorStoreDetailView {...defaultProps} store={minimalStore} />);
|
||||
|
||||
expect(screen.getByText("vs_minimal")).toBeInTheDocument();
|
||||
expect(screen.getByText("ready")).toBeInTheDocument();
|
||||
const zeroTexts = screen.getAllByText("0");
|
||||
expect(zeroTexts.length).toBeGreaterThanOrEqual(2);
|
||||
});
|
||||
|
||||
test("shows empty files message when no files", () => {
|
||||
render(
|
||||
<VectorStoreDetailView {...defaultProps} store={mockStore} files={[]} />
|
||||
);
|
||||
|
||||
expect(screen.getByText("Files")).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText("No files in this vector store.")
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Files Table", () => {
|
||||
const mockStore: VectorStore = {
|
||||
id: "vs_123",
|
||||
name: "Test Vector Store",
|
||||
created_at: 1710000000,
|
||||
status: "ready",
|
||||
file_counts: { total: 2 },
|
||||
usage_bytes: 2048,
|
||||
metadata: {},
|
||||
};
|
||||
|
||||
const mockFiles: VectorStoreFile[] = [
|
||||
{
|
||||
id: "file_123",
|
||||
status: "completed",
|
||||
created_at: 1710001000,
|
||||
usage_bytes: 1024,
|
||||
},
|
||||
{
|
||||
id: "file_456",
|
||||
status: "processing",
|
||||
created_at: 1710002000,
|
||||
usage_bytes: 512,
|
||||
},
|
||||
];
|
||||
|
||||
test("renders files table with correct data", () => {
|
||||
render(
|
||||
<VectorStoreDetailView
|
||||
{...defaultProps}
|
||||
store={mockStore}
|
||||
files={mockFiles}
|
||||
/>
|
||||
);
|
||||
|
||||
expect(screen.getByText("Files")).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText("Files in this vector store")
|
||||
).toBeInTheDocument();
|
||||
|
||||
expect(screen.getByText("ID")).toBeInTheDocument();
|
||||
expect(screen.getByText("Status")).toBeInTheDocument();
|
||||
expect(screen.getByText("Created")).toBeInTheDocument();
|
||||
expect(screen.getByText("Usage Bytes")).toBeInTheDocument();
|
||||
|
||||
expect(screen.getByText("file_123")).toBeInTheDocument();
|
||||
expect(screen.getByText("completed")).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText(new Date(1710001000 * 1000).toLocaleString())
|
||||
).toBeInTheDocument();
|
||||
expect(screen.getByText("1024")).toBeInTheDocument();
|
||||
|
||||
expect(screen.getByText("file_456")).toBeInTheDocument();
|
||||
expect(screen.getByText("processing")).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText(new Date(1710002000 * 1000).toLocaleString())
|
||||
).toBeInTheDocument();
|
||||
expect(screen.getByText("512")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("file ID links are clickable and navigate correctly", () => {
|
||||
render(
|
||||
<VectorStoreDetailView
|
||||
{...defaultProps}
|
||||
store={mockStore}
|
||||
files={mockFiles}
|
||||
id="vs_123"
|
||||
/>
|
||||
);
|
||||
|
||||
const fileButton = screen.getByRole("button", { name: "file_123" });
|
||||
expect(fileButton).toBeInTheDocument();
|
||||
|
||||
fireEvent.click(fileButton);
|
||||
expect(mockPush).toHaveBeenCalledWith(
|
||||
"/logs/vector-stores/vs_123/files/file_123"
|
||||
);
|
||||
});
|
||||
|
||||
test("handles multiple file clicks correctly", () => {
|
||||
render(
|
||||
<VectorStoreDetailView
|
||||
{...defaultProps}
|
||||
store={mockStore}
|
||||
files={mockFiles}
|
||||
id="vs_123"
|
||||
/>
|
||||
);
|
||||
|
||||
const file1Button = screen.getByRole("button", { name: "file_123" });
|
||||
const file2Button = screen.getByRole("button", { name: "file_456" });
|
||||
|
||||
fireEvent.click(file1Button);
|
||||
expect(mockPush).toHaveBeenCalledWith(
|
||||
"/logs/vector-stores/vs_123/files/file_123"
|
||||
);
|
||||
|
||||
fireEvent.click(file2Button);
|
||||
expect(mockPush).toHaveBeenCalledWith(
|
||||
"/logs/vector-stores/vs_123/files/file_456"
|
||||
);
|
||||
|
||||
expect(mockPush).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
});
|
||||
|
||||
describe("Layout Structure", () => {
|
||||
const mockStore: VectorStore = {
|
||||
id: "vs_layout_test",
|
||||
name: "Layout Test Store",
|
||||
created_at: 1710000000,
|
||||
status: "ready",
|
||||
file_counts: { total: 1 },
|
||||
usage_bytes: 1024,
|
||||
metadata: {},
|
||||
};
|
||||
|
||||
test("renders main content and sidebar in correct layout", () => {
|
||||
render(<VectorStoreDetailView {...defaultProps} store={mockStore} />);
|
||||
|
||||
expect(screen.getByText("Files")).toBeInTheDocument();
|
||||
|
||||
expect(screen.getByText("vs_layout_test")).toBeInTheDocument();
|
||||
expect(screen.getByText("Layout Test Store")).toBeInTheDocument();
|
||||
expect(screen.getByText("ready")).toBeInTheDocument();
|
||||
expect(screen.getByText("1")).toBeInTheDocument();
|
||||
expect(screen.getByText("1024")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
|
@ -85,9 +85,9 @@ export function VectorStoreDetailView({
|
|||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{files.map((file) => (
|
||||
{files.map(file => (
|
||||
<TableRow key={file.id}>
|
||||
<TableCell>
|
||||
<TableCell>
|
||||
<Button
|
||||
variant="link"
|
||||
className="p-0 h-auto font-mono text-blue-600 hover:text-blue-800 dark:text-blue-400 dark:hover:text-blue-300"
|
||||
|
|
|
@ -45,7 +45,7 @@ test.describe("LogsTable Scroll and Progressive Loading", () => {
|
|||
const scrollContainer = page.locator("div.overflow-auto").first();
|
||||
|
||||
// Scroll to near the bottom
|
||||
await scrollContainer.evaluate((element) => {
|
||||
await scrollContainer.evaluate(element => {
|
||||
element.scrollTop = element.scrollHeight - element.clientHeight - 100;
|
||||
});
|
||||
|
||||
|
|
|
@ -10,7 +10,13 @@ const compat = new FlatCompat({
|
|||
});
|
||||
|
||||
const eslintConfig = [
|
||||
...compat.extends("next/core-web-vitals", "next/typescript"),
|
||||
...compat.extends("next/core-web-vitals", "next/typescript", "prettier"),
|
||||
...compat.plugins("prettier"),
|
||||
{
|
||||
rules: {
|
||||
"prettier/prettier": "error",
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
export default eslintConfig;
|
||||
|
|
|
@ -1,85 +1,85 @@
|
|||
import { useEffect, useRef, useState } from "react"
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
|
||||
import { recordAudio } from "@/lib/audio-utils"
|
||||
import { recordAudio } from "@/lib/audio-utils";
|
||||
|
||||
interface UseAudioRecordingOptions {
|
||||
transcribeAudio?: (blob: Blob) => Promise<string>
|
||||
onTranscriptionComplete?: (text: string) => void
|
||||
transcribeAudio?: (blob: Blob) => Promise<string>;
|
||||
onTranscriptionComplete?: (text: string) => void;
|
||||
}
|
||||
|
||||
export function useAudioRecording({
|
||||
transcribeAudio,
|
||||
onTranscriptionComplete,
|
||||
}: UseAudioRecordingOptions) {
|
||||
const [isListening, setIsListening] = useState(false)
|
||||
const [isSpeechSupported, setIsSpeechSupported] = useState(!!transcribeAudio)
|
||||
const [isRecording, setIsRecording] = useState(false)
|
||||
const [isTranscribing, setIsTranscribing] = useState(false)
|
||||
const [audioStream, setAudioStream] = useState<MediaStream | null>(null)
|
||||
const activeRecordingRef = useRef<any>(null)
|
||||
const [isListening, setIsListening] = useState(false);
|
||||
const [isSpeechSupported, setIsSpeechSupported] = useState(!!transcribeAudio);
|
||||
const [isRecording, setIsRecording] = useState(false);
|
||||
const [isTranscribing, setIsTranscribing] = useState(false);
|
||||
const [audioStream, setAudioStream] = useState<MediaStream | null>(null);
|
||||
const activeRecordingRef = useRef<any>(null);
|
||||
|
||||
useEffect(() => {
|
||||
const checkSpeechSupport = async () => {
|
||||
const hasMediaDevices = !!(
|
||||
navigator.mediaDevices && navigator.mediaDevices.getUserMedia
|
||||
)
|
||||
setIsSpeechSupported(hasMediaDevices && !!transcribeAudio)
|
||||
}
|
||||
);
|
||||
setIsSpeechSupported(hasMediaDevices && !!transcribeAudio);
|
||||
};
|
||||
|
||||
checkSpeechSupport()
|
||||
}, [transcribeAudio])
|
||||
checkSpeechSupport();
|
||||
}, [transcribeAudio]);
|
||||
|
||||
const stopRecording = async () => {
|
||||
setIsRecording(false)
|
||||
setIsTranscribing(true)
|
||||
setIsRecording(false);
|
||||
setIsTranscribing(true);
|
||||
try {
|
||||
// First stop the recording to get the final blob
|
||||
recordAudio.stop()
|
||||
recordAudio.stop();
|
||||
// Wait for the recording promise to resolve with the final blob
|
||||
const recording = await activeRecordingRef.current
|
||||
const recording = await activeRecordingRef.current;
|
||||
if (transcribeAudio) {
|
||||
const text = await transcribeAudio(recording)
|
||||
onTranscriptionComplete?.(text)
|
||||
const text = await transcribeAudio(recording);
|
||||
onTranscriptionComplete?.(text);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error transcribing audio:", error)
|
||||
console.error("Error transcribing audio:", error);
|
||||
} finally {
|
||||
setIsTranscribing(false)
|
||||
setIsListening(false)
|
||||
setIsTranscribing(false);
|
||||
setIsListening(false);
|
||||
if (audioStream) {
|
||||
audioStream.getTracks().forEach((track) => track.stop())
|
||||
setAudioStream(null)
|
||||
audioStream.getTracks().forEach(track => track.stop());
|
||||
setAudioStream(null);
|
||||
}
|
||||
activeRecordingRef.current = null
|
||||
activeRecordingRef.current = null;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const toggleListening = async () => {
|
||||
if (!isListening) {
|
||||
try {
|
||||
setIsListening(true)
|
||||
setIsRecording(true)
|
||||
setIsListening(true);
|
||||
setIsRecording(true);
|
||||
// Get audio stream first
|
||||
const stream = await navigator.mediaDevices.getUserMedia({
|
||||
audio: true,
|
||||
})
|
||||
setAudioStream(stream)
|
||||
});
|
||||
setAudioStream(stream);
|
||||
|
||||
// Start recording with the stream
|
||||
activeRecordingRef.current = recordAudio(stream)
|
||||
activeRecordingRef.current = recordAudio(stream);
|
||||
} catch (error) {
|
||||
console.error("Error recording audio:", error)
|
||||
setIsListening(false)
|
||||
setIsRecording(false)
|
||||
console.error("Error recording audio:", error);
|
||||
setIsListening(false);
|
||||
setIsRecording(false);
|
||||
if (audioStream) {
|
||||
audioStream.getTracks().forEach((track) => track.stop())
|
||||
setAudioStream(null)
|
||||
audioStream.getTracks().forEach(track => track.stop());
|
||||
setAudioStream(null);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
await stopRecording()
|
||||
await stopRecording();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
return {
|
||||
isListening,
|
||||
|
@ -89,5 +89,5 @@ export function useAudioRecording({
|
|||
audioStream,
|
||||
toggleListening,
|
||||
stopRecording,
|
||||
}
|
||||
};
|
||||
}
|
||||
|
|
|
@ -1,67 +1,67 @@
|
|||
import { useEffect, useRef, useState } from "react"
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
|
||||
// How many pixels from the bottom of the container to enable auto-scroll
|
||||
const ACTIVATION_THRESHOLD = 50
|
||||
const ACTIVATION_THRESHOLD = 50;
|
||||
// Minimum pixels of scroll-up movement required to disable auto-scroll
|
||||
const MIN_SCROLL_UP_THRESHOLD = 10
|
||||
const MIN_SCROLL_UP_THRESHOLD = 10;
|
||||
|
||||
export function useAutoScroll(dependencies: React.DependencyList) {
|
||||
const containerRef = useRef<HTMLDivElement | null>(null)
|
||||
const previousScrollTop = useRef<number | null>(null)
|
||||
const [shouldAutoScroll, setShouldAutoScroll] = useState(true)
|
||||
const containerRef = useRef<HTMLDivElement | null>(null);
|
||||
const previousScrollTop = useRef<number | null>(null);
|
||||
const [shouldAutoScroll, setShouldAutoScroll] = useState(true);
|
||||
|
||||
const scrollToBottom = () => {
|
||||
if (containerRef.current) {
|
||||
containerRef.current.scrollTop = containerRef.current.scrollHeight
|
||||
containerRef.current.scrollTop = containerRef.current.scrollHeight;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const handleScroll = () => {
|
||||
if (containerRef.current) {
|
||||
const { scrollTop, scrollHeight, clientHeight } = containerRef.current
|
||||
const { scrollTop, scrollHeight, clientHeight } = containerRef.current;
|
||||
|
||||
const distanceFromBottom = Math.abs(
|
||||
scrollHeight - scrollTop - clientHeight
|
||||
)
|
||||
);
|
||||
|
||||
const isScrollingUp = previousScrollTop.current
|
||||
? scrollTop < previousScrollTop.current
|
||||
: false
|
||||
: false;
|
||||
|
||||
const scrollUpDistance = previousScrollTop.current
|
||||
? previousScrollTop.current - scrollTop
|
||||
: 0
|
||||
: 0;
|
||||
|
||||
const isDeliberateScrollUp =
|
||||
isScrollingUp && scrollUpDistance > MIN_SCROLL_UP_THRESHOLD
|
||||
isScrollingUp && scrollUpDistance > MIN_SCROLL_UP_THRESHOLD;
|
||||
|
||||
if (isDeliberateScrollUp) {
|
||||
setShouldAutoScroll(false)
|
||||
setShouldAutoScroll(false);
|
||||
} else {
|
||||
const isScrolledToBottom = distanceFromBottom < ACTIVATION_THRESHOLD
|
||||
setShouldAutoScroll(isScrolledToBottom)
|
||||
const isScrolledToBottom = distanceFromBottom < ACTIVATION_THRESHOLD;
|
||||
setShouldAutoScroll(isScrolledToBottom);
|
||||
}
|
||||
|
||||
previousScrollTop.current = scrollTop
|
||||
previousScrollTop.current = scrollTop;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const handleTouchStart = () => {
|
||||
setShouldAutoScroll(false)
|
||||
}
|
||||
setShouldAutoScroll(false);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (containerRef.current) {
|
||||
previousScrollTop.current = containerRef.current.scrollTop
|
||||
previousScrollTop.current = containerRef.current.scrollTop;
|
||||
}
|
||||
}, [])
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
if (shouldAutoScroll) {
|
||||
scrollToBottom()
|
||||
scrollToBottom();
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, dependencies)
|
||||
}, dependencies);
|
||||
|
||||
return {
|
||||
containerRef,
|
||||
|
@ -69,5 +69,5 @@ export function useAutoScroll(dependencies: React.DependencyList) {
|
|||
handleScroll,
|
||||
shouldAutoScroll,
|
||||
handleTouchStart,
|
||||
}
|
||||
};
|
||||
}
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue