adding logo and favicon

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>

chore: Enable keyword search for Milvus inline (#3073)

With https://github.com/milvus-io/milvus-lite/pull/294 - Milvus Lite
supports keyword search using BM25. While introducing keyword search we
had explicitly disabled it for inline milvus. This PR removes the need
for the check, and enables `inline::milvus` for tests.

<!-- If resolving an issue, uncomment and update the line below -->
<!-- Closes #[issue-number] -->

Run llama stack with `inline::milvus` enabled:

```
pytest tests/integration/vector_io/test_openai_vector_stores.py::test_openai_vector_store_search_modes --stack-config=http://localhost:8321 --embedding-model=all-MiniLM-L6-v2 -v
```

```
INFO     2025-08-07 17:06:20,932 tests.integration.conftest:64 tests: Setting DISABLE_CODE_SANDBOX=1 for macOS
=========================================================================================== test session starts ============================================================================================
platform darwin -- Python 3.12.11, pytest-7.4.4, pluggy-1.5.0 -- /Users/vnarsing/miniconda3/envs/stack-client/bin/python
cachedir: .pytest_cache
metadata: {'Python': '3.12.11', 'Platform': 'macOS-14.7.6-arm64-arm-64bit', 'Packages': {'pytest': '7.4.4', 'pluggy': '1.5.0'}, 'Plugins': {'asyncio': '0.23.8', 'cov': '6.0.0', 'timeout': '2.2.0', 'socket': '0.7.0', 'html': '3.1.1', 'langsmith': '0.3.39', 'anyio': '4.8.0', 'metadata': '3.0.0'}}
rootdir: /Users/vnarsing/go/src/github/meta-llama/llama-stack
configfile: pyproject.toml
plugins: asyncio-0.23.8, cov-6.0.0, timeout-2.2.0, socket-0.7.0, html-3.1.1, langsmith-0.3.39, anyio-4.8.0, metadata-3.0.0
asyncio: mode=Mode.AUTO
collected 3 items

tests/integration/vector_io/test_openai_vector_stores.py::test_openai_vector_store_search_modes[None-None-all-MiniLM-L6-v2-None-384-vector] PASSED                                                   [ 33%]
tests/integration/vector_io/test_openai_vector_stores.py::test_openai_vector_store_search_modes[None-None-all-MiniLM-L6-v2-None-384-keyword] PASSED                                                  [ 66%]
tests/integration/vector_io/test_openai_vector_stores.py::test_openai_vector_store_search_modes[None-None-all-MiniLM-L6-v2-None-384-hybrid] PASSED                                                   [100%]

============================================================================================ 3 passed in 4.75s =============================================================================================
```

Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
Co-authored-by: Francisco Arceo <arceofrancisco@gmail.com>

chore: Fixup main pre commit (#3204)

build: Bump version to 0.2.18

chore: Faster npm pre-commit (#3206)

Adds npm to pre-commit.yml installation and caches ui
Removes node installation during pre-commit.

<!-- If resolving an issue, uncomment and update the line below -->
<!-- Closes #[issue-number] -->

<!-- Describe the tests you ran to verify your changes with result
summaries. *Provide clear instructions so the plan can be easily
re-executed.* -->

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>

chiecking in for tonight, wip moving to agents api

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>

remove log

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>

updated

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>

fix: disable ui-prettier & ui-eslint (#3207)

chore(pre-commit): add pre-commit hook to enforce llama_stack logger usage (#3061)

This PR adds a step in pre-commit to enforce using `llama_stack` logger.

Currently, various parts of the code base uses different loggers. As a
custom `llama_stack` logger exist and used in the codebase, it is better
to standardize its utilization.

Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
Co-authored-by: Matthew Farrellee <matt@cs.wisc.edu>

fix: fix ```openai_embeddings``` for asymmetric embedding NIMs (#3205)

NVIDIA asymmetric embedding models (e.g.,
`nvidia/llama-3.2-nv-embedqa-1b-v2`) require an `input_type` parameter
not present in the standard OpenAI embeddings API. This PR adds the
`input_type="query"` as default and updates the documentation to suggest
using the `embedding` API for passage embeddings.

<!-- If resolving an issue, uncomment and update the line below -->
Resolves #2892

```
pytest -s -v tests/integration/inference/test_openai_embeddings.py   --stack-config="inference=nvidia"   --embedding-model="nvidia/llama-3.2-nv-embedqa-1b-v2"   --env NVIDIA_API_KEY={nvidia_api_key}   --env NVIDIA_BASE_URL="https://integrate.api.nvidia.com"
```

cleaning up

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>

updating session manager to cache messages locally

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>

fix linter

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>

more cleanup

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-08-19 16:44:20 -04:00
parent e7be568d7e
commit 6620b625f1
76 changed files with 2343 additions and 1187 deletions

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import importlib.resources
import logging
import sys
from pydantic import BaseModel
@ -17,9 +16,10 @@ from llama_stack.core.external import load_external_apis
from llama_stack.core.utils.exec import run_command
from llama_stack.core.utils.image_types import LlamaStackImageType
from llama_stack.distributions.template import DistributionTemplate
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api
log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="core")
# These are the dependencies needed by the distribution server.
# `llama-stack` is automatically installed by the installation script.

View file

@ -3,7 +3,6 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import textwrap
from typing import Any
@ -21,9 +20,10 @@ from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
from llama_stack.core.utils.config_dirs import EXTERNAL_PROVIDERS_DIR
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.core.utils.prompt_for_config import prompt_for_config
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, ProviderSpec
logger = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="core")
def configure_single_provider(registry: dict[str, ProviderSpec], provider: Provider) -> Provider:

View file

@ -7,7 +7,7 @@
import asyncio
import inspect
import json
import logging
import logging # allow-direct-logging
import os
import sys
from concurrent.futures import ThreadPoolExecutor
@ -48,6 +48,7 @@ from llama_stack.core.stack import (
from llama_stack.core.utils.config import redact_sensitive_fields
from llama_stack.core.utils.context import preserve_contexts_async_generator
from llama_stack.core.utils.exec import in_notebook
from llama_stack.log import get_logger
from llama_stack.providers.utils.telemetry.tracing import (
CURRENT_TRACE_CONTEXT,
end_trace,
@ -55,7 +56,7 @@ from llama_stack.providers.utils.telemetry.tracing import (
start_trace,
)
logger = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="core")
T = TypeVar("T")

View file

@ -6,15 +6,15 @@
import contextvars
import json
import logging
from contextlib import AbstractContextManager
from typing import Any
from llama_stack.core.datatypes import User
from llama_stack.log import get_logger
from .utils.dynamic import instantiate_class_type
log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="core")
# Context variable for request provider data and auth attributes
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)

View file

@ -9,7 +9,7 @@ import asyncio
import functools
import inspect
import json
import logging
import logging # allow-direct-logging
import os
import ssl
import sys

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import importlib
import os
import signal
import subprocess
@ -12,9 +12,9 @@ import sys
from termcolor import cprint
log = logging.getLogger(__name__)
from llama_stack.log import get_logger
import importlib
log = get_logger(name=__name__, category="core")
def formulate_run_args(image_type: str, image_name: str) -> list:

View file

@ -6,7 +6,6 @@
import inspect
import json
import logging
from enum import Enum
from typing import Annotated, Any, Literal, Union, get_args, get_origin
@ -14,7 +13,9 @@ from pydantic import BaseModel
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefinedType
log = logging.getLogger(__name__)
from llama_stack.log import get_logger
log = get_logger(name=__name__, category="core")
def is_list_of_primitives(field_type):

View file

@ -4,10 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import logging # allow-direct-logging
import os
import re
from logging.config import dictConfig
from logging.config import dictConfig # allow-direct-logging
from rich.console import Console
from rich.errors import MarkupError

View file

@ -13,14 +13,15 @@
# Copyright (c) Meta Platforms, Inc. and its affiliates.
import math
from logging import getLogger
import torch
import torch.nn.functional as F
from llama_stack.log import get_logger
from .utils import get_negative_inf_value, to_2tuple
logger = getLogger()
logger = get_logger(name=__name__, category="models::llama")
def resize_local_position_embedding(orig_pos_embed, grid_size):

View file

@ -13,7 +13,6 @@
import math
from collections import defaultdict
from logging import getLogger
from typing import Any
import torch
@ -21,9 +20,11 @@ import torchvision.transforms as tv
from PIL import Image
from torchvision.transforms import functional as F
from llama_stack.log import get_logger
IMAGE_RES = 224
logger = getLogger()
logger = get_logger(name=__name__, category="models::llama")
class VariableSizeImageTransform:

View file

@ -3,8 +3,6 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import math
from collections.abc import Callable
from functools import partial
@ -22,6 +20,8 @@ from PIL import Image as PIL_Image
from torch import Tensor, nn
from torch.distributed import _functional_collectives as funcol
from llama_stack.log import get_logger
from ..model import ModelArgs, RMSNorm, apply_rotary_emb, precompute_freqs_cis
from .encoder_utils import (
build_encoder_attention_mask,
@ -34,9 +34,10 @@ from .encoder_utils import (
from .image_transform import VariableSizeImageTransform
from .utils import get_negative_inf_value, to_2tuple
logger = logging.getLogger(__name__)
MP_SCALE = 8
logger = get_logger(name=__name__, category="models")
def reduce_from_tensor_model_parallel_region(input_):
"""All-reduce the input tensor across model parallel group."""
@ -771,7 +772,7 @@ class TilePositionEmbedding(nn.Module):
if embed is not None:
# reshape the weights to the correct shape
nt_old, nt_old, _, w = embed.shape
logging.info(f"Resizing tile embedding from {nt_old}x{nt_old} to {self.num_tiles}x{self.num_tiles}")
logger.info(f"Resizing tile embedding from {nt_old}x{nt_old} to {self.num_tiles}x{self.num_tiles}")
embed_new = TilePositionEmbedding._dynamic_resize(embed, self.num_tiles)
# assign the weights to the module
state_dict[prefix + "embedding"] = embed_new

View file

@ -4,8 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Collection, Iterator, Sequence, Set
from logging import getLogger
from pathlib import Path
from typing import (
Literal,
@ -14,11 +14,9 @@ from typing import (
import tiktoken
from llama_stack.log import get_logger
from llama_stack.models.llama.tokenizer_utils import load_bpe_file
logger = getLogger(__name__)
# The tiktoken tokenizer can handle <=400k chars without
# pyo3_runtime.PanicException.
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
@ -31,6 +29,8 @@ MAX_NO_WHITESPACES_CHARS = 25_000
_INSTANCE = None
logger = get_logger(name=__name__, category="models::llama")
class Tokenizer:
"""

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import os
from collections.abc import Callable
@ -13,11 +12,13 @@ from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
from torch import Tensor, nn
from torch.nn import functional as F
from llama_stack.log import get_logger
from ...datatypes import QuantizationMode
from ..model import Transformer, TransformerBlock
from ..moe import MoE
log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="models")
def swiglu_wrapper_no_reduce(

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
from collections.abc import Collection, Iterator, Sequence, Set
from logging import getLogger
from pathlib import Path
from typing import (
Literal,
@ -14,11 +13,9 @@ from typing import (
import tiktoken
from llama_stack.log import get_logger
from llama_stack.models.llama.tokenizer_utils import load_bpe_file
logger = getLogger(__name__)
# The tiktoken tokenizer can handle <=400k chars without
# pyo3_runtime.PanicException.
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
@ -101,6 +98,8 @@ BASIC_SPECIAL_TOKENS = [
"<|fim_suffix|>",
]
logger = get_logger(name=__name__, category="models::llama")
class Tokenizer:
"""

View file

@ -6,9 +6,10 @@
# type: ignore
import collections
import logging
log = logging.getLogger(__name__)
from llama_stack.log import get_logger
log = get_logger(name=__name__, category="llama")
try:
import fbgemm_gpu.experimental.gen_ai # noqa: F401

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import uuid
from collections.abc import AsyncGenerator
from datetime import UTC, datetime
@ -42,6 +41,7 @@ from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.datatypes import AccessRule
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
from llama_stack.providers.utils.pagination import paginate_records
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
@ -51,7 +51,7 @@ from .config import MetaReferenceAgentsImplConfig
from .persistence import AgentInfo
from .responses.openai_responses import OpenAIResponsesImpl
logger = logging.getLogger()
logger = get_logger(name=__name__, category="agents")
class MetaReferenceAgentsImpl(Agents):

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import json
import logging
import uuid
from datetime import UTC, datetime
@ -15,9 +14,10 @@ from llama_stack.core.access_control.access_control import AccessDeniedError, is
from llama_stack.core.access_control.datatypes import AccessRule
from llama_stack.core.datatypes import User
from llama_stack.core.request_headers import get_authenticated_user
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore
log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="agents")
class AgentSessionInfo(Session):

View file

@ -5,13 +5,13 @@
# the root directory of this source tree.
import asyncio
import logging
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
from llama_stack.log import get_logger
from llama_stack.providers.utils.telemetry import tracing
log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="agents")
class SafetyException(Exception): # noqa: N818

View file

@ -12,7 +12,6 @@
import copy
import json
import logging
import multiprocessing
import os
import tempfile
@ -32,13 +31,14 @@ from fairscale.nn.model_parallel.initialize import (
from pydantic import BaseModel, Field
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import GenerationResult
from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
)
log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="inference")
class ProcessingMessageName(str, Enum):

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from collections.abc import AsyncGenerator
from llama_stack.apis.inference import (
@ -21,6 +20,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
)
from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
@ -32,7 +32,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
from .config import SentenceTransformersInferenceConfig
log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="inference")
class SentenceTransformersInferenceImpl(

View file

@ -6,7 +6,6 @@
import gc
import json
import logging
import multiprocessing
from pathlib import Path
from typing import Any
@ -28,6 +27,7 @@ from llama_stack.apis.post_training import (
LoraFinetuningConfig,
TrainingConfig,
)
from llama_stack.log import get_logger
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
from ..config import HuggingFacePostTrainingConfig
@ -44,7 +44,7 @@ from ..utils import (
split_dataset,
)
logger = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="post_training")
class HFFinetuningSingleDevice:

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import gc
import logging
import multiprocessing
from pathlib import Path
from typing import Any
@ -24,6 +23,7 @@ from llama_stack.apis.post_training import (
DPOAlignmentConfig,
TrainingConfig,
)
from llama_stack.log import get_logger
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
from ..config import HuggingFacePostTrainingConfig
@ -40,7 +40,7 @@ from ..utils import (
split_dataset,
)
logger = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="post_training")
class HFDPOAlignmentSingleDevice:

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import os
import signal
import sys
@ -19,10 +18,11 @@ from transformers import AutoConfig, AutoModelForCausalLM
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.post_training import Checkpoint, TrainingConfig
from llama_stack.log import get_logger
from .config import HuggingFacePostTrainingConfig
logger = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="post_training")
def setup_environment():

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import os
import time
from datetime import UTC, datetime
@ -19,6 +18,7 @@ from torch.utils.data import DataLoader, DistributedSampler
from torchtune import modules, training
from torchtune import utils as torchtune_utils
from torchtune.data import padded_collate_sft
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
from torchtune.modules.loss import CEWithChunkedOutputLoss
from torchtune.modules.peft import (
get_adapter_params,
@ -45,6 +45,7 @@ from llama_stack.apis.post_training import (
)
from llama_stack.core.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
from llama_stack.core.utils.model_utils import model_local_dir
from llama_stack.log import get_logger
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
from llama_stack.providers.inline.post_training.torchtune.common import utils
@ -56,9 +57,7 @@ from llama_stack.providers.inline.post_training.torchtune.config import (
)
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
log = logging.getLogger(__name__)
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
log = get_logger(name=__name__, category="post_training")
class LoraFinetuningSingleDevice:

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import uuid
from typing import TYPE_CHECKING, Any
@ -20,13 +19,14 @@ from llama_stack.apis.safety import (
)
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
from .config import CodeScannerConfig
log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="safety")
ALLOWED_CODE_SCANNER_MODEL_IDS = [
"code-scanner",

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import re
import uuid
from string import Template
@ -21,6 +20,7 @@ from llama_stack.apis.safety import (
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
from llama_stack.apis.shields import Shield
from llama_stack.core.datatypes import Api
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import Role
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
@ -132,6 +132,8 @@ Provide your safety assessment for ONLY THE LAST $agent_type message in the abov
PROMPT_TEMPLATE = Template(f"{PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS}")
logger = get_logger(name=__name__, category="safety")
class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
def __init__(self, config: LlamaGuardConfig, deps) -> None:
@ -407,7 +409,7 @@ class LlamaGuardShield:
unsafe_code_list = [code.strip() for code in unsafe_code.split(",")]
invalid_codes = [code for code in unsafe_code_list if code not in SAFETY_CODE_TO_CATEGORIES_MAP]
if invalid_codes:
logging.warning(f"Invalid safety codes returned: {invalid_codes}")
logger.warning(f"Invalid safety codes returned: {invalid_codes}")
# just returning safe object, as we don't know what the invalid codes can map to
return ModerationObject(
id=f"modr-{uuid.uuid4()}",

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import Any
import torch
@ -21,6 +20,7 @@ from llama_stack.apis.safety import (
from llama_stack.apis.safety.safety import ModerationObject
from llama_stack.apis.shields import Shield
from llama_stack.core.utils.model_utils import model_local_dir
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
@ -28,7 +28,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import PromptGuardConfig, PromptGuardType
log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="safety")
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"

View file

@ -7,7 +7,6 @@
import collections
import functools
import json
import logging
import random
import re
import string
@ -20,7 +19,9 @@ import nltk
from pythainlp.tokenize import sent_tokenize as sent_tokenize_thai
from pythainlp.tokenize import word_tokenize as word_tokenize_thai
logger = logging.getLogger()
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="scoring")
WORD_LIST = [
"western",

View file

@ -4,13 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import threading
from typing import Any
from opentelemetry import metrics, trace
logger = logging.getLogger(__name__)
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.metrics import MeterProvider
@ -40,6 +37,7 @@ from llama_stack.apis.telemetry import (
UnstructuredLogEvent,
)
from llama_stack.core.datatypes import Api
from llama_stack.log import get_logger
from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import (
ConsoleSpanProcessor,
)
@ -61,6 +59,8 @@ _GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = {
_global_lock = threading.Lock()
_TRACER_PROVIDER = None
logger = get_logger(name=__name__, category="telemetry")
def is_tracing_enabled(tracer):
with tracer.start_as_current_span("check_tracing") as span:

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import asyncio
import logging
import secrets
import string
from typing import Any
@ -32,6 +31,7 @@ from llama_stack.apis.tools import (
ToolRuntime,
)
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from llama_stack.providers.utils.memory.vector_store import (
@ -42,7 +42,7 @@ from llama_stack.providers.utils.memory.vector_store import (
from .config import RagToolRuntimeConfig
from .context_retriever import generate_rag_query
log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="tool_runtime")
def make_random_string(length: int = 8):

View file

@ -8,7 +8,6 @@ import asyncio
import base64
import io
import json
import logging
from typing import Any
import faiss
@ -24,6 +23,7 @@ from llama_stack.apis.vector_io import (
QueryChunksResponse,
VectorIO,
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import (
HealthResponse,
HealthStatus,
@ -40,7 +40,7 @@ from llama_stack.providers.utils.memory.vector_store import (
from .config import FaissVectorIOConfig
logger = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="vector_io")
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:{VERSION}::"

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import asyncio
import logging
import re
import sqlite3
import struct
@ -24,6 +23,7 @@ from llama_stack.apis.vector_io import (
QueryChunksResponse,
VectorIO,
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
@ -36,7 +36,7 @@ from llama_stack.providers.utils.memory.vector_store import (
VectorDBWithIndex,
)
logger = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="vector_io")
# Specifying search mode is dependent on the VectorIO provider.
VECTOR_SEARCH = "vector"

View file

@ -3,15 +3,14 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from llama_stack.log import get_logger
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .models import MODEL_ENTRIES
logger = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="inference")
class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):

View file

@ -77,6 +77,10 @@ print(f"Response: {response.completion_message.content}")
```
### Create Embeddings
> Note on OpenAI embeddings compatibility
>
> NVIDIA asymmetric embedding models (e.g., `nvidia/llama-3.2-nv-embedqa-1b-v2`) require an `input_type` parameter not present in the standard OpenAI embeddings API. The NVIDIA Inference Adapter automatically sets `input_type="query"` when using the OpenAI-compatible embeddings endpoint for NVIDIA. For passage embeddings, use the `embeddings` API with `task_type="document"`.
```python
response = client.inference.embeddings(
model_id="nvidia/llama-3.2-nv-embedqa-1b-v2",

View file

@ -4,11 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import warnings
from collections.abc import AsyncIterator
from openai import APIConnectionError, BadRequestError
from openai import NOT_GIVEN, APIConnectionError, BadRequestError
from llama_stack.apis.common.content_types import (
InterleavedContent,
@ -27,12 +26,16 @@ from llama_stack.apis.inference import (
Inference,
LogProbConfig,
Message,
OpenAIEmbeddingData,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
ResponseFormat,
SamplingParams,
TextTruncation,
ToolChoice,
ToolConfig,
)
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
@ -54,7 +57,7 @@ from .openai_utils import (
)
from .utils import _is_nvidia_hosted
logger = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="inference")
class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
@ -210,6 +213,57 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
#
return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data])
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
"""
OpenAI-compatible embeddings for NVIDIA NIM.
Note: NVIDIA NIM asymmetric embedding models require an "input_type" field not present in the standard OpenAI embeddings API.
We default this to "query" to ensure requests succeed when using the
OpenAI-compatible endpoint. For passage embeddings, use the embeddings API with
`task_type='document'`.
"""
extra_body: dict[str, object] = {"input_type": "query"}
logger.warning(
"NVIDIA OpenAI-compatible embeddings: defaulting to input_type='query'. "
"For passage embeddings, use the embeddings API with task_type='document'."
)
response = await self.client.embeddings.create(
model=await self._get_provider_model_id(model),
input=input,
encoding_format=encoding_format if encoding_format is not None else NOT_GIVEN,
dimensions=dimensions if dimensions is not None else NOT_GIVEN,
user=user if user is not None else NOT_GIVEN,
extra_body=extra_body,
)
data = []
for i, embedding_data in enumerate(response.data):
data.append(
OpenAIEmbeddingData(
embedding=embedding_data.embedding,
index=i,
)
)
usage = OpenAIEmbeddingUsage(
prompt_tokens=response.usage.prompt_tokens,
total_tokens=response.usage.total_tokens,
)
return OpenAIEmbeddingsResponse(
data=data,
model=response.model,
usage=usage,
)
async def chat_completion(
self,
model_id: str,

View file

@ -4,13 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import httpx
from llama_stack.log import get_logger
from . import NVIDIAConfig
logger = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="inference")
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:

View file

@ -4,15 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import OpenAIConfig
from .models import MODEL_ENTRIES
logger = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="inference")
#

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import logging
from collections.abc import AsyncGenerator
from huggingface_hub import AsyncInferenceClient, HfApi
@ -34,6 +33,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
)
from llama_stack.apis.models import Model
from llama_stack.log import get_logger
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import (
@ -58,7 +58,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="inference")
def build_hf_repo_model_entries():

View file

@ -4,18 +4,18 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import warnings
from typing import Any
from pydantic import BaseModel
from llama_stack.apis.post_training import TrainingConfig
from llama_stack.log import get_logger
from llama_stack.providers.remote.post_training.nvidia.config import SFTLoRADefaultConfig
from .config import NvidiaPostTrainingConfig
logger = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="integration")
def warn_unsupported_params(config_dict: Any, supported_keys: set[str], config_name: str) -> None:

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import json
import logging
from typing import Any
from llama_stack.apis.inference import Message
@ -16,12 +15,13 @@ from llama_stack.apis.safety import (
ViolationLevel,
)
from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
from .config import BedrockSafetyConfig
logger = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="safety")
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import Any
import requests
@ -12,12 +11,13 @@ import requests
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import RunShieldResponse, Safety, SafetyViolation, ViolationLevel
from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new
from .config import NVIDIASafetyConfig
logger = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="safety")
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import json
import logging
from typing import Any
import litellm
@ -20,12 +19,13 @@ from llama_stack.apis.safety import (
)
from llama_stack.apis.shields import Shield
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new
from .config import SambaNovaSafetyConfig
logger = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="safety")
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import asyncio
import json
import logging
from typing import Any
from urllib.parse import urlparse
@ -20,6 +19,7 @@ from llama_stack.apis.vector_io import (
QueryChunksResponse,
VectorIO,
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
from llama_stack.providers.utils.kvstore import kvstore_impl
@ -33,7 +33,7 @@ from llama_stack.providers.utils.memory.vector_store import (
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="vector_io")
ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import asyncio
import logging
import os
from typing import Any
@ -21,6 +20,7 @@ from llama_stack.apis.vector_io import (
QueryChunksResponse,
VectorIO,
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
from llama_stack.providers.utils.kvstore import kvstore_impl
@ -36,7 +36,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti
from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig
logger = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="vector_io")
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:milvus:{VERSION}::"
@ -413,15 +413,6 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index:
raise VectorStoreNotFoundError(vector_db_id)
if params and params.get("mode") == "keyword":
# Check if this is inline Milvus (Milvus-Lite)
if hasattr(self.config, "db_path"):
raise NotImplementedError(
"Keyword search is not supported in Milvus-Lite. "
"Please use a remote Milvus server for keyword search functionality."
)
return await index.query_chunks(query, params)
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import Any
import psycopg2
@ -22,6 +21,7 @@ from llama_stack.apis.vector_io import (
QueryChunksResponse,
VectorIO,
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
@ -34,7 +34,7 @@ from llama_stack.providers.utils.memory.vector_store import (
from .config import PGVectorVectorIOConfig
log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="vector_io")
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:pgvector:{VERSION}::"

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import asyncio
import logging
import uuid
from typing import Any
@ -24,6 +23,7 @@ from llama_stack.apis.vector_io import (
VectorStoreChunkingStrategy,
VectorStoreFileObject,
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
@ -36,7 +36,7 @@ from llama_stack.providers.utils.memory.vector_store import (
from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="vector_io")
CHUNK_ID_KEY = "_chunk_id"
# KV store prefixes for vector databases

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import logging
from typing import Any
import weaviate
@ -19,6 +18,7 @@ from llama_stack.apis.files.files import Files
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
@ -34,7 +34,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti
from .config import WeaviateVectorIOConfig
log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="vector_io")
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::"

View file

@ -5,10 +5,11 @@
# the root directory of this source tree.
import base64
import logging
import struct
from typing import TYPE_CHECKING
from llama_stack.log import get_logger
if TYPE_CHECKING:
from sentence_transformers import SentenceTransformer
@ -27,7 +28,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import interleaved_con
EMBEDDING_MODELS = {}
log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="inference")
class SentenceTransformerEmbeddingMixin:

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import base64
import json
import logging
import struct
import time
import uuid
@ -122,6 +121,7 @@ from llama_stack.apis.inference import (
from llama_stack.apis.inference import (
OpenAIChoice as OpenAIChatCompletionChoice,
)
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import (
BuiltinTool,
StopReason,
@ -134,7 +134,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
decode_assistant_message,
)
logger = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="inference")
class OpenAICompatCompletionChoiceDelta(BaseModel):

View file

@ -4,16 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from datetime import datetime
from pymongo import AsyncMongoClient
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore
from ..config import MongoDBKVStoreConfig
log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="kvstore")
class MongoDBKVStoreImpl(KVStore):

View file

@ -4,16 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from datetime import datetime
import psycopg2
from psycopg2.extras import DictCursor
from llama_stack.log import get_logger
from ..api import KVStore
from ..config import PostgresKVStoreConfig
log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="kvstore")
class PostgresKVStoreImpl(KVStore):

View file

@ -44,7 +44,7 @@ from llama_stack.providers.utils.memory.vector_store import (
make_overlapped_chunks,
)
logger = get_logger(__name__, category="vector_io")
logger = get_logger(name=__name__, category="memory")
# Constants for OpenAI vector stores
CHUNK_MULTIPLIER = 5

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import base64
import io
import logging
import re
import time
from abc import ABC, abstractmethod
@ -26,6 +25,7 @@ from llama_stack.apis.common.content_types import (
from llama_stack.apis.tools import RAGDocument
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
from llama_stack.log import get_logger
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.datatypes import Api
from llama_stack.providers.utils.inference.prompt_adapter import (
@ -33,7 +33,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
)
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
log = logging.getLogger(__name__)
log = get_logger(name=__name__, category="memory")
class ChunkForDeletion(BaseModel):

View file

@ -6,7 +6,7 @@
import asyncio
import contextvars
import logging
import logging # allow-direct-logging
import queue
import random
import sys

View file

@ -0,0 +1,587 @@
import React from "react";
import {
render,
screen,
fireEvent,
waitFor,
act,
} from "@testing-library/react";
import "@testing-library/jest-dom";
import ChatPlaygroundPage from "./page";
const mockClient = {
agents: {
list: jest.fn(),
create: jest.fn(),
retrieve: jest.fn(),
delete: jest.fn(),
session: {
list: jest.fn(),
create: jest.fn(),
delete: jest.fn(),
retrieve: jest.fn(),
},
turn: {
create: jest.fn(),
},
},
models: {
list: jest.fn(),
},
toolgroups: {
list: jest.fn(),
},
};
jest.mock("@/hooks/use-auth-client", () => ({
useAuthClient: jest.fn(() => mockClient),
}));
jest.mock("@/components/chat-playground/chat", () => ({
Chat: jest.fn(
({
className,
messages,
handleSubmit,
input,
handleInputChange,
isGenerating,
append,
suggestions,
}) => (
<div data-testid="chat-component" className={className}>
<div data-testid="messages-count">{messages.length}</div>
<input
data-testid="chat-input"
value={input}
onChange={handleInputChange}
disabled={isGenerating}
/>
<button data-testid="submit-button" onClick={handleSubmit}>
Submit
</button>
{suggestions?.map((suggestion: string, index: number) => (
<button
key={index}
data-testid={`suggestion-${index}`}
onClick={() => append({ role: "user", content: suggestion })}
>
{suggestion}
</button>
))}
</div>
)
),
}));
jest.mock("@/components/chat-playground/session-manager", () => ({
SessionManager: jest.fn(({ selectedAgentId, onNewSession }) => (
<div data-testid="session-manager">
{selectedAgentId && (
<>
<div data-testid="selected-agent">{selectedAgentId}</div>
<button data-testid="new-session-button" onClick={onNewSession}>
New Session
</button>
</>
)}
</div>
)),
SessionUtils: {
saveCurrentSessionId: jest.fn(),
loadCurrentSessionId: jest.fn(),
loadCurrentAgentId: jest.fn(),
saveCurrentAgentId: jest.fn(),
clearCurrentSession: jest.fn(),
saveSessionData: jest.fn(),
loadSessionData: jest.fn(),
saveAgentConfig: jest.fn(),
loadAgentConfig: jest.fn(),
clearAgentCache: jest.fn(),
createDefaultSession: jest.fn(() => ({
id: "test-session-123",
name: "Default Session",
messages: [],
selectedModel: "",
systemMessage: "You are a helpful assistant.",
agentId: "test-agent-123",
createdAt: Date.now(),
updatedAt: Date.now(),
})),
},
}));
const mockAgents = [
{
agent_id: "agent_123",
agent_config: {
name: "Test Agent",
instructions: "You are a test assistant.",
},
},
{
agent_id: "agent_456",
agent_config: {
agent_name: "Another Agent",
instructions: "You are another assistant.",
},
},
];
const mockModels = [
{
identifier: "test-model-1",
model_type: "llm",
},
{
identifier: "test-model-2",
model_type: "llm",
},
];
const mockToolgroups = [
{
identifier: "builtin::rag",
provider_id: "test-provider",
type: "tool_group",
provider_resource_id: "test-resource",
},
];
describe("ChatPlaygroundPage", () => {
beforeEach(() => {
jest.clearAllMocks();
Element.prototype.scrollIntoView = jest.fn();
mockClient.agents.list.mockResolvedValue({ data: mockAgents });
mockClient.models.list.mockResolvedValue(mockModels);
mockClient.toolgroups.list.mockResolvedValue(mockToolgroups);
mockClient.agents.session.create.mockResolvedValue({
session_id: "new-session-123",
});
mockClient.agents.session.list.mockResolvedValue({ data: [] });
mockClient.agents.session.retrieve.mockResolvedValue({
session_id: "test-session",
session_name: "Test Session",
started_at: new Date().toISOString(),
turns: [],
}); // No turns by default
mockClient.agents.retrieve.mockResolvedValue({
agent_id: "test-agent",
agent_config: {
toolgroups: ["builtin::rag"],
instructions: "Test instructions",
model: "test-model",
},
});
mockClient.agents.delete.mockResolvedValue(undefined);
});
describe("Agent Selector Rendering", () => {
test("shows agent selector when agents are available", async () => {
await act(async () => {
render(<ChatPlaygroundPage />);
});
await waitFor(() => {
expect(screen.getByText("Agent Session:")).toBeInTheDocument();
expect(screen.getAllByRole("combobox")).toHaveLength(2);
expect(screen.getByText("+ New Agent")).toBeInTheDocument();
expect(screen.getByText("Clear Chat")).toBeInTheDocument();
});
});
test("does not show agent selector when no agents are available", async () => {
mockClient.agents.list.mockResolvedValue({ data: [] });
await act(async () => {
render(<ChatPlaygroundPage />);
});
await waitFor(() => {
expect(screen.queryByText("Agent Session:")).not.toBeInTheDocument();
expect(screen.getAllByRole("combobox")).toHaveLength(1);
expect(screen.getByText("+ New Agent")).toBeInTheDocument();
expect(screen.queryByText("Clear Chat")).not.toBeInTheDocument();
});
});
test("does not show agent selector while loading", async () => {
mockClient.agents.list.mockImplementation(() => new Promise(() => {}));
await act(async () => {
render(<ChatPlaygroundPage />);
});
expect(screen.queryByText("Agent Session:")).not.toBeInTheDocument();
expect(screen.getAllByRole("combobox")).toHaveLength(1);
expect(screen.getByText("+ New Agent")).toBeInTheDocument();
expect(screen.queryByText("Clear Chat")).not.toBeInTheDocument();
});
test("shows agent options in selector", async () => {
await act(async () => {
render(<ChatPlaygroundPage />);
});
await waitFor(() => {
const agentCombobox = screen.getAllByRole("combobox").find(element => {
return (
element.textContent?.includes("Test Agent") ||
element.textContent?.includes("Select Agent")
);
});
expect(agentCombobox).toBeDefined();
fireEvent.click(agentCombobox!);
});
await waitFor(() => {
expect(screen.getAllByText("Test Agent")).toHaveLength(2);
expect(screen.getByText("Another Agent")).toBeInTheDocument();
});
});
test("displays agent ID when no name is available", async () => {
const agentWithoutName = {
agent_id: "agent_789",
agent_config: {
instructions: "You are an agent without a name.",
},
};
mockClient.agents.list.mockResolvedValue({ data: [agentWithoutName] });
await act(async () => {
render(<ChatPlaygroundPage />);
});
await waitFor(() => {
const agentCombobox = screen.getAllByRole("combobox").find(element => {
return (
element.textContent?.includes("Agent agent_78") ||
element.textContent?.includes("Select Agent")
);
});
expect(agentCombobox).toBeDefined();
fireEvent.click(agentCombobox!);
});
await waitFor(() => {
expect(screen.getAllByText("Agent agent_78...")).toHaveLength(2);
});
});
});
describe("Agent Creation Modal", () => {
test("opens agent creation modal when + New Agent is clicked", async () => {
await act(async () => {
render(<ChatPlaygroundPage />);
});
const newAgentButton = screen.getByText("+ New Agent");
fireEvent.click(newAgentButton);
expect(screen.getByText("Create New Agent")).toBeInTheDocument();
expect(screen.getByText("Agent Name (optional)")).toBeInTheDocument();
expect(screen.getAllByText("Model")).toHaveLength(2);
expect(screen.getByText("System Instructions")).toBeInTheDocument();
expect(screen.getByText("Tools (optional)")).toBeInTheDocument();
});
test("closes modal when Cancel is clicked", async () => {
await act(async () => {
render(<ChatPlaygroundPage />);
});
const newAgentButton = screen.getByText("+ New Agent");
fireEvent.click(newAgentButton);
const cancelButton = screen.getByText("Cancel");
fireEvent.click(cancelButton);
expect(screen.queryByText("Create New Agent")).not.toBeInTheDocument();
});
test("creates agent when Create Agent is clicked", async () => {
mockClient.agents.create.mockResolvedValue({ agent_id: "new-agent-123" });
mockClient.agents.list
.mockResolvedValueOnce({ data: mockAgents })
.mockResolvedValueOnce({
data: [
...mockAgents,
{ agent_id: "new-agent-123", agent_config: { name: "New Agent" } },
],
});
await act(async () => {
render(<ChatPlaygroundPage />);
});
const newAgentButton = screen.getByText("+ New Agent");
await act(async () => {
fireEvent.click(newAgentButton);
});
await waitFor(() => {
expect(screen.getByText("Create New Agent")).toBeInTheDocument();
});
const nameInput = screen.getByPlaceholderText("My Custom Agent");
await act(async () => {
fireEvent.change(nameInput, { target: { value: "Test Agent Name" } });
});
const instructionsTextarea = screen.getByDisplayValue(
"You are a helpful assistant."
);
await act(async () => {
fireEvent.change(instructionsTextarea, {
target: { value: "Custom instructions" },
});
});
await waitFor(() => {
const modalModelSelectors = screen
.getAllByRole("combobox")
.filter(el => {
return (
el.textContent?.includes("Select Model") ||
el.closest('[class*="modal"]') ||
el.closest('[class*="card"]')
);
});
expect(modalModelSelectors.length).toBeGreaterThan(0);
});
const modalModelSelectors = screen.getAllByRole("combobox").filter(el => {
return (
el.textContent?.includes("Select Model") ||
el.closest('[class*="modal"]') ||
el.closest('[class*="card"]')
);
});
await act(async () => {
fireEvent.click(modalModelSelectors[0]);
});
await waitFor(() => {
const modelOptions = screen.getAllByText("test-model-1");
expect(modelOptions.length).toBeGreaterThan(0);
});
const modelOptions = screen.getAllByText("test-model-1");
const dropdownOption = modelOptions.find(
option =>
option.closest('[role="option"]') ||
option.id?.includes("radix") ||
option.getAttribute("aria-selected") !== null
);
await act(async () => {
fireEvent.click(
dropdownOption || modelOptions[modelOptions.length - 1]
);
});
await waitFor(() => {
const createButton = screen.getByText("Create Agent");
expect(createButton).not.toBeDisabled();
});
const createButton = screen.getByText("Create Agent");
await act(async () => {
fireEvent.click(createButton);
});
await waitFor(() => {
expect(mockClient.agents.create).toHaveBeenCalledWith({
agent_config: {
model: expect.any(String),
instructions: "Custom instructions",
name: "Test Agent Name",
enable_session_persistence: true,
},
});
});
await waitFor(() => {
expect(screen.queryByText("Create New Agent")).not.toBeInTheDocument();
});
});
});
describe("Agent Selection", () => {
test("creates default session when agent is selected", async () => {
await act(async () => {
render(<ChatPlaygroundPage />);
});
await waitFor(() => {
// first agent should be auto-selected
expect(mockClient.agents.session.create).toHaveBeenCalledWith(
"agent_123",
{ session_name: "Default Session" }
);
});
});
test("switches agent when different agent is selected", async () => {
await act(async () => {
render(<ChatPlaygroundPage />);
});
await waitFor(() => {
const agentCombobox = screen.getAllByRole("combobox").find(element => {
return (
element.textContent?.includes("Test Agent") ||
element.textContent?.includes("Select Agent")
);
});
expect(agentCombobox).toBeDefined();
fireEvent.click(agentCombobox!);
});
await waitFor(() => {
const anotherAgentOption = screen.getByText("Another Agent");
fireEvent.click(anotherAgentOption);
});
expect(mockClient.agents.session.create).toHaveBeenCalledWith(
"agent_456",
{ session_name: "Default Session" }
);
});
});
describe("Agent Deletion", () => {
test("shows delete button when multiple agents exist", async () => {
await act(async () => {
render(<ChatPlaygroundPage />);
});
await waitFor(() => {
expect(screen.getByTitle("Delete current agent")).toBeInTheDocument();
});
});
test("hides delete button when only one agent exists", async () => {
mockClient.agents.list.mockResolvedValue({
data: [mockAgents[0]],
});
await act(async () => {
render(<ChatPlaygroundPage />);
});
await waitFor(() => {
expect(
screen.queryByTitle("Delete current agent")
).not.toBeInTheDocument();
});
});
test("deletes agent and switches to another when confirmed", async () => {
global.confirm = jest.fn(() => true);
await act(async () => {
render(<ChatPlaygroundPage />);
});
await waitFor(() => {
expect(screen.getByTitle("Delete current agent")).toBeInTheDocument();
});
mockClient.agents.delete.mockResolvedValue(undefined);
mockClient.agents.list.mockResolvedValueOnce({ data: mockAgents });
mockClient.agents.list.mockResolvedValueOnce({
data: [mockAgents[1]],
});
const deleteButton = screen.getByTitle("Delete current agent");
await act(async () => {
deleteButton.click();
});
await waitFor(() => {
expect(mockClient.agents.delete).toHaveBeenCalledWith("agent_123");
expect(global.confirm).toHaveBeenCalledWith(
"Are you sure you want to delete this agent? This action cannot be undone and will delete all associated sessions."
);
});
(global.confirm as jest.Mock).mockRestore();
});
test("does not delete agent when cancelled", async () => {
global.confirm = jest.fn(() => false);
await act(async () => {
render(<ChatPlaygroundPage />);
});
await waitFor(() => {
expect(screen.getByTitle("Delete current agent")).toBeInTheDocument();
});
const deleteButton = screen.getByTitle("Delete current agent");
await act(async () => {
deleteButton.click();
});
await waitFor(() => {
expect(global.confirm).toHaveBeenCalled();
expect(mockClient.agents.delete).not.toHaveBeenCalled();
});
(global.confirm as jest.Mock).mockRestore();
});
});
describe("Error Handling", () => {
test("handles agent loading errors gracefully", async () => {
mockClient.agents.list.mockRejectedValue(
new Error("Failed to load agents")
);
const consoleSpy = jest
.spyOn(console, "error")
.mockImplementation(() => {});
await act(async () => {
render(<ChatPlaygroundPage />);
});
await waitFor(() => {
expect(consoleSpy).toHaveBeenCalledWith(
"Error fetching agents:",
expect.any(Error)
);
});
expect(screen.getByText("+ New Agent")).toBeInTheDocument();
consoleSpy.mockRestore();
});
test("handles model loading errors gracefully", async () => {
mockClient.models.list.mockRejectedValue(
new Error("Failed to load models")
);
const consoleSpy = jest
.spyOn(console, "error")
.mockImplementation(() => {});
await act(async () => {
render(<ChatPlaygroundPage />);
});
await waitFor(() => {
expect(consoleSpy).toHaveBeenCalledWith(
"Error fetching models:",
expect.any(Error)
);
});
consoleSpy.mockRestore();
});
});
});

File diff suppressed because it is too large Load diff

Binary file not shown.

Before

Width:  |  Height:  |  Size: 25 KiB

View file

@ -120,3 +120,44 @@
@apply bg-background text-foreground;
}
}
@layer utilities {
.animate-typing-dot-1 {
animation: typing-dot-bounce-1 0.8s cubic-bezier(0.4, 0, 0.6, 1) infinite;
}
.animate-typing-dot-2 {
animation: typing-dot-bounce-2 0.8s cubic-bezier(0.4, 0, 0.6, 1) infinite;
}
.animate-typing-dot-3 {
animation: typing-dot-bounce-3 0.8s cubic-bezier(0.4, 0, 0.6, 1) infinite;
}
@keyframes typing-dot-bounce-1 {
0%, 15%, 85%, 100% {
transform: translateY(0);
}
7.5% {
transform: translateY(-6px);
}
}
@keyframes typing-dot-bounce-2 {
0%, 15%, 35%, 85%, 100% {
transform: translateY(0);
}
25% {
transform: translateY(-6px);
}
}
@keyframes typing-dot-bounce-3 {
0%, 35%, 55%, 85%, 100% {
transform: translateY(0);
}
45% {
transform: translateY(-6px);
}
}
}

View file

@ -18,6 +18,9 @@ const geistMono = Geist_Mono({
export const metadata: Metadata = {
title: "Llama Stack",
description: "Llama Stack UI",
icons: {
icon: "/favicon.ico",
},
};
import { SidebarProvider, SidebarTrigger } from "@/components/ui/sidebar";

View file

@ -1,6 +1,6 @@
"use client";
import { useState, useEffect } from "react";
import { useState, useEffect, useCallback } from "react";
import { Button } from "@/components/ui/button";
import {
Select,
@ -13,14 +13,20 @@ import { Input } from "@/components/ui/input";
import { Card } from "@/components/ui/card";
import { Trash2 } from "lucide-react";
import type { Message } from "@/components/chat-playground/chat-message";
import { useAuthClient } from "@/hooks/use-auth-client";
import type {
Session,
SessionCreateParams,
} from "llama-stack-client/resources/agents";
interface ChatSession {
export interface ChatSession {
id: string;
name: string;
messages: Message[];
selectedModel: string;
selectedVectorDb: string;
systemMessage: string;
agentId: string;
session?: Session;
createdAt: number;
updatedAt: number;
}
@ -29,9 +35,9 @@ interface SessionManagerProps {
currentSession: ChatSession | null;
onSessionChange: (session: ChatSession) => void;
onNewSession: () => void;
selectedAgentId: string;
}
const SESSIONS_STORAGE_KEY = "chat-playground-sessions";
const CURRENT_SESSION_KEY = "chat-playground-current-session";
// ensures this only happens client side
@ -63,16 +69,6 @@ const safeLocalStorage = {
},
};
function safeJsonParse<T>(jsonString: string | null, fallback: T): T {
if (!jsonString) return fallback;
try {
return JSON.parse(jsonString) as T;
} catch (err) {
console.error("Error parsing JSON:", err);
return fallback;
}
}
const generateSessionId = (): string => {
return globalThis.crypto.randomUUID();
};
@ -80,60 +76,202 @@ const generateSessionId = (): string => {
export function SessionManager({
currentSession,
onSessionChange,
selectedAgentId,
}: SessionManagerProps) {
const [sessions, setSessions] = useState<ChatSession[]>([]);
const [showCreateForm, setShowCreateForm] = useState(false);
const [newSessionName, setNewSessionName] = useState("");
const [loading, setLoading] = useState(false);
const client = useAuthClient();
const loadAgentSessions = useCallback(async () => {
if (!selectedAgentId) return;
setLoading(true);
try {
const response = await client.agents.session.list(selectedAgentId);
console.log("Sessions response:", response);
if (!response.data || !Array.isArray(response.data)) {
console.warn("Invalid sessions response, starting fresh");
setSessions([]);
return;
}
const agentSessions: ChatSession[] = response.data
.filter(sessionData => {
const isValid =
sessionData &&
typeof sessionData === "object" &&
sessionData.session_id &&
sessionData.session_name;
if (!isValid) {
console.warn("Filtering out invalid session:", sessionData);
}
return isValid;
})
.map(sessionData => ({
id: sessionData.session_id,
name: sessionData.session_name,
messages: [],
selectedModel: currentSession?.selectedModel || "",
systemMessage:
currentSession?.systemMessage || "You are a helpful assistant.",
agentId: selectedAgentId,
session: sessionData,
createdAt: sessionData.started_at
? new Date(sessionData.started_at).getTime()
: Date.now(),
updatedAt: sessionData.started_at
? new Date(sessionData.started_at).getTime()
: Date.now(),
}));
setSessions(agentSessions);
} catch (error) {
console.error("Error loading agent sessions:", error);
setSessions([]);
} finally {
setLoading(false);
}
}, [
selectedAgentId,
client,
currentSession?.selectedModel,
currentSession?.systemMessage,
]);
useEffect(() => {
const savedSessions = safeLocalStorage.getItem(SESSIONS_STORAGE_KEY);
const sessions = safeJsonParse<ChatSession[]>(savedSessions, []);
setSessions(sessions);
}, []);
if (selectedAgentId) {
loadAgentSessions();
}
}, [selectedAgentId, loadAgentSessions]);
const saveSessions = (updatedSessions: ChatSession[]) => {
setSessions(updatedSessions);
safeLocalStorage.setItem(
SESSIONS_STORAGE_KEY,
JSON.stringify(updatedSessions)
);
};
const createNewSession = async () => {
if (!selectedAgentId) return;
const createNewSession = () => {
const sessionName =
newSessionName.trim() || `Session ${sessions.length + 1}`;
const newSession: ChatSession = {
id: generateSessionId(),
name: sessionName,
messages: [],
selectedModel: currentSession?.selectedModel || "",
selectedVectorDb: currentSession?.selectedVectorDb || "",
systemMessage:
currentSession?.systemMessage || "You are a helpful assistant.",
createdAt: Date.now(),
updatedAt: Date.now(),
};
setLoading(true);
const updatedSessions = [...sessions, newSession];
saveSessions(updatedSessions);
try {
const response = await client.agents.session.create(selectedAgentId, {
session_name: sessionName,
} as SessionCreateParams);
safeLocalStorage.setItem(CURRENT_SESSION_KEY, newSession.id);
onSessionChange(newSession);
const newSession: ChatSession = {
id: response.session_id,
name: sessionName,
messages: [],
selectedModel: currentSession?.selectedModel || "",
systemMessage:
currentSession?.systemMessage || "You are a helpful assistant.",
agentId: selectedAgentId,
createdAt: Date.now(),
updatedAt: Date.now(),
};
setNewSessionName("");
setShowCreateForm(false);
};
setSessions(prev => [...prev, newSession]);
SessionUtils.saveCurrentSessionId(newSession.id, selectedAgentId);
onSessionChange(newSession);
const switchToSession = (sessionId: string) => {
const session = sessions.find(s => s.id === sessionId);
if (session) {
safeLocalStorage.setItem(CURRENT_SESSION_KEY, sessionId);
onSessionChange(session);
setNewSessionName("");
setShowCreateForm(false);
} catch (error) {
console.error("Error creating session:", error);
} finally {
setLoading(false);
}
};
const deleteSession = (sessionId: string) => {
if (sessions.length <= 1) {
const loadSessionMessages = useCallback(
async (agentId: string, sessionId: string): Promise<Message[]> => {
try {
const session = await client.agents.session.retrieve(
agentId,
sessionId
);
if (!session || !session.turns || !Array.isArray(session.turns)) {
return [];
}
const messages: Message[] = [];
for (const turn of session.turns) {
// Add user messages from input_messages
if (turn.input_messages && Array.isArray(turn.input_messages)) {
for (const input of turn.input_messages) {
if (input.role === "user" && input.content) {
messages.push({
id: `${turn.turn_id}-user-${messages.length}`,
role: "user",
content:
typeof input.content === "string"
? input.content
: JSON.stringify(input.content),
createdAt: new Date(turn.started_at || Date.now()),
});
}
}
}
// Add assistant message from output_message
if (turn.output_message && turn.output_message.content) {
messages.push({
id: `${turn.turn_id}-assistant-${messages.length}`,
role: "assistant",
content:
typeof turn.output_message.content === "string"
? turn.output_message.content
: JSON.stringify(turn.output_message.content),
createdAt: new Date(
turn.completed_at || turn.started_at || Date.now()
),
});
}
}
return messages;
} catch (error) {
console.error("Error loading session messages:", error);
return [];
}
},
[client]
);
const switchToSession = useCallback(
async (sessionId: string) => {
const session = sessions.find(s => s.id === sessionId);
if (session) {
setLoading(true);
try {
// Load messages for this session
const messages = await loadSessionMessages(
selectedAgentId,
sessionId
);
const sessionWithMessages = {
...session,
messages,
};
SessionUtils.saveCurrentSessionId(sessionId, selectedAgentId);
onSessionChange(sessionWithMessages);
} catch (error) {
console.error("Error switching to session:", error);
// Fallback to session without messages
SessionUtils.saveCurrentSessionId(sessionId, selectedAgentId);
onSessionChange(session);
} finally {
setLoading(false);
}
}
},
[sessions, selectedAgentId, loadSessionMessages, onSessionChange]
);
const deleteSession = async (sessionId: string) => {
if (sessions.length <= 1 || !selectedAgentId) {
return;
}
@ -142,21 +280,30 @@ export function SessionManager({
"Are you sure you want to delete this session? This action cannot be undone."
)
) {
const updatedSessions = sessions.filter(s => s.id !== sessionId);
saveSessions(updatedSessions);
setLoading(true);
try {
await client.agents.session.delete(selectedAgentId, sessionId);
if (currentSession?.id === sessionId) {
const newCurrentSession = updatedSessions[0] || null;
if (newCurrentSession) {
safeLocalStorage.setItem(CURRENT_SESSION_KEY, newCurrentSession.id);
onSessionChange(newCurrentSession);
} else {
safeLocalStorage.removeItem(CURRENT_SESSION_KEY);
const defaultSession = SessionUtils.createDefaultSession();
saveSessions([defaultSession]);
safeLocalStorage.setItem(CURRENT_SESSION_KEY, defaultSession.id);
onSessionChange(defaultSession);
const updatedSessions = sessions.filter(s => s.id !== sessionId);
setSessions(updatedSessions);
if (currentSession?.id === sessionId) {
const newCurrentSession = updatedSessions[0] || null;
if (newCurrentSession) {
SessionUtils.saveCurrentSessionId(
newCurrentSession.id,
selectedAgentId
);
onSessionChange(newCurrentSession);
} else {
SessionUtils.clearCurrentSession(selectedAgentId);
onNewSession();
}
}
} catch (error) {
console.error("Error deleting session:", error);
} finally {
setLoading(false);
}
}
};
@ -172,16 +319,16 @@ export function SessionManager({
updatedSessions.push(currentSession);
}
safeLocalStorage.setItem(
SESSIONS_STORAGE_KEY,
JSON.stringify(updatedSessions)
);
return updatedSessions;
});
}
}, [currentSession]);
// Don't render if no agent is selected
if (!selectedAgentId) {
return null;
}
return (
<div className="relative">
<div className="flex items-center gap-2">
@ -205,6 +352,7 @@ export function SessionManager({
onClick={() => setShowCreateForm(true)}
variant="outline"
size="sm"
disabled={loading || !selectedAgentId}
>
+ New
</Button>
@ -241,8 +389,12 @@ export function SessionManager({
/>
<div className="flex gap-2">
<Button onClick={createNewSession} className="flex-1">
Create
<Button
onClick={createNewSession}
className="flex-1"
disabled={loading}
>
{loading ? "Creating..." : "Create"}
</Button>
<Button
variant="outline"
@ -270,72 +422,147 @@ export function SessionManager({
}
export const SessionUtils = {
loadCurrentSession: (): ChatSession | null => {
const currentSessionId = safeLocalStorage.getItem(CURRENT_SESSION_KEY);
const savedSessions = safeLocalStorage.getItem(SESSIONS_STORAGE_KEY);
if (currentSessionId && savedSessions) {
const sessions = safeJsonParse<ChatSession[]>(savedSessions, []);
return sessions.find(s => s.id === currentSessionId) || null;
}
return null;
loadCurrentSessionId: (agentId?: string): string | null => {
const key = agentId
? `${CURRENT_SESSION_KEY}-${agentId}`
: CURRENT_SESSION_KEY;
return safeLocalStorage.getItem(key);
},
saveCurrentSession: (session: ChatSession) => {
const savedSessions = safeLocalStorage.getItem(SESSIONS_STORAGE_KEY);
const sessions = safeJsonParse<ChatSession[]>(savedSessions, []);
const existingIndex = sessions.findIndex(s => s.id === session.id);
if (existingIndex >= 0) {
sessions[existingIndex] = { ...session, updatedAt: Date.now() };
} else {
sessions.push({
...session,
createdAt: Date.now(),
updatedAt: Date.now(),
});
}
safeLocalStorage.setItem(SESSIONS_STORAGE_KEY, JSON.stringify(sessions));
safeLocalStorage.setItem(CURRENT_SESSION_KEY, session.id);
saveCurrentSessionId: (sessionId: string, agentId?: string) => {
const key = agentId
? `${CURRENT_SESSION_KEY}-${agentId}`
: CURRENT_SESSION_KEY;
safeLocalStorage.setItem(key, sessionId);
},
createDefaultSession: (
inheritModel?: string,
inheritVectorDb?: string
agentId: string,
inheritModel?: string
): ChatSession => ({
id: generateSessionId(),
name: "Default Session",
messages: [],
selectedModel: inheritModel || "",
selectedVectorDb: inheritVectorDb || "",
systemMessage: "You are a helpful assistant.",
agentId,
createdAt: Date.now(),
updatedAt: Date.now(),
}),
deleteSession: (
sessionId: string
): {
deletedSession: ChatSession | null;
remainingSessions: ChatSession[];
} => {
const savedSessions = safeLocalStorage.getItem(SESSIONS_STORAGE_KEY);
const sessions = safeJsonParse<ChatSession[]>(savedSessions, []);
clearCurrentSession: (agentId?: string) => {
const key = agentId
? `${CURRENT_SESSION_KEY}-${agentId}`
: CURRENT_SESSION_KEY;
safeLocalStorage.removeItem(key);
},
const sessionToDelete = sessions.find(s => s.id === sessionId);
const remainingSessions = sessions.filter(s => s.id !== sessionId);
loadCurrentAgentId: (): string | null => {
return safeLocalStorage.getItem("chat-playground-current-agent");
},
saveCurrentAgentId: (agentId: string) => {
safeLocalStorage.setItem("chat-playground-current-agent", agentId);
},
// Comprehensive session caching
saveSessionData: (agentId: string, sessionData: ChatSession) => {
const key = `chat-playground-session-data-${agentId}-${sessionData.id}`;
safeLocalStorage.setItem(
SESSIONS_STORAGE_KEY,
JSON.stringify(remainingSessions)
key,
JSON.stringify({
...sessionData,
cachedAt: Date.now(),
})
);
},
const currentSessionId = safeLocalStorage.getItem(CURRENT_SESSION_KEY);
if (currentSessionId === sessionId) {
safeLocalStorage.removeItem(CURRENT_SESSION_KEY);
loadSessionData: (agentId: string, sessionId: string): ChatSession | null => {
const key = `chat-playground-session-data-${agentId}-${sessionId}`;
const cached = safeLocalStorage.getItem(key);
if (!cached) return null;
try {
const data = JSON.parse(cached);
// Check if cache is fresh (less than 1 hour old)
const cacheAge = Date.now() - (data.cachedAt || 0);
if (cacheAge > 60 * 60 * 1000) {
safeLocalStorage.removeItem(key);
return null;
}
// Convert date strings back to Date objects
return {
...data,
messages: data.messages.map(
(msg: { createdAt: string; [key: string]: unknown }) => ({
...msg,
createdAt: new Date(msg.createdAt),
})
),
};
} catch (error) {
console.error("Error parsing cached session data:", error);
safeLocalStorage.removeItem(key);
return null;
}
},
return { deletedSession: sessionToDelete || null, remainingSessions };
// Agent config caching
saveAgentConfig: (
agentId: string,
config: {
toolgroups?: Array<
string | { name: string; args: Record<string, unknown> }
>;
[key: string]: unknown;
}
) => {
const key = `chat-playground-agent-config-${agentId}`;
safeLocalStorage.setItem(
key,
JSON.stringify({
config,
cachedAt: Date.now(),
})
);
},
loadAgentConfig: (
agentId: string
): {
toolgroups?: Array<
string | { name: string; args: Record<string, unknown> }
>;
[key: string]: unknown;
} | null => {
const key = `chat-playground-agent-config-${agentId}`;
const cached = safeLocalStorage.getItem(key);
if (!cached) return null;
try {
const data = JSON.parse(cached);
// Check if cache is fresh (less than 30 minutes old)
const cacheAge = Date.now() - (data.cachedAt || 0);
if (cacheAge > 30 * 60 * 1000) {
safeLocalStorage.removeItem(key);
return null;
}
return data.config;
} catch (error) {
console.error("Error parsing cached agent config:", error);
safeLocalStorage.removeItem(key);
return null;
}
},
// Clear all cached data for an agent
clearAgentCache: (agentId: string) => {
const keys = Object.keys(localStorage).filter(
key =>
key.includes(`chat-playground-session-data-${agentId}`) ||
key.includes(`chat-playground-agent-config-${agentId}`)
);
keys.forEach(key => safeLocalStorage.removeItem(key));
},
};

View file

@ -5,9 +5,9 @@ export function TypingIndicator() {
<div className="justify-left flex space-x-1">
<div className="rounded-lg bg-muted p-3">
<div className="flex -space-x-2.5">
<Dot className="h-5 w-5 animate-typing-dot-bounce" />
<Dot className="h-5 w-5 animate-typing-dot-bounce [animation-delay:90ms]" />
<Dot className="h-5 w-5 animate-typing-dot-bounce [animation-delay:180ms]" />
<Dot className="h-5 w-5 animate-typing-dot-1" />
<Dot className="h-5 w-5 animate-typing-dot-2" />
<Dot className="h-5 w-5 animate-typing-dot-3" />
</div>
</div>
</div>

View file

@ -11,6 +11,7 @@ import {
} from "lucide-react";
import Link from "next/link";
import { usePathname } from "next/navigation";
import Image from "next/image";
import { cn } from "@/lib/utils";
import {
@ -110,7 +111,16 @@ export function AppSidebar() {
return (
<Sidebar>
<SidebarHeader>
<Link href="/">Llama Stack</Link>
<Link href="/" className="flex items-center gap-2 p-2">
<Image
src="/logo.webp"
alt="Llama Stack"
width={32}
height={32}
className="h-8 w-8"
/>
<span className="font-semibold text-lg">Llama Stack</span>
</Link>
</SidebarHeader>
<SidebarContent>
<SidebarGroup>

View file

@ -18,7 +18,7 @@
"class-variance-authority": "^0.7.1",
"clsx": "^2.1.1",
"framer-motion": "^11.18.2",
"llama-stack-client": "0.2.17",
"llama-stack-client": "^0.2.18",
"lucide-react": "^0.510.0",
"next": "15.3.3",
"next-auth": "^4.24.11",
@ -9926,9 +9926,9 @@
"license": "MIT"
},
"node_modules/llama-stack-client": {
"version": "0.2.17",
"resolved": "https://registry.npmjs.org/llama-stack-client/-/llama-stack-client-0.2.17.tgz",
"integrity": "sha512-+/fEO8M7XPiVLjhH7ge18i1ijKp4+h3dOkE0C8g2cvGuDUtDYIJlf8NSyr9OMByjiWpCibWU7VOKL50LwGLS3Q==",
"version": "0.2.18",
"resolved": "https://registry.npmjs.org/llama-stack-client/-/llama-stack-client-0.2.18.tgz",
"integrity": "sha512-k+xQOz/TIU0cINP4Aih8q6xs7f/6qs0fLDMXTTKQr5C0F1jtCjRiwsas7bTsDfpKfYhg/7Xy/wPw/uZgi6aIVg==",
"license": "MIT",
"dependencies": {
"@types/node": "^18.11.18",

View file

@ -23,7 +23,7 @@
"class-variance-authority": "^0.7.1",
"clsx": "^2.1.1",
"framer-motion": "^11.18.2",
"llama-stack-client": "^0.2.17",
"llama-stack-client": "^0.2.18",
"lucide-react": "^0.510.0",
"next": "15.3.3",
"next-auth": "^4.24.11",

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 19 KiB