mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 11:22:35 +00:00
adding logo and favicon
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> chore: Enable keyword search for Milvus inline (#3073) With https://github.com/milvus-io/milvus-lite/pull/294 - Milvus Lite supports keyword search using BM25. While introducing keyword search we had explicitly disabled it for inline milvus. This PR removes the need for the check, and enables `inline::milvus` for tests. <!-- If resolving an issue, uncomment and update the line below --> <!-- Closes #[issue-number] --> Run llama stack with `inline::milvus` enabled: ``` pytest tests/integration/vector_io/test_openai_vector_stores.py::test_openai_vector_store_search_modes --stack-config=http://localhost:8321 --embedding-model=all-MiniLM-L6-v2 -v ``` ``` INFO 2025-08-07 17:06:20,932 tests.integration.conftest:64 tests: Setting DISABLE_CODE_SANDBOX=1 for macOS =========================================================================================== test session starts ============================================================================================ platform darwin -- Python 3.12.11, pytest-7.4.4, pluggy-1.5.0 -- /Users/vnarsing/miniconda3/envs/stack-client/bin/python cachedir: .pytest_cache metadata: {'Python': '3.12.11', 'Platform': 'macOS-14.7.6-arm64-arm-64bit', 'Packages': {'pytest': '7.4.4', 'pluggy': '1.5.0'}, 'Plugins': {'asyncio': '0.23.8', 'cov': '6.0.0', 'timeout': '2.2.0', 'socket': '0.7.0', 'html': '3.1.1', 'langsmith': '0.3.39', 'anyio': '4.8.0', 'metadata': '3.0.0'}} rootdir: /Users/vnarsing/go/src/github/meta-llama/llama-stack configfile: pyproject.toml plugins: asyncio-0.23.8, cov-6.0.0, timeout-2.2.0, socket-0.7.0, html-3.1.1, langsmith-0.3.39, anyio-4.8.0, metadata-3.0.0 asyncio: mode=Mode.AUTO collected 3 items tests/integration/vector_io/test_openai_vector_stores.py::test_openai_vector_store_search_modes[None-None-all-MiniLM-L6-v2-None-384-vector] PASSED [ 33%] tests/integration/vector_io/test_openai_vector_stores.py::test_openai_vector_store_search_modes[None-None-all-MiniLM-L6-v2-None-384-keyword] PASSED [ 66%] tests/integration/vector_io/test_openai_vector_stores.py::test_openai_vector_store_search_modes[None-None-all-MiniLM-L6-v2-None-384-hybrid] PASSED [100%] ============================================================================================ 3 passed in 4.75s ============================================================================================= ``` Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com> Co-authored-by: Francisco Arceo <arceofrancisco@gmail.com> chore: Fixup main pre commit (#3204) build: Bump version to 0.2.18 chore: Faster npm pre-commit (#3206) Adds npm to pre-commit.yml installation and caches ui Removes node installation during pre-commit. <!-- If resolving an issue, uncomment and update the line below --> <!-- Closes #[issue-number] --> <!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* --> Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> chiecking in for tonight, wip moving to agents api Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> remove log Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> updated Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> fix: disable ui-prettier & ui-eslint (#3207) chore(pre-commit): add pre-commit hook to enforce llama_stack logger usage (#3061) This PR adds a step in pre-commit to enforce using `llama_stack` logger. Currently, various parts of the code base uses different loggers. As a custom `llama_stack` logger exist and used in the codebase, it is better to standardize its utilization. Signed-off-by: Mustafa Elbehery <melbeher@redhat.com> Co-authored-by: Matthew Farrellee <matt@cs.wisc.edu> fix: fix ```openai_embeddings``` for asymmetric embedding NIMs (#3205) NVIDIA asymmetric embedding models (e.g., `nvidia/llama-3.2-nv-embedqa-1b-v2`) require an `input_type` parameter not present in the standard OpenAI embeddings API. This PR adds the `input_type="query"` as default and updates the documentation to suggest using the `embedding` API for passage embeddings. <!-- If resolving an issue, uncomment and update the line below --> Resolves #2892 ``` pytest -s -v tests/integration/inference/test_openai_embeddings.py --stack-config="inference=nvidia" --embedding-model="nvidia/llama-3.2-nv-embedqa-1b-v2" --env NVIDIA_API_KEY={nvidia_api_key} --env NVIDIA_BASE_URL="https://integrate.api.nvidia.com" ``` cleaning up Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> updating session manager to cache messages locally Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> fix linter Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> more cleanup Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
e7be568d7e
commit
6620b625f1
76 changed files with 2343 additions and 1187 deletions
15
.github/workflows/pre-commit.yml
vendored
15
.github/workflows/pre-commit.yml
vendored
|
|
@ -36,6 +36,21 @@ jobs:
|
|||
**/requirements*.txt
|
||||
.pre-commit-config.yaml
|
||||
|
||||
# npm ci may fail -
|
||||
# npm error `npm ci` can only install packages when your package.json and package-lock.json or npm-shrinkwrap.json are in sync. Please update your lock file with `npm install` before continuing.
|
||||
# npm error Invalid: lock file's llama-stack-client@0.2.17 does not satisfy llama-stack-client@0.2.18
|
||||
|
||||
# - name: Set up Node.js
|
||||
# uses: actions/setup-node@39370e3970a6d050c480ffad4ff0ed4d3fdee5af # v4.1.0
|
||||
# with:
|
||||
# node-version: '20'
|
||||
# cache: 'npm'
|
||||
# cache-dependency-path: 'llama_stack/ui/'
|
||||
|
||||
# - name: Install npm dependencies
|
||||
# run: npm ci
|
||||
# working-directory: llama_stack/ui
|
||||
|
||||
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
|
||||
continue-on-error: true
|
||||
env:
|
||||
|
|
|
|||
|
|
@ -146,20 +146,50 @@ repos:
|
|||
pass_filenames: false
|
||||
require_serial: true
|
||||
files: ^.github/workflows/.*$
|
||||
- id: ui-prettier
|
||||
name: Format UI code with Prettier
|
||||
entry: bash -c 'cd llama_stack/ui && npm run format'
|
||||
# ui-prettier and ui-eslint are disabled until we can avoid `npm ci`, which is slow and may fail -
|
||||
# npm error `npm ci` can only install packages when your package.json and package-lock.json or npm-shrinkwrap.json are in sync. Please update your lock file with `npm install` before continuing.
|
||||
# npm error Invalid: lock file's llama-stack-client@0.2.17 does not satisfy llama-stack-client@0.2.18
|
||||
# and until we have infra for installing prettier and next via npm -
|
||||
# Lint UI code with ESLint.....................................................Failed
|
||||
# - hook id: ui-eslint
|
||||
# - exit code: 127
|
||||
# > ui@0.1.0 lint
|
||||
# > next lint --fix --quiet
|
||||
# sh: line 1: next: command not found
|
||||
#
|
||||
# - id: ui-prettier
|
||||
# name: Format UI code with Prettier
|
||||
# entry: bash -c 'cd llama_stack/ui && npm ci && npm run format'
|
||||
# language: system
|
||||
# files: ^llama_stack/ui/.*\.(ts|tsx)$
|
||||
# pass_filenames: false
|
||||
# require_serial: true
|
||||
# - id: ui-eslint
|
||||
# name: Lint UI code with ESLint
|
||||
# entry: bash -c 'cd llama_stack/ui && npm run lint -- --fix --quiet'
|
||||
# language: system
|
||||
# files: ^llama_stack/ui/.*\.(ts|tsx)$
|
||||
# pass_filenames: false
|
||||
# require_serial: true
|
||||
|
||||
- id: check-log-usage
|
||||
name: Ensure 'llama_stack.log' usage for logging
|
||||
entry: bash
|
||||
language: system
|
||||
files: ^llama_stack/ui/.*\.(ts|tsx)$
|
||||
pass_filenames: false
|
||||
require_serial: true
|
||||
- id: ui-eslint
|
||||
name: Lint UI code with ESLint
|
||||
entry: bash -c 'cd llama_stack/ui && npm run lint -- --fix --quiet'
|
||||
language: system
|
||||
files: ^llama_stack/ui/.*\.(ts|tsx)$
|
||||
pass_filenames: false
|
||||
require_serial: true
|
||||
types: [python]
|
||||
pass_filenames: true
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
matches=$(grep -EnH '^[^#]*\b(import\s+logging|from\s+logging\b)' "$@" | grep -v -e '#\s*allow-direct-logging' || true)
|
||||
if [ -n "$matches" ]; then
|
||||
# GitHub Actions annotation format
|
||||
while IFS=: read -r file line_num rest; do
|
||||
echo "::error file=$file,line=$line_num::Do not use 'import logging' or 'from logging import' in $file. Use the custom log instead: from llama_stack.log import get_logger; logger = get_logger(). If direct logging is truly needed, add: # allow-direct-logging"
|
||||
done <<< "$matches"
|
||||
exit 1
|
||||
fi
|
||||
exit 0
|
||||
|
||||
ci:
|
||||
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import importlib.resources
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -17,9 +16,10 @@ from llama_stack.core.external import load_external_apis
|
|||
from llama_stack.core.utils.exec import run_command
|
||||
from llama_stack.core.utils.image_types import LlamaStackImageType
|
||||
from llama_stack.distributions.template import DistributionTemplate
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
# These are the dependencies needed by the distribution server.
|
||||
# `llama-stack` is automatically installed by the installation script.
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import logging
|
||||
import textwrap
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -21,9 +20,10 @@ from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
|
|||
from llama_stack.core.utils.config_dirs import EXTERNAL_PROVIDERS_DIR
|
||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.core.utils.prompt_for_config import prompt_for_config
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
def configure_single_provider(registry: dict[str, ProviderSpec], provider: Provider) -> Provider:
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import logging # allow-direct-logging
|
||||
import os
|
||||
import sys
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
|
@ -48,6 +48,7 @@ from llama_stack.core.stack import (
|
|||
from llama_stack.core.utils.config import redact_sensitive_fields
|
||||
from llama_stack.core.utils.context import preserve_contexts_async_generator
|
||||
from llama_stack.core.utils.exec import in_notebook
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
CURRENT_TRACE_CONTEXT,
|
||||
end_trace,
|
||||
|
|
@ -55,7 +56,7 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
|||
start_trace,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
|
|
|||
|
|
@ -6,15 +6,15 @@
|
|||
|
||||
import contextvars
|
||||
import json
|
||||
import logging
|
||||
from contextlib import AbstractContextManager
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.datatypes import User
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .utils.dynamic import instantiate_class_type
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
# Context variable for request provider data and auth attributes
|
||||
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import asyncio
|
|||
import functools
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import logging # allow-direct-logging
|
||||
import os
|
||||
import ssl
|
||||
import sys
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import importlib
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
|
|
@ -12,9 +12,9 @@ import sys
|
|||
|
||||
from termcolor import cprint
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
import importlib
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
def formulate_run_args(image_type: str, image_name: str) -> list:
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@
|
|||
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any, Literal, Union, get_args, get_origin
|
||||
|
||||
|
|
@ -14,7 +13,9 @@ from pydantic import BaseModel
|
|||
from pydantic.fields import FieldInfo
|
||||
from pydantic_core import PydanticUndefinedType
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
def is_list_of_primitives(field_type):
|
||||
|
|
|
|||
|
|
@ -4,10 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import logging # allow-direct-logging
|
||||
import os
|
||||
import re
|
||||
from logging.config import dictConfig
|
||||
from logging.config import dictConfig # allow-direct-logging
|
||||
|
||||
from rich.console import Console
|
||||
from rich.errors import MarkupError
|
||||
|
|
|
|||
|
|
@ -13,14 +13,15 @@
|
|||
|
||||
# Copyright (c) Meta Platforms, Inc. and its affiliates.
|
||||
import math
|
||||
from logging import getLogger
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .utils import get_negative_inf_value, to_2tuple
|
||||
|
||||
logger = getLogger()
|
||||
logger = get_logger(name=__name__, category="models::llama")
|
||||
|
||||
|
||||
def resize_local_position_embedding(orig_pos_embed, grid_size):
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@
|
|||
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from logging import getLogger
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
|
@ -21,9 +20,11 @@ import torchvision.transforms as tv
|
|||
from PIL import Image
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
IMAGE_RES = 224
|
||||
|
||||
logger = getLogger()
|
||||
logger = get_logger(name=__name__, category="models::llama")
|
||||
|
||||
|
||||
class VariableSizeImageTransform:
|
||||
|
|
|
|||
|
|
@ -3,8 +3,6 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import math
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
|
|
@ -22,6 +20,8 @@ from PIL import Image as PIL_Image
|
|||
from torch import Tensor, nn
|
||||
from torch.distributed import _functional_collectives as funcol
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from ..model import ModelArgs, RMSNorm, apply_rotary_emb, precompute_freqs_cis
|
||||
from .encoder_utils import (
|
||||
build_encoder_attention_mask,
|
||||
|
|
@ -34,9 +34,10 @@ from .encoder_utils import (
|
|||
from .image_transform import VariableSizeImageTransform
|
||||
from .utils import get_negative_inf_value, to_2tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
MP_SCALE = 8
|
||||
|
||||
logger = get_logger(name=__name__, category="models")
|
||||
|
||||
|
||||
def reduce_from_tensor_model_parallel_region(input_):
|
||||
"""All-reduce the input tensor across model parallel group."""
|
||||
|
|
@ -771,7 +772,7 @@ class TilePositionEmbedding(nn.Module):
|
|||
if embed is not None:
|
||||
# reshape the weights to the correct shape
|
||||
nt_old, nt_old, _, w = embed.shape
|
||||
logging.info(f"Resizing tile embedding from {nt_old}x{nt_old} to {self.num_tiles}x{self.num_tiles}")
|
||||
logger.info(f"Resizing tile embedding from {nt_old}x{nt_old} to {self.num_tiles}x{self.num_tiles}")
|
||||
embed_new = TilePositionEmbedding._dynamic_resize(embed, self.num_tiles)
|
||||
# assign the weights to the module
|
||||
state_dict[prefix + "embedding"] = embed_new
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from collections.abc import Collection, Iterator, Sequence, Set
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Literal,
|
||||
|
|
@ -14,11 +14,9 @@ from typing import (
|
|||
|
||||
import tiktoken
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.tokenizer_utils import load_bpe_file
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
# The tiktoken tokenizer can handle <=400k chars without
|
||||
# pyo3_runtime.PanicException.
|
||||
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
|
||||
|
|
@ -31,6 +29,8 @@ MAX_NO_WHITESPACES_CHARS = 25_000
|
|||
|
||||
_INSTANCE = None
|
||||
|
||||
logger = get_logger(name=__name__, category="models::llama")
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
|
||||
|
|
@ -13,11 +12,13 @@ from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
|
|||
from torch import Tensor, nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from ...datatypes import QuantizationMode
|
||||
from ..model import Transformer, TransformerBlock
|
||||
from ..moe import MoE
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="models")
|
||||
|
||||
|
||||
def swiglu_wrapper_no_reduce(
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import Collection, Iterator, Sequence, Set
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Literal,
|
||||
|
|
@ -14,11 +13,9 @@ from typing import (
|
|||
|
||||
import tiktoken
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.tokenizer_utils import load_bpe_file
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
# The tiktoken tokenizer can handle <=400k chars without
|
||||
# pyo3_runtime.PanicException.
|
||||
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
|
||||
|
|
@ -101,6 +98,8 @@ BASIC_SPECIAL_TOKENS = [
|
|||
"<|fim_suffix|>",
|
||||
]
|
||||
|
||||
logger = get_logger(name=__name__, category="models::llama")
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -6,9 +6,10 @@
|
|||
|
||||
# type: ignore
|
||||
import collections
|
||||
import logging
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
log = get_logger(name=__name__, category="llama")
|
||||
|
||||
try:
|
||||
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import UTC, datetime
|
||||
|
|
@ -42,6 +41,7 @@ from llama_stack.apis.safety import Safety
|
|||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
||||
from llama_stack.providers.utils.pagination import paginate_records
|
||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||
|
|
@ -51,7 +51,7 @@ from .config import MetaReferenceAgentsImplConfig
|
|||
from .persistence import AgentInfo
|
||||
from .responses.openai_responses import OpenAIResponsesImpl
|
||||
|
||||
logger = logging.getLogger()
|
||||
logger = get_logger(name=__name__, category="agents")
|
||||
|
||||
|
||||
class MetaReferenceAgentsImpl(Agents):
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
|
|
@ -15,9 +14,10 @@ from llama_stack.core.access_control.access_control import AccessDeniedError, is
|
|||
from llama_stack.core.access_control.datatypes import AccessRule
|
||||
from llama_stack.core.datatypes import User
|
||||
from llama_stack.core.request_headers import get_authenticated_user
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="agents")
|
||||
|
||||
|
||||
class AgentSessionInfo(Session):
|
||||
|
|
|
|||
|
|
@ -5,13 +5,13 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.telemetry import tracing
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="agents")
|
||||
|
||||
|
||||
class SafetyException(Exception): # noqa: N818
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@
|
|||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import tempfile
|
||||
|
|
@ -32,13 +31,14 @@ from fairscale.nn.model_parallel.initialize import (
|
|||
from pydantic import BaseModel, Field
|
||||
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import GenerationResult
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
ChatCompletionRequestWithRawContent,
|
||||
CompletionRequestWithRawContent,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
class ProcessingMessageName(str, Enum):
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
|
|
@ -21,6 +20,7 @@ from llama_stack.apis.inference import (
|
|||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
|
|
@ -32,7 +32,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
|
||||
from .config import SentenceTransformersInferenceConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
class SentenceTransformersInferenceImpl(
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@
|
|||
|
||||
import gc
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
|
@ -28,6 +27,7 @@ from llama_stack.apis.post_training import (
|
|||
LoraFinetuningConfig,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
||||
|
||||
from ..config import HuggingFacePostTrainingConfig
|
||||
|
|
@ -44,7 +44,7 @@ from ..utils import (
|
|||
split_dataset,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="post_training")
|
||||
|
||||
|
||||
class HFFinetuningSingleDevice:
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import multiprocessing
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
|
@ -24,6 +23,7 @@ from llama_stack.apis.post_training import (
|
|||
DPOAlignmentConfig,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
||||
|
||||
from ..config import HuggingFacePostTrainingConfig
|
||||
|
|
@ -40,7 +40,7 @@ from ..utils import (
|
|||
split_dataset,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="post_training")
|
||||
|
||||
|
||||
class HFDPOAlignmentSingleDevice:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
|
|
@ -19,10 +18,11 @@ from transformers import AutoConfig, AutoModelForCausalLM
|
|||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.post_training import Checkpoint, TrainingConfig
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .config import HuggingFacePostTrainingConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="post_training")
|
||||
|
||||
|
||||
def setup_environment():
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
|
|
@ -19,6 +18,7 @@ from torch.utils.data import DataLoader, DistributedSampler
|
|||
from torchtune import modules, training
|
||||
from torchtune import utils as torchtune_utils
|
||||
from torchtune.data import padded_collate_sft
|
||||
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
||||
from torchtune.modules.loss import CEWithChunkedOutputLoss
|
||||
from torchtune.modules.peft import (
|
||||
get_adapter_params,
|
||||
|
|
@ -45,6 +45,7 @@ from llama_stack.apis.post_training import (
|
|||
)
|
||||
from llama_stack.core.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||
from llama_stack.core.utils.model_utils import model_local_dir
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
||||
from llama_stack.providers.inline.post_training.torchtune.common import utils
|
||||
|
|
@ -56,9 +57,7 @@ from llama_stack.providers.inline.post_training.torchtune.config import (
|
|||
)
|
||||
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
||||
log = get_logger(name=__name__, category="post_training")
|
||||
|
||||
|
||||
class LoraFinetuningSingleDevice:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
|
|
@ -20,13 +19,14 @@ from llama_stack.apis.safety import (
|
|||
)
|
||||
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
from .config import CodeScannerConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="safety")
|
||||
|
||||
ALLOWED_CODE_SCANNER_MODEL_IDS = [
|
||||
"code-scanner",
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from string import Template
|
||||
|
|
@ -21,6 +20,7 @@ from llama_stack.apis.safety import (
|
|||
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.core.datatypes import Api
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import Role
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
|
|
@ -132,6 +132,8 @@ Provide your safety assessment for ONLY THE LAST $agent_type message in the abov
|
|||
|
||||
PROMPT_TEMPLATE = Template(f"{PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS}")
|
||||
|
||||
logger = get_logger(name=__name__, category="safety")
|
||||
|
||||
|
||||
class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||
def __init__(self, config: LlamaGuardConfig, deps) -> None:
|
||||
|
|
@ -407,7 +409,7 @@ class LlamaGuardShield:
|
|||
unsafe_code_list = [code.strip() for code in unsafe_code.split(",")]
|
||||
invalid_codes = [code for code in unsafe_code_list if code not in SAFETY_CODE_TO_CATEGORIES_MAP]
|
||||
if invalid_codes:
|
||||
logging.warning(f"Invalid safety codes returned: {invalid_codes}")
|
||||
logger.warning(f"Invalid safety codes returned: {invalid_codes}")
|
||||
# just returning safe object, as we don't know what the invalid codes can map to
|
||||
return ModerationObject(
|
||||
id=f"modr-{uuid.uuid4()}",
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
|
@ -21,6 +20,7 @@ from llama_stack.apis.safety import (
|
|||
from llama_stack.apis.safety.safety import ModerationObject
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.core.utils.model_utils import model_local_dir
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
|
|
@ -28,7 +28,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
from .config import PromptGuardConfig, PromptGuardType
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="safety")
|
||||
|
||||
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@
|
|||
import collections
|
||||
import functools
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
import string
|
||||
|
|
@ -20,7 +19,9 @@ import nltk
|
|||
from pythainlp.tokenize import sent_tokenize as sent_tokenize_thai
|
||||
from pythainlp.tokenize import word_tokenize as word_tokenize_thai
|
||||
|
||||
logger = logging.getLogger()
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="scoring")
|
||||
|
||||
WORD_LIST = [
|
||||
"western",
|
||||
|
|
|
|||
|
|
@ -4,13 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
from opentelemetry import metrics, trace
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.sdk.metrics import MeterProvider
|
||||
|
|
@ -40,6 +37,7 @@ from llama_stack.apis.telemetry import (
|
|||
UnstructuredLogEvent,
|
||||
)
|
||||
from llama_stack.core.datatypes import Api
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import (
|
||||
ConsoleSpanProcessor,
|
||||
)
|
||||
|
|
@ -61,6 +59,8 @@ _GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = {
|
|||
_global_lock = threading.Lock()
|
||||
_TRACER_PROVIDER = None
|
||||
|
||||
logger = get_logger(name=__name__, category="telemetry")
|
||||
|
||||
|
||||
def is_tracing_enabled(tracer):
|
||||
with tracer.start_as_current_span("check_tracing") as span:
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import secrets
|
||||
import string
|
||||
from typing import Any
|
||||
|
|
@ -32,6 +31,7 @@ from llama_stack.apis.tools import (
|
|||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
|
|
@ -42,7 +42,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
from .config import RagToolRuntimeConfig
|
||||
from .context_retriever import generate_rag_query
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="tool_runtime")
|
||||
|
||||
|
||||
def make_random_string(length: int = 8):
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ import asyncio
|
|||
import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import faiss
|
||||
|
|
@ -24,6 +23,7 @@ from llama_stack.apis.vector_io import (
|
|||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import (
|
||||
HealthResponse,
|
||||
HealthStatus,
|
||||
|
|
@ -40,7 +40,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
|
||||
from .config import FaissVectorIOConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="vector_io")
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:{VERSION}::"
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import sqlite3
|
||||
import struct
|
||||
|
|
@ -24,6 +23,7 @@ from llama_stack.apis.vector_io import (
|
|||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
|
|
@ -36,7 +36,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
VectorDBWithIndex,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="vector_io")
|
||||
|
||||
# Specifying search mode is dependent on the VectorIO provider.
|
||||
VECTOR_SEARCH = "vector"
|
||||
|
|
|
|||
|
|
@ -3,15 +3,14 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import logging
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||
|
|
|
|||
|
|
@ -77,6 +77,10 @@ print(f"Response: {response.completion_message.content}")
|
|||
```
|
||||
|
||||
### Create Embeddings
|
||||
> Note on OpenAI embeddings compatibility
|
||||
>
|
||||
> NVIDIA asymmetric embedding models (e.g., `nvidia/llama-3.2-nv-embedqa-1b-v2`) require an `input_type` parameter not present in the standard OpenAI embeddings API. The NVIDIA Inference Adapter automatically sets `input_type="query"` when using the OpenAI-compatible embeddings endpoint for NVIDIA. For passage embeddings, use the `embeddings` API with `task_type="document"`.
|
||||
|
||||
```python
|
||||
response = client.inference.embeddings(
|
||||
model_id="nvidia/llama-3.2-nv-embedqa-1b-v2",
|
||||
|
|
|
|||
|
|
@ -4,11 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from openai import APIConnectionError, BadRequestError
|
||||
from openai import NOT_GIVEN, APIConnectionError, BadRequestError
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
|
|
@ -27,12 +26,16 @@ from llama_stack.apis.inference import (
|
|||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingData,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
|
|
@ -54,7 +57,7 @@ from .openai_utils import (
|
|||
)
|
||||
from .utils import _is_nvidia_hosted
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
|
||||
|
|
@ -210,6 +213,57 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
|
|||
#
|
||||
return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data])
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
"""
|
||||
OpenAI-compatible embeddings for NVIDIA NIM.
|
||||
|
||||
Note: NVIDIA NIM asymmetric embedding models require an "input_type" field not present in the standard OpenAI embeddings API.
|
||||
We default this to "query" to ensure requests succeed when using the
|
||||
OpenAI-compatible endpoint. For passage embeddings, use the embeddings API with
|
||||
`task_type='document'`.
|
||||
"""
|
||||
extra_body: dict[str, object] = {"input_type": "query"}
|
||||
logger.warning(
|
||||
"NVIDIA OpenAI-compatible embeddings: defaulting to input_type='query'. "
|
||||
"For passage embeddings, use the embeddings API with task_type='document'."
|
||||
)
|
||||
|
||||
response = await self.client.embeddings.create(
|
||||
model=await self._get_provider_model_id(model),
|
||||
input=input,
|
||||
encoding_format=encoding_format if encoding_format is not None else NOT_GIVEN,
|
||||
dimensions=dimensions if dimensions is not None else NOT_GIVEN,
|
||||
user=user if user is not None else NOT_GIVEN,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
|
||||
data = []
|
||||
for i, embedding_data in enumerate(response.data):
|
||||
data.append(
|
||||
OpenAIEmbeddingData(
|
||||
embedding=embedding_data.embedding,
|
||||
index=i,
|
||||
)
|
||||
)
|
||||
|
||||
usage = OpenAIEmbeddingUsage(
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
total_tokens=response.usage.total_tokens,
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingsResponse(
|
||||
data=data,
|
||||
model=response.model,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
|
|
|
|||
|
|
@ -4,13 +4,13 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
|
||||
import httpx
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from . import NVIDIAConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:
|
||||
|
|
|
|||
|
|
@ -4,15 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import OpenAIConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
#
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from huggingface_hub import AsyncInferenceClient, HfApi
|
||||
|
|
@ -34,6 +33,7 @@ from llama_stack.apis.inference import (
|
|||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.sku_list import all_registered_models
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
|
|
@ -58,7 +58,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
def build_hf_repo_model_entries():
|
||||
|
|
|
|||
|
|
@ -4,18 +4,18 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.post_training import TrainingConfig
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.remote.post_training.nvidia.config import SFTLoRADefaultConfig
|
||||
|
||||
from .config import NvidiaPostTrainingConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="integration")
|
||||
|
||||
|
||||
def warn_unsupported_params(config_dict: Any, supported_keys: set[str], config_name: str) -> None:
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
|
|
@ -16,12 +15,13 @@ from llama_stack.apis.safety import (
|
|||
ViolationLevel,
|
||||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
||||
|
||||
from .config import BedrockSafetyConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="safety")
|
||||
|
||||
|
||||
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
|
|
@ -12,12 +11,13 @@ import requests
|
|||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import RunShieldResponse, Safety, SafetyViolation, ViolationLevel
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new
|
||||
|
||||
from .config import NVIDIASafetyConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="safety")
|
||||
|
||||
|
||||
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
|
|
@ -20,12 +19,13 @@ from llama_stack.apis.safety import (
|
|||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new
|
||||
|
||||
from .config import SambaNovaSafetyConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="safety")
|
||||
|
||||
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
|
@ -20,6 +19,7 @@ from llama_stack.apis.vector_io import (
|
|||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
|
@ -33,7 +33,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
|
||||
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="vector_io")
|
||||
|
||||
ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -21,6 +20,7 @@ from llama_stack.apis.vector_io import (
|
|||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
|
@ -36,7 +36,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti
|
|||
|
||||
from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="vector_io")
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:milvus:{VERSION}::"
|
||||
|
|
@ -413,15 +413,6 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
|
||||
if params and params.get("mode") == "keyword":
|
||||
# Check if this is inline Milvus (Milvus-Lite)
|
||||
if hasattr(self.config, "db_path"):
|
||||
raise NotImplementedError(
|
||||
"Keyword search is not supported in Milvus-Lite. "
|
||||
"Please use a remote Milvus server for keyword search functionality."
|
||||
)
|
||||
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import psycopg2
|
||||
|
|
@ -22,6 +21,7 @@ from llama_stack.apis.vector_io import (
|
|||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
|
|
@ -34,7 +34,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
|
||||
from .config import PGVectorVectorIOConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="vector_io")
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:pgvector:{VERSION}::"
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -24,6 +23,7 @@ from llama_stack.apis.vector_io import (
|
|||
VectorStoreChunkingStrategy,
|
||||
VectorStoreFileObject,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
|
||||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
||||
|
|
@ -36,7 +36,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
|
||||
from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="vector_io")
|
||||
CHUNK_ID_KEY = "_chunk_id"
|
||||
|
||||
# KV store prefixes for vector databases
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import weaviate
|
||||
|
|
@ -19,6 +18,7 @@ from llama_stack.apis.files.files import Files
|
|||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
|
|
@ -34,7 +34,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti
|
|||
|
||||
from .config import WeaviateVectorIOConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="vector_io")
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::"
|
||||
|
|
|
|||
|
|
@ -5,10 +5,11 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import struct
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
|
|
@ -27,7 +28,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import interleaved_con
|
|||
EMBEDDING_MODELS = {}
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
class SentenceTransformerEmbeddingMixin:
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import struct
|
||||
import time
|
||||
import uuid
|
||||
|
|
@ -122,6 +121,7 @@ from llama_stack.apis.inference import (
|
|||
from llama_stack.apis.inference import (
|
||||
OpenAIChoice as OpenAIChatCompletionChoice,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
StopReason,
|
||||
|
|
@ -134,7 +134,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
decode_assistant_message,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
class OpenAICompatCompletionChoiceDelta(BaseModel):
|
||||
|
|
|
|||
|
|
@ -4,16 +4,16 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from pymongo import AsyncMongoClient
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
||||
from ..config import MongoDBKVStoreConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="kvstore")
|
||||
|
||||
|
||||
class MongoDBKVStoreImpl(KVStore):
|
||||
|
|
|
|||
|
|
@ -4,16 +4,17 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
import psycopg2
|
||||
from psycopg2.extras import DictCursor
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from ..api import KVStore
|
||||
from ..config import PostgresKVStoreConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="kvstore")
|
||||
|
||||
|
||||
class PostgresKVStoreImpl(KVStore):
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
make_overlapped_chunks,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__, category="vector_io")
|
||||
logger = get_logger(name=__name__, category="memory")
|
||||
|
||||
# Constants for OpenAI vector stores
|
||||
CHUNK_MULTIPLIER = 5
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
|
|
@ -26,6 +25,7 @@ from llama_stack.apis.common.content_types import (
|
|||
from llama_stack.apis.tools import RAGDocument
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
|
|
@ -33,7 +33,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
)
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="memory")
|
||||
|
||||
|
||||
class ChunkForDeletion(BaseModel):
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import logging
|
||||
import logging # allow-direct-logging
|
||||
import queue
|
||||
import random
|
||||
import sys
|
||||
|
|
|
|||
587
llama_stack/ui/app/chat-playground/page.test.tsx
Normal file
587
llama_stack/ui/app/chat-playground/page.test.tsx
Normal file
|
|
@ -0,0 +1,587 @@
|
|||
import React from "react";
|
||||
import {
|
||||
render,
|
||||
screen,
|
||||
fireEvent,
|
||||
waitFor,
|
||||
act,
|
||||
} from "@testing-library/react";
|
||||
import "@testing-library/jest-dom";
|
||||
import ChatPlaygroundPage from "./page";
|
||||
|
||||
const mockClient = {
|
||||
agents: {
|
||||
list: jest.fn(),
|
||||
create: jest.fn(),
|
||||
retrieve: jest.fn(),
|
||||
delete: jest.fn(),
|
||||
session: {
|
||||
list: jest.fn(),
|
||||
create: jest.fn(),
|
||||
delete: jest.fn(),
|
||||
retrieve: jest.fn(),
|
||||
},
|
||||
turn: {
|
||||
create: jest.fn(),
|
||||
},
|
||||
},
|
||||
models: {
|
||||
list: jest.fn(),
|
||||
},
|
||||
toolgroups: {
|
||||
list: jest.fn(),
|
||||
},
|
||||
};
|
||||
|
||||
jest.mock("@/hooks/use-auth-client", () => ({
|
||||
useAuthClient: jest.fn(() => mockClient),
|
||||
}));
|
||||
|
||||
jest.mock("@/components/chat-playground/chat", () => ({
|
||||
Chat: jest.fn(
|
||||
({
|
||||
className,
|
||||
messages,
|
||||
handleSubmit,
|
||||
input,
|
||||
handleInputChange,
|
||||
isGenerating,
|
||||
append,
|
||||
suggestions,
|
||||
}) => (
|
||||
<div data-testid="chat-component" className={className}>
|
||||
<div data-testid="messages-count">{messages.length}</div>
|
||||
<input
|
||||
data-testid="chat-input"
|
||||
value={input}
|
||||
onChange={handleInputChange}
|
||||
disabled={isGenerating}
|
||||
/>
|
||||
<button data-testid="submit-button" onClick={handleSubmit}>
|
||||
Submit
|
||||
</button>
|
||||
{suggestions?.map((suggestion: string, index: number) => (
|
||||
<button
|
||||
key={index}
|
||||
data-testid={`suggestion-${index}`}
|
||||
onClick={() => append({ role: "user", content: suggestion })}
|
||||
>
|
||||
{suggestion}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
)
|
||||
),
|
||||
}));
|
||||
|
||||
jest.mock("@/components/chat-playground/session-manager", () => ({
|
||||
SessionManager: jest.fn(({ selectedAgentId, onNewSession }) => (
|
||||
<div data-testid="session-manager">
|
||||
{selectedAgentId && (
|
||||
<>
|
||||
<div data-testid="selected-agent">{selectedAgentId}</div>
|
||||
<button data-testid="new-session-button" onClick={onNewSession}>
|
||||
New Session
|
||||
</button>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
)),
|
||||
SessionUtils: {
|
||||
saveCurrentSessionId: jest.fn(),
|
||||
loadCurrentSessionId: jest.fn(),
|
||||
loadCurrentAgentId: jest.fn(),
|
||||
saveCurrentAgentId: jest.fn(),
|
||||
clearCurrentSession: jest.fn(),
|
||||
saveSessionData: jest.fn(),
|
||||
loadSessionData: jest.fn(),
|
||||
saveAgentConfig: jest.fn(),
|
||||
loadAgentConfig: jest.fn(),
|
||||
clearAgentCache: jest.fn(),
|
||||
createDefaultSession: jest.fn(() => ({
|
||||
id: "test-session-123",
|
||||
name: "Default Session",
|
||||
messages: [],
|
||||
selectedModel: "",
|
||||
systemMessage: "You are a helpful assistant.",
|
||||
agentId: "test-agent-123",
|
||||
createdAt: Date.now(),
|
||||
updatedAt: Date.now(),
|
||||
})),
|
||||
},
|
||||
}));
|
||||
|
||||
const mockAgents = [
|
||||
{
|
||||
agent_id: "agent_123",
|
||||
agent_config: {
|
||||
name: "Test Agent",
|
||||
instructions: "You are a test assistant.",
|
||||
},
|
||||
},
|
||||
{
|
||||
agent_id: "agent_456",
|
||||
agent_config: {
|
||||
agent_name: "Another Agent",
|
||||
instructions: "You are another assistant.",
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
const mockModels = [
|
||||
{
|
||||
identifier: "test-model-1",
|
||||
model_type: "llm",
|
||||
},
|
||||
{
|
||||
identifier: "test-model-2",
|
||||
model_type: "llm",
|
||||
},
|
||||
];
|
||||
|
||||
const mockToolgroups = [
|
||||
{
|
||||
identifier: "builtin::rag",
|
||||
provider_id: "test-provider",
|
||||
type: "tool_group",
|
||||
provider_resource_id: "test-resource",
|
||||
},
|
||||
];
|
||||
|
||||
describe("ChatPlaygroundPage", () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
Element.prototype.scrollIntoView = jest.fn();
|
||||
mockClient.agents.list.mockResolvedValue({ data: mockAgents });
|
||||
mockClient.models.list.mockResolvedValue(mockModels);
|
||||
mockClient.toolgroups.list.mockResolvedValue(mockToolgroups);
|
||||
mockClient.agents.session.create.mockResolvedValue({
|
||||
session_id: "new-session-123",
|
||||
});
|
||||
mockClient.agents.session.list.mockResolvedValue({ data: [] });
|
||||
mockClient.agents.session.retrieve.mockResolvedValue({
|
||||
session_id: "test-session",
|
||||
session_name: "Test Session",
|
||||
started_at: new Date().toISOString(),
|
||||
turns: [],
|
||||
}); // No turns by default
|
||||
mockClient.agents.retrieve.mockResolvedValue({
|
||||
agent_id: "test-agent",
|
||||
agent_config: {
|
||||
toolgroups: ["builtin::rag"],
|
||||
instructions: "Test instructions",
|
||||
model: "test-model",
|
||||
},
|
||||
});
|
||||
mockClient.agents.delete.mockResolvedValue(undefined);
|
||||
});
|
||||
|
||||
describe("Agent Selector Rendering", () => {
|
||||
test("shows agent selector when agents are available", async () => {
|
||||
await act(async () => {
|
||||
render(<ChatPlaygroundPage />);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("Agent Session:")).toBeInTheDocument();
|
||||
expect(screen.getAllByRole("combobox")).toHaveLength(2);
|
||||
expect(screen.getByText("+ New Agent")).toBeInTheDocument();
|
||||
expect(screen.getByText("Clear Chat")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("does not show agent selector when no agents are available", async () => {
|
||||
mockClient.agents.list.mockResolvedValue({ data: [] });
|
||||
|
||||
await act(async () => {
|
||||
render(<ChatPlaygroundPage />);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByText("Agent Session:")).not.toBeInTheDocument();
|
||||
expect(screen.getAllByRole("combobox")).toHaveLength(1);
|
||||
expect(screen.getByText("+ New Agent")).toBeInTheDocument();
|
||||
expect(screen.queryByText("Clear Chat")).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("does not show agent selector while loading", async () => {
|
||||
mockClient.agents.list.mockImplementation(() => new Promise(() => {}));
|
||||
|
||||
await act(async () => {
|
||||
render(<ChatPlaygroundPage />);
|
||||
});
|
||||
|
||||
expect(screen.queryByText("Agent Session:")).not.toBeInTheDocument();
|
||||
expect(screen.getAllByRole("combobox")).toHaveLength(1);
|
||||
expect(screen.getByText("+ New Agent")).toBeInTheDocument();
|
||||
expect(screen.queryByText("Clear Chat")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("shows agent options in selector", async () => {
|
||||
await act(async () => {
|
||||
render(<ChatPlaygroundPage />);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
const agentCombobox = screen.getAllByRole("combobox").find(element => {
|
||||
return (
|
||||
element.textContent?.includes("Test Agent") ||
|
||||
element.textContent?.includes("Select Agent")
|
||||
);
|
||||
});
|
||||
expect(agentCombobox).toBeDefined();
|
||||
fireEvent.click(agentCombobox!);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getAllByText("Test Agent")).toHaveLength(2);
|
||||
expect(screen.getByText("Another Agent")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("displays agent ID when no name is available", async () => {
|
||||
const agentWithoutName = {
|
||||
agent_id: "agent_789",
|
||||
agent_config: {
|
||||
instructions: "You are an agent without a name.",
|
||||
},
|
||||
};
|
||||
|
||||
mockClient.agents.list.mockResolvedValue({ data: [agentWithoutName] });
|
||||
|
||||
await act(async () => {
|
||||
render(<ChatPlaygroundPage />);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
const agentCombobox = screen.getAllByRole("combobox").find(element => {
|
||||
return (
|
||||
element.textContent?.includes("Agent agent_78") ||
|
||||
element.textContent?.includes("Select Agent")
|
||||
);
|
||||
});
|
||||
expect(agentCombobox).toBeDefined();
|
||||
fireEvent.click(agentCombobox!);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getAllByText("Agent agent_78...")).toHaveLength(2);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("Agent Creation Modal", () => {
|
||||
test("opens agent creation modal when + New Agent is clicked", async () => {
|
||||
await act(async () => {
|
||||
render(<ChatPlaygroundPage />);
|
||||
});
|
||||
|
||||
const newAgentButton = screen.getByText("+ New Agent");
|
||||
fireEvent.click(newAgentButton);
|
||||
|
||||
expect(screen.getByText("Create New Agent")).toBeInTheDocument();
|
||||
expect(screen.getByText("Agent Name (optional)")).toBeInTheDocument();
|
||||
expect(screen.getAllByText("Model")).toHaveLength(2);
|
||||
expect(screen.getByText("System Instructions")).toBeInTheDocument();
|
||||
expect(screen.getByText("Tools (optional)")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("closes modal when Cancel is clicked", async () => {
|
||||
await act(async () => {
|
||||
render(<ChatPlaygroundPage />);
|
||||
});
|
||||
|
||||
const newAgentButton = screen.getByText("+ New Agent");
|
||||
fireEvent.click(newAgentButton);
|
||||
|
||||
const cancelButton = screen.getByText("Cancel");
|
||||
fireEvent.click(cancelButton);
|
||||
|
||||
expect(screen.queryByText("Create New Agent")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("creates agent when Create Agent is clicked", async () => {
|
||||
mockClient.agents.create.mockResolvedValue({ agent_id: "new-agent-123" });
|
||||
mockClient.agents.list
|
||||
.mockResolvedValueOnce({ data: mockAgents })
|
||||
.mockResolvedValueOnce({
|
||||
data: [
|
||||
...mockAgents,
|
||||
{ agent_id: "new-agent-123", agent_config: { name: "New Agent" } },
|
||||
],
|
||||
});
|
||||
|
||||
await act(async () => {
|
||||
render(<ChatPlaygroundPage />);
|
||||
});
|
||||
|
||||
const newAgentButton = screen.getByText("+ New Agent");
|
||||
await act(async () => {
|
||||
fireEvent.click(newAgentButton);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("Create New Agent")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
const nameInput = screen.getByPlaceholderText("My Custom Agent");
|
||||
await act(async () => {
|
||||
fireEvent.change(nameInput, { target: { value: "Test Agent Name" } });
|
||||
});
|
||||
|
||||
const instructionsTextarea = screen.getByDisplayValue(
|
||||
"You are a helpful assistant."
|
||||
);
|
||||
await act(async () => {
|
||||
fireEvent.change(instructionsTextarea, {
|
||||
target: { value: "Custom instructions" },
|
||||
});
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
const modalModelSelectors = screen
|
||||
.getAllByRole("combobox")
|
||||
.filter(el => {
|
||||
return (
|
||||
el.textContent?.includes("Select Model") ||
|
||||
el.closest('[class*="modal"]') ||
|
||||
el.closest('[class*="card"]')
|
||||
);
|
||||
});
|
||||
expect(modalModelSelectors.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
const modalModelSelectors = screen.getAllByRole("combobox").filter(el => {
|
||||
return (
|
||||
el.textContent?.includes("Select Model") ||
|
||||
el.closest('[class*="modal"]') ||
|
||||
el.closest('[class*="card"]')
|
||||
);
|
||||
});
|
||||
|
||||
await act(async () => {
|
||||
fireEvent.click(modalModelSelectors[0]);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
const modelOptions = screen.getAllByText("test-model-1");
|
||||
expect(modelOptions.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
const modelOptions = screen.getAllByText("test-model-1");
|
||||
const dropdownOption = modelOptions.find(
|
||||
option =>
|
||||
option.closest('[role="option"]') ||
|
||||
option.id?.includes("radix") ||
|
||||
option.getAttribute("aria-selected") !== null
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
fireEvent.click(
|
||||
dropdownOption || modelOptions[modelOptions.length - 1]
|
||||
);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
const createButton = screen.getByText("Create Agent");
|
||||
expect(createButton).not.toBeDisabled();
|
||||
});
|
||||
|
||||
const createButton = screen.getByText("Create Agent");
|
||||
await act(async () => {
|
||||
fireEvent.click(createButton);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockClient.agents.create).toHaveBeenCalledWith({
|
||||
agent_config: {
|
||||
model: expect.any(String),
|
||||
instructions: "Custom instructions",
|
||||
name: "Test Agent Name",
|
||||
enable_session_persistence: true,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByText("Create New Agent")).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("Agent Selection", () => {
|
||||
test("creates default session when agent is selected", async () => {
|
||||
await act(async () => {
|
||||
render(<ChatPlaygroundPage />);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
// first agent should be auto-selected
|
||||
expect(mockClient.agents.session.create).toHaveBeenCalledWith(
|
||||
"agent_123",
|
||||
{ session_name: "Default Session" }
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
test("switches agent when different agent is selected", async () => {
|
||||
await act(async () => {
|
||||
render(<ChatPlaygroundPage />);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
const agentCombobox = screen.getAllByRole("combobox").find(element => {
|
||||
return (
|
||||
element.textContent?.includes("Test Agent") ||
|
||||
element.textContent?.includes("Select Agent")
|
||||
);
|
||||
});
|
||||
expect(agentCombobox).toBeDefined();
|
||||
fireEvent.click(agentCombobox!);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
const anotherAgentOption = screen.getByText("Another Agent");
|
||||
fireEvent.click(anotherAgentOption);
|
||||
});
|
||||
|
||||
expect(mockClient.agents.session.create).toHaveBeenCalledWith(
|
||||
"agent_456",
|
||||
{ session_name: "Default Session" }
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("Agent Deletion", () => {
|
||||
test("shows delete button when multiple agents exist", async () => {
|
||||
await act(async () => {
|
||||
render(<ChatPlaygroundPage />);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTitle("Delete current agent")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("hides delete button when only one agent exists", async () => {
|
||||
mockClient.agents.list.mockResolvedValue({
|
||||
data: [mockAgents[0]],
|
||||
});
|
||||
|
||||
await act(async () => {
|
||||
render(<ChatPlaygroundPage />);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.queryByTitle("Delete current agent")
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("deletes agent and switches to another when confirmed", async () => {
|
||||
global.confirm = jest.fn(() => true);
|
||||
|
||||
await act(async () => {
|
||||
render(<ChatPlaygroundPage />);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTitle("Delete current agent")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
mockClient.agents.delete.mockResolvedValue(undefined);
|
||||
mockClient.agents.list.mockResolvedValueOnce({ data: mockAgents });
|
||||
mockClient.agents.list.mockResolvedValueOnce({
|
||||
data: [mockAgents[1]],
|
||||
});
|
||||
|
||||
const deleteButton = screen.getByTitle("Delete current agent");
|
||||
await act(async () => {
|
||||
deleteButton.click();
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockClient.agents.delete).toHaveBeenCalledWith("agent_123");
|
||||
expect(global.confirm).toHaveBeenCalledWith(
|
||||
"Are you sure you want to delete this agent? This action cannot be undone and will delete all associated sessions."
|
||||
);
|
||||
});
|
||||
|
||||
(global.confirm as jest.Mock).mockRestore();
|
||||
});
|
||||
|
||||
test("does not delete agent when cancelled", async () => {
|
||||
global.confirm = jest.fn(() => false);
|
||||
|
||||
await act(async () => {
|
||||
render(<ChatPlaygroundPage />);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTitle("Delete current agent")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
const deleteButton = screen.getByTitle("Delete current agent");
|
||||
await act(async () => {
|
||||
deleteButton.click();
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(global.confirm).toHaveBeenCalled();
|
||||
expect(mockClient.agents.delete).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
(global.confirm as jest.Mock).mockRestore();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Error Handling", () => {
|
||||
test("handles agent loading errors gracefully", async () => {
|
||||
mockClient.agents.list.mockRejectedValue(
|
||||
new Error("Failed to load agents")
|
||||
);
|
||||
const consoleSpy = jest
|
||||
.spyOn(console, "error")
|
||||
.mockImplementation(() => {});
|
||||
|
||||
await act(async () => {
|
||||
render(<ChatPlaygroundPage />);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(consoleSpy).toHaveBeenCalledWith(
|
||||
"Error fetching agents:",
|
||||
expect.any(Error)
|
||||
);
|
||||
});
|
||||
|
||||
expect(screen.getByText("+ New Agent")).toBeInTheDocument();
|
||||
|
||||
consoleSpy.mockRestore();
|
||||
});
|
||||
|
||||
test("handles model loading errors gracefully", async () => {
|
||||
mockClient.models.list.mockRejectedValue(
|
||||
new Error("Failed to load models")
|
||||
);
|
||||
const consoleSpy = jest
|
||||
.spyOn(console, "error")
|
||||
.mockImplementation(() => {});
|
||||
|
||||
await act(async () => {
|
||||
render(<ChatPlaygroundPage />);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(consoleSpy).toHaveBeenCalledWith(
|
||||
"Error fetching models:",
|
||||
expect.any(Error)
|
||||
);
|
||||
});
|
||||
|
||||
consoleSpy.mockRestore();
|
||||
});
|
||||
});
|
||||
});
|
||||
File diff suppressed because it is too large
Load diff
Binary file not shown.
|
Before Width: | Height: | Size: 25 KiB |
|
|
@ -120,3 +120,44 @@
|
|||
@apply bg-background text-foreground;
|
||||
}
|
||||
}
|
||||
|
||||
@layer utilities {
|
||||
.animate-typing-dot-1 {
|
||||
animation: typing-dot-bounce-1 0.8s cubic-bezier(0.4, 0, 0.6, 1) infinite;
|
||||
}
|
||||
|
||||
.animate-typing-dot-2 {
|
||||
animation: typing-dot-bounce-2 0.8s cubic-bezier(0.4, 0, 0.6, 1) infinite;
|
||||
}
|
||||
|
||||
.animate-typing-dot-3 {
|
||||
animation: typing-dot-bounce-3 0.8s cubic-bezier(0.4, 0, 0.6, 1) infinite;
|
||||
}
|
||||
|
||||
@keyframes typing-dot-bounce-1 {
|
||||
0%, 15%, 85%, 100% {
|
||||
transform: translateY(0);
|
||||
}
|
||||
7.5% {
|
||||
transform: translateY(-6px);
|
||||
}
|
||||
}
|
||||
|
||||
@keyframes typing-dot-bounce-2 {
|
||||
0%, 15%, 35%, 85%, 100% {
|
||||
transform: translateY(0);
|
||||
}
|
||||
25% {
|
||||
transform: translateY(-6px);
|
||||
}
|
||||
}
|
||||
|
||||
@keyframes typing-dot-bounce-3 {
|
||||
0%, 35%, 55%, 85%, 100% {
|
||||
transform: translateY(0);
|
||||
}
|
||||
45% {
|
||||
transform: translateY(-6px);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,6 +18,9 @@ const geistMono = Geist_Mono({
|
|||
export const metadata: Metadata = {
|
||||
title: "Llama Stack",
|
||||
description: "Llama Stack UI",
|
||||
icons: {
|
||||
icon: "/favicon.ico",
|
||||
},
|
||||
};
|
||||
|
||||
import { SidebarProvider, SidebarTrigger } from "@/components/ui/sidebar";
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -1,6 +1,6 @@
|
|||
"use client";
|
||||
|
||||
import { useState, useEffect } from "react";
|
||||
import { useState, useEffect, useCallback } from "react";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Select,
|
||||
|
|
@ -13,14 +13,20 @@ import { Input } from "@/components/ui/input";
|
|||
import { Card } from "@/components/ui/card";
|
||||
import { Trash2 } from "lucide-react";
|
||||
import type { Message } from "@/components/chat-playground/chat-message";
|
||||
import { useAuthClient } from "@/hooks/use-auth-client";
|
||||
import type {
|
||||
Session,
|
||||
SessionCreateParams,
|
||||
} from "llama-stack-client/resources/agents";
|
||||
|
||||
interface ChatSession {
|
||||
export interface ChatSession {
|
||||
id: string;
|
||||
name: string;
|
||||
messages: Message[];
|
||||
selectedModel: string;
|
||||
selectedVectorDb: string;
|
||||
systemMessage: string;
|
||||
agentId: string;
|
||||
session?: Session;
|
||||
createdAt: number;
|
||||
updatedAt: number;
|
||||
}
|
||||
|
|
@ -29,9 +35,9 @@ interface SessionManagerProps {
|
|||
currentSession: ChatSession | null;
|
||||
onSessionChange: (session: ChatSession) => void;
|
||||
onNewSession: () => void;
|
||||
selectedAgentId: string;
|
||||
}
|
||||
|
||||
const SESSIONS_STORAGE_KEY = "chat-playground-sessions";
|
||||
const CURRENT_SESSION_KEY = "chat-playground-current-session";
|
||||
|
||||
// ensures this only happens client side
|
||||
|
|
@ -63,16 +69,6 @@ const safeLocalStorage = {
|
|||
},
|
||||
};
|
||||
|
||||
function safeJsonParse<T>(jsonString: string | null, fallback: T): T {
|
||||
if (!jsonString) return fallback;
|
||||
try {
|
||||
return JSON.parse(jsonString) as T;
|
||||
} catch (err) {
|
||||
console.error("Error parsing JSON:", err);
|
||||
return fallback;
|
||||
}
|
||||
}
|
||||
|
||||
const generateSessionId = (): string => {
|
||||
return globalThis.crypto.randomUUID();
|
||||
};
|
||||
|
|
@ -80,60 +76,202 @@ const generateSessionId = (): string => {
|
|||
export function SessionManager({
|
||||
currentSession,
|
||||
onSessionChange,
|
||||
selectedAgentId,
|
||||
}: SessionManagerProps) {
|
||||
const [sessions, setSessions] = useState<ChatSession[]>([]);
|
||||
const [showCreateForm, setShowCreateForm] = useState(false);
|
||||
const [newSessionName, setNewSessionName] = useState("");
|
||||
const [loading, setLoading] = useState(false);
|
||||
const client = useAuthClient();
|
||||
|
||||
const loadAgentSessions = useCallback(async () => {
|
||||
if (!selectedAgentId) return;
|
||||
|
||||
setLoading(true);
|
||||
try {
|
||||
const response = await client.agents.session.list(selectedAgentId);
|
||||
console.log("Sessions response:", response);
|
||||
|
||||
if (!response.data || !Array.isArray(response.data)) {
|
||||
console.warn("Invalid sessions response, starting fresh");
|
||||
setSessions([]);
|
||||
return;
|
||||
}
|
||||
|
||||
const agentSessions: ChatSession[] = response.data
|
||||
.filter(sessionData => {
|
||||
const isValid =
|
||||
sessionData &&
|
||||
typeof sessionData === "object" &&
|
||||
sessionData.session_id &&
|
||||
sessionData.session_name;
|
||||
if (!isValid) {
|
||||
console.warn("Filtering out invalid session:", sessionData);
|
||||
}
|
||||
return isValid;
|
||||
})
|
||||
.map(sessionData => ({
|
||||
id: sessionData.session_id,
|
||||
name: sessionData.session_name,
|
||||
messages: [],
|
||||
selectedModel: currentSession?.selectedModel || "",
|
||||
systemMessage:
|
||||
currentSession?.systemMessage || "You are a helpful assistant.",
|
||||
agentId: selectedAgentId,
|
||||
session: sessionData,
|
||||
createdAt: sessionData.started_at
|
||||
? new Date(sessionData.started_at).getTime()
|
||||
: Date.now(),
|
||||
updatedAt: sessionData.started_at
|
||||
? new Date(sessionData.started_at).getTime()
|
||||
: Date.now(),
|
||||
}));
|
||||
setSessions(agentSessions);
|
||||
} catch (error) {
|
||||
console.error("Error loading agent sessions:", error);
|
||||
setSessions([]);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
}, [
|
||||
selectedAgentId,
|
||||
client,
|
||||
currentSession?.selectedModel,
|
||||
currentSession?.systemMessage,
|
||||
]);
|
||||
|
||||
useEffect(() => {
|
||||
const savedSessions = safeLocalStorage.getItem(SESSIONS_STORAGE_KEY);
|
||||
const sessions = safeJsonParse<ChatSession[]>(savedSessions, []);
|
||||
setSessions(sessions);
|
||||
}, []);
|
||||
if (selectedAgentId) {
|
||||
loadAgentSessions();
|
||||
}
|
||||
}, [selectedAgentId, loadAgentSessions]);
|
||||
|
||||
const saveSessions = (updatedSessions: ChatSession[]) => {
|
||||
setSessions(updatedSessions);
|
||||
safeLocalStorage.setItem(
|
||||
SESSIONS_STORAGE_KEY,
|
||||
JSON.stringify(updatedSessions)
|
||||
);
|
||||
};
|
||||
const createNewSession = async () => {
|
||||
if (!selectedAgentId) return;
|
||||
|
||||
const createNewSession = () => {
|
||||
const sessionName =
|
||||
newSessionName.trim() || `Session ${sessions.length + 1}`;
|
||||
const newSession: ChatSession = {
|
||||
id: generateSessionId(),
|
||||
name: sessionName,
|
||||
messages: [],
|
||||
selectedModel: currentSession?.selectedModel || "",
|
||||
selectedVectorDb: currentSession?.selectedVectorDb || "",
|
||||
systemMessage:
|
||||
currentSession?.systemMessage || "You are a helpful assistant.",
|
||||
createdAt: Date.now(),
|
||||
updatedAt: Date.now(),
|
||||
};
|
||||
setLoading(true);
|
||||
|
||||
const updatedSessions = [...sessions, newSession];
|
||||
saveSessions(updatedSessions);
|
||||
try {
|
||||
const response = await client.agents.session.create(selectedAgentId, {
|
||||
session_name: sessionName,
|
||||
} as SessionCreateParams);
|
||||
|
||||
safeLocalStorage.setItem(CURRENT_SESSION_KEY, newSession.id);
|
||||
onSessionChange(newSession);
|
||||
const newSession: ChatSession = {
|
||||
id: response.session_id,
|
||||
name: sessionName,
|
||||
messages: [],
|
||||
selectedModel: currentSession?.selectedModel || "",
|
||||
systemMessage:
|
||||
currentSession?.systemMessage || "You are a helpful assistant.",
|
||||
agentId: selectedAgentId,
|
||||
createdAt: Date.now(),
|
||||
updatedAt: Date.now(),
|
||||
};
|
||||
|
||||
setNewSessionName("");
|
||||
setShowCreateForm(false);
|
||||
};
|
||||
setSessions(prev => [...prev, newSession]);
|
||||
SessionUtils.saveCurrentSessionId(newSession.id, selectedAgentId);
|
||||
onSessionChange(newSession);
|
||||
|
||||
const switchToSession = (sessionId: string) => {
|
||||
const session = sessions.find(s => s.id === sessionId);
|
||||
if (session) {
|
||||
safeLocalStorage.setItem(CURRENT_SESSION_KEY, sessionId);
|
||||
onSessionChange(session);
|
||||
setNewSessionName("");
|
||||
setShowCreateForm(false);
|
||||
} catch (error) {
|
||||
console.error("Error creating session:", error);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const deleteSession = (sessionId: string) => {
|
||||
if (sessions.length <= 1) {
|
||||
const loadSessionMessages = useCallback(
|
||||
async (agentId: string, sessionId: string): Promise<Message[]> => {
|
||||
try {
|
||||
const session = await client.agents.session.retrieve(
|
||||
agentId,
|
||||
sessionId
|
||||
);
|
||||
|
||||
if (!session || !session.turns || !Array.isArray(session.turns)) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const messages: Message[] = [];
|
||||
for (const turn of session.turns) {
|
||||
// Add user messages from input_messages
|
||||
if (turn.input_messages && Array.isArray(turn.input_messages)) {
|
||||
for (const input of turn.input_messages) {
|
||||
if (input.role === "user" && input.content) {
|
||||
messages.push({
|
||||
id: `${turn.turn_id}-user-${messages.length}`,
|
||||
role: "user",
|
||||
content:
|
||||
typeof input.content === "string"
|
||||
? input.content
|
||||
: JSON.stringify(input.content),
|
||||
createdAt: new Date(turn.started_at || Date.now()),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add assistant message from output_message
|
||||
if (turn.output_message && turn.output_message.content) {
|
||||
messages.push({
|
||||
id: `${turn.turn_id}-assistant-${messages.length}`,
|
||||
role: "assistant",
|
||||
content:
|
||||
typeof turn.output_message.content === "string"
|
||||
? turn.output_message.content
|
||||
: JSON.stringify(turn.output_message.content),
|
||||
createdAt: new Date(
|
||||
turn.completed_at || turn.started_at || Date.now()
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return messages;
|
||||
} catch (error) {
|
||||
console.error("Error loading session messages:", error);
|
||||
return [];
|
||||
}
|
||||
},
|
||||
[client]
|
||||
);
|
||||
|
||||
const switchToSession = useCallback(
|
||||
async (sessionId: string) => {
|
||||
const session = sessions.find(s => s.id === sessionId);
|
||||
if (session) {
|
||||
setLoading(true);
|
||||
try {
|
||||
// Load messages for this session
|
||||
const messages = await loadSessionMessages(
|
||||
selectedAgentId,
|
||||
sessionId
|
||||
);
|
||||
const sessionWithMessages = {
|
||||
...session,
|
||||
messages,
|
||||
};
|
||||
|
||||
SessionUtils.saveCurrentSessionId(sessionId, selectedAgentId);
|
||||
onSessionChange(sessionWithMessages);
|
||||
} catch (error) {
|
||||
console.error("Error switching to session:", error);
|
||||
// Fallback to session without messages
|
||||
SessionUtils.saveCurrentSessionId(sessionId, selectedAgentId);
|
||||
onSessionChange(session);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
}
|
||||
},
|
||||
[sessions, selectedAgentId, loadSessionMessages, onSessionChange]
|
||||
);
|
||||
|
||||
const deleteSession = async (sessionId: string) => {
|
||||
if (sessions.length <= 1 || !selectedAgentId) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -142,21 +280,30 @@ export function SessionManager({
|
|||
"Are you sure you want to delete this session? This action cannot be undone."
|
||||
)
|
||||
) {
|
||||
const updatedSessions = sessions.filter(s => s.id !== sessionId);
|
||||
saveSessions(updatedSessions);
|
||||
setLoading(true);
|
||||
try {
|
||||
await client.agents.session.delete(selectedAgentId, sessionId);
|
||||
|
||||
if (currentSession?.id === sessionId) {
|
||||
const newCurrentSession = updatedSessions[0] || null;
|
||||
if (newCurrentSession) {
|
||||
safeLocalStorage.setItem(CURRENT_SESSION_KEY, newCurrentSession.id);
|
||||
onSessionChange(newCurrentSession);
|
||||
} else {
|
||||
safeLocalStorage.removeItem(CURRENT_SESSION_KEY);
|
||||
const defaultSession = SessionUtils.createDefaultSession();
|
||||
saveSessions([defaultSession]);
|
||||
safeLocalStorage.setItem(CURRENT_SESSION_KEY, defaultSession.id);
|
||||
onSessionChange(defaultSession);
|
||||
const updatedSessions = sessions.filter(s => s.id !== sessionId);
|
||||
setSessions(updatedSessions);
|
||||
|
||||
if (currentSession?.id === sessionId) {
|
||||
const newCurrentSession = updatedSessions[0] || null;
|
||||
if (newCurrentSession) {
|
||||
SessionUtils.saveCurrentSessionId(
|
||||
newCurrentSession.id,
|
||||
selectedAgentId
|
||||
);
|
||||
onSessionChange(newCurrentSession);
|
||||
} else {
|
||||
SessionUtils.clearCurrentSession(selectedAgentId);
|
||||
onNewSession();
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error deleting session:", error);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
@ -172,16 +319,16 @@ export function SessionManager({
|
|||
updatedSessions.push(currentSession);
|
||||
}
|
||||
|
||||
safeLocalStorage.setItem(
|
||||
SESSIONS_STORAGE_KEY,
|
||||
JSON.stringify(updatedSessions)
|
||||
);
|
||||
|
||||
return updatedSessions;
|
||||
});
|
||||
}
|
||||
}, [currentSession]);
|
||||
|
||||
// Don't render if no agent is selected
|
||||
if (!selectedAgentId) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="relative">
|
||||
<div className="flex items-center gap-2">
|
||||
|
|
@ -205,6 +352,7 @@ export function SessionManager({
|
|||
onClick={() => setShowCreateForm(true)}
|
||||
variant="outline"
|
||||
size="sm"
|
||||
disabled={loading || !selectedAgentId}
|
||||
>
|
||||
+ New
|
||||
</Button>
|
||||
|
|
@ -241,8 +389,12 @@ export function SessionManager({
|
|||
/>
|
||||
|
||||
<div className="flex gap-2">
|
||||
<Button onClick={createNewSession} className="flex-1">
|
||||
Create
|
||||
<Button
|
||||
onClick={createNewSession}
|
||||
className="flex-1"
|
||||
disabled={loading}
|
||||
>
|
||||
{loading ? "Creating..." : "Create"}
|
||||
</Button>
|
||||
<Button
|
||||
variant="outline"
|
||||
|
|
@ -270,72 +422,147 @@ export function SessionManager({
|
|||
}
|
||||
|
||||
export const SessionUtils = {
|
||||
loadCurrentSession: (): ChatSession | null => {
|
||||
const currentSessionId = safeLocalStorage.getItem(CURRENT_SESSION_KEY);
|
||||
const savedSessions = safeLocalStorage.getItem(SESSIONS_STORAGE_KEY);
|
||||
|
||||
if (currentSessionId && savedSessions) {
|
||||
const sessions = safeJsonParse<ChatSession[]>(savedSessions, []);
|
||||
return sessions.find(s => s.id === currentSessionId) || null;
|
||||
}
|
||||
return null;
|
||||
loadCurrentSessionId: (agentId?: string): string | null => {
|
||||
const key = agentId
|
||||
? `${CURRENT_SESSION_KEY}-${agentId}`
|
||||
: CURRENT_SESSION_KEY;
|
||||
return safeLocalStorage.getItem(key);
|
||||
},
|
||||
|
||||
saveCurrentSession: (session: ChatSession) => {
|
||||
const savedSessions = safeLocalStorage.getItem(SESSIONS_STORAGE_KEY);
|
||||
const sessions = safeJsonParse<ChatSession[]>(savedSessions, []);
|
||||
|
||||
const existingIndex = sessions.findIndex(s => s.id === session.id);
|
||||
if (existingIndex >= 0) {
|
||||
sessions[existingIndex] = { ...session, updatedAt: Date.now() };
|
||||
} else {
|
||||
sessions.push({
|
||||
...session,
|
||||
createdAt: Date.now(),
|
||||
updatedAt: Date.now(),
|
||||
});
|
||||
}
|
||||
|
||||
safeLocalStorage.setItem(SESSIONS_STORAGE_KEY, JSON.stringify(sessions));
|
||||
safeLocalStorage.setItem(CURRENT_SESSION_KEY, session.id);
|
||||
saveCurrentSessionId: (sessionId: string, agentId?: string) => {
|
||||
const key = agentId
|
||||
? `${CURRENT_SESSION_KEY}-${agentId}`
|
||||
: CURRENT_SESSION_KEY;
|
||||
safeLocalStorage.setItem(key, sessionId);
|
||||
},
|
||||
|
||||
createDefaultSession: (
|
||||
inheritModel?: string,
|
||||
inheritVectorDb?: string
|
||||
agentId: string,
|
||||
inheritModel?: string
|
||||
): ChatSession => ({
|
||||
id: generateSessionId(),
|
||||
name: "Default Session",
|
||||
messages: [],
|
||||
selectedModel: inheritModel || "",
|
||||
selectedVectorDb: inheritVectorDb || "",
|
||||
systemMessage: "You are a helpful assistant.",
|
||||
agentId,
|
||||
createdAt: Date.now(),
|
||||
updatedAt: Date.now(),
|
||||
}),
|
||||
|
||||
deleteSession: (
|
||||
sessionId: string
|
||||
): {
|
||||
deletedSession: ChatSession | null;
|
||||
remainingSessions: ChatSession[];
|
||||
} => {
|
||||
const savedSessions = safeLocalStorage.getItem(SESSIONS_STORAGE_KEY);
|
||||
const sessions = safeJsonParse<ChatSession[]>(savedSessions, []);
|
||||
clearCurrentSession: (agentId?: string) => {
|
||||
const key = agentId
|
||||
? `${CURRENT_SESSION_KEY}-${agentId}`
|
||||
: CURRENT_SESSION_KEY;
|
||||
safeLocalStorage.removeItem(key);
|
||||
},
|
||||
|
||||
const sessionToDelete = sessions.find(s => s.id === sessionId);
|
||||
const remainingSessions = sessions.filter(s => s.id !== sessionId);
|
||||
loadCurrentAgentId: (): string | null => {
|
||||
return safeLocalStorage.getItem("chat-playground-current-agent");
|
||||
},
|
||||
|
||||
saveCurrentAgentId: (agentId: string) => {
|
||||
safeLocalStorage.setItem("chat-playground-current-agent", agentId);
|
||||
},
|
||||
|
||||
// Comprehensive session caching
|
||||
saveSessionData: (agentId: string, sessionData: ChatSession) => {
|
||||
const key = `chat-playground-session-data-${agentId}-${sessionData.id}`;
|
||||
safeLocalStorage.setItem(
|
||||
SESSIONS_STORAGE_KEY,
|
||||
JSON.stringify(remainingSessions)
|
||||
key,
|
||||
JSON.stringify({
|
||||
...sessionData,
|
||||
cachedAt: Date.now(),
|
||||
})
|
||||
);
|
||||
},
|
||||
|
||||
const currentSessionId = safeLocalStorage.getItem(CURRENT_SESSION_KEY);
|
||||
if (currentSessionId === sessionId) {
|
||||
safeLocalStorage.removeItem(CURRENT_SESSION_KEY);
|
||||
loadSessionData: (agentId: string, sessionId: string): ChatSession | null => {
|
||||
const key = `chat-playground-session-data-${agentId}-${sessionId}`;
|
||||
const cached = safeLocalStorage.getItem(key);
|
||||
if (!cached) return null;
|
||||
|
||||
try {
|
||||
const data = JSON.parse(cached);
|
||||
// Check if cache is fresh (less than 1 hour old)
|
||||
const cacheAge = Date.now() - (data.cachedAt || 0);
|
||||
if (cacheAge > 60 * 60 * 1000) {
|
||||
safeLocalStorage.removeItem(key);
|
||||
return null;
|
||||
}
|
||||
|
||||
// Convert date strings back to Date objects
|
||||
return {
|
||||
...data,
|
||||
messages: data.messages.map(
|
||||
(msg: { createdAt: string; [key: string]: unknown }) => ({
|
||||
...msg,
|
||||
createdAt: new Date(msg.createdAt),
|
||||
})
|
||||
),
|
||||
};
|
||||
} catch (error) {
|
||||
console.error("Error parsing cached session data:", error);
|
||||
safeLocalStorage.removeItem(key);
|
||||
return null;
|
||||
}
|
||||
},
|
||||
|
||||
return { deletedSession: sessionToDelete || null, remainingSessions };
|
||||
// Agent config caching
|
||||
saveAgentConfig: (
|
||||
agentId: string,
|
||||
config: {
|
||||
toolgroups?: Array<
|
||||
string | { name: string; args: Record<string, unknown> }
|
||||
>;
|
||||
[key: string]: unknown;
|
||||
}
|
||||
) => {
|
||||
const key = `chat-playground-agent-config-${agentId}`;
|
||||
safeLocalStorage.setItem(
|
||||
key,
|
||||
JSON.stringify({
|
||||
config,
|
||||
cachedAt: Date.now(),
|
||||
})
|
||||
);
|
||||
},
|
||||
|
||||
loadAgentConfig: (
|
||||
agentId: string
|
||||
): {
|
||||
toolgroups?: Array<
|
||||
string | { name: string; args: Record<string, unknown> }
|
||||
>;
|
||||
[key: string]: unknown;
|
||||
} | null => {
|
||||
const key = `chat-playground-agent-config-${agentId}`;
|
||||
const cached = safeLocalStorage.getItem(key);
|
||||
if (!cached) return null;
|
||||
|
||||
try {
|
||||
const data = JSON.parse(cached);
|
||||
// Check if cache is fresh (less than 30 minutes old)
|
||||
const cacheAge = Date.now() - (data.cachedAt || 0);
|
||||
if (cacheAge > 30 * 60 * 1000) {
|
||||
safeLocalStorage.removeItem(key);
|
||||
return null;
|
||||
}
|
||||
return data.config;
|
||||
} catch (error) {
|
||||
console.error("Error parsing cached agent config:", error);
|
||||
safeLocalStorage.removeItem(key);
|
||||
return null;
|
||||
}
|
||||
},
|
||||
|
||||
// Clear all cached data for an agent
|
||||
clearAgentCache: (agentId: string) => {
|
||||
const keys = Object.keys(localStorage).filter(
|
||||
key =>
|
||||
key.includes(`chat-playground-session-data-${agentId}`) ||
|
||||
key.includes(`chat-playground-agent-config-${agentId}`)
|
||||
);
|
||||
keys.forEach(key => safeLocalStorage.removeItem(key));
|
||||
},
|
||||
};
|
||||
|
|
|
|||
|
|
@ -5,9 +5,9 @@ export function TypingIndicator() {
|
|||
<div className="justify-left flex space-x-1">
|
||||
<div className="rounded-lg bg-muted p-3">
|
||||
<div className="flex -space-x-2.5">
|
||||
<Dot className="h-5 w-5 animate-typing-dot-bounce" />
|
||||
<Dot className="h-5 w-5 animate-typing-dot-bounce [animation-delay:90ms]" />
|
||||
<Dot className="h-5 w-5 animate-typing-dot-bounce [animation-delay:180ms]" />
|
||||
<Dot className="h-5 w-5 animate-typing-dot-1" />
|
||||
<Dot className="h-5 w-5 animate-typing-dot-2" />
|
||||
<Dot className="h-5 w-5 animate-typing-dot-3" />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import {
|
|||
} from "lucide-react";
|
||||
import Link from "next/link";
|
||||
import { usePathname } from "next/navigation";
|
||||
import Image from "next/image";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
import {
|
||||
|
|
@ -110,7 +111,16 @@ export function AppSidebar() {
|
|||
return (
|
||||
<Sidebar>
|
||||
<SidebarHeader>
|
||||
<Link href="/">Llama Stack</Link>
|
||||
<Link href="/" className="flex items-center gap-2 p-2">
|
||||
<Image
|
||||
src="/logo.webp"
|
||||
alt="Llama Stack"
|
||||
width={32}
|
||||
height={32}
|
||||
className="h-8 w-8"
|
||||
/>
|
||||
<span className="font-semibold text-lg">Llama Stack</span>
|
||||
</Link>
|
||||
</SidebarHeader>
|
||||
<SidebarContent>
|
||||
<SidebarGroup>
|
||||
|
|
|
|||
8
llama_stack/ui/package-lock.json
generated
8
llama_stack/ui/package-lock.json
generated
|
|
@ -18,7 +18,7 @@
|
|||
"class-variance-authority": "^0.7.1",
|
||||
"clsx": "^2.1.1",
|
||||
"framer-motion": "^11.18.2",
|
||||
"llama-stack-client": "0.2.17",
|
||||
"llama-stack-client": "^0.2.18",
|
||||
"lucide-react": "^0.510.0",
|
||||
"next": "15.3.3",
|
||||
"next-auth": "^4.24.11",
|
||||
|
|
@ -9926,9 +9926,9 @@
|
|||
"license": "MIT"
|
||||
},
|
||||
"node_modules/llama-stack-client": {
|
||||
"version": "0.2.17",
|
||||
"resolved": "https://registry.npmjs.org/llama-stack-client/-/llama-stack-client-0.2.17.tgz",
|
||||
"integrity": "sha512-+/fEO8M7XPiVLjhH7ge18i1ijKp4+h3dOkE0C8g2cvGuDUtDYIJlf8NSyr9OMByjiWpCibWU7VOKL50LwGLS3Q==",
|
||||
"version": "0.2.18",
|
||||
"resolved": "https://registry.npmjs.org/llama-stack-client/-/llama-stack-client-0.2.18.tgz",
|
||||
"integrity": "sha512-k+xQOz/TIU0cINP4Aih8q6xs7f/6qs0fLDMXTTKQr5C0F1jtCjRiwsas7bTsDfpKfYhg/7Xy/wPw/uZgi6aIVg==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@types/node": "^18.11.18",
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
"class-variance-authority": "^0.7.1",
|
||||
"clsx": "^2.1.1",
|
||||
"framer-motion": "^11.18.2",
|
||||
"llama-stack-client": "^0.2.17",
|
||||
"llama-stack-client": "^0.2.18",
|
||||
"lucide-react": "^0.510.0",
|
||||
"next": "15.3.3",
|
||||
"next-auth": "^4.24.11",
|
||||
|
|
|
|||
BIN
llama_stack/ui/public/favicon.ico
Normal file
BIN
llama_stack/ui/public/favicon.ico
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 4.2 KiB |
BIN
llama_stack/ui/public/logo.webp
Normal file
BIN
llama_stack/ui/public/logo.webp
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 19 KiB |
|
|
@ -7,7 +7,7 @@ required-version = ">=0.7.0"
|
|||
|
||||
[project]
|
||||
name = "llama_stack"
|
||||
version = "0.2.17"
|
||||
version = "0.2.18"
|
||||
authors = [{ name = "Meta Llama", email = "llama-oss@meta.com" }]
|
||||
description = "Llama Stack"
|
||||
readme = "README.md"
|
||||
|
|
@ -31,7 +31,7 @@ dependencies = [
|
|||
"huggingface-hub>=0.34.0,<1.0",
|
||||
"jinja2>=3.1.6",
|
||||
"jsonschema",
|
||||
"llama-stack-client>=0.2.17",
|
||||
"llama-stack-client>=0.2.18",
|
||||
"llama-api-client>=0.1.2",
|
||||
"openai>=1.99.6,<1.100.0",
|
||||
"prompt-toolkit",
|
||||
|
|
@ -56,7 +56,7 @@ dependencies = [
|
|||
ui = [
|
||||
"streamlit",
|
||||
"pandas",
|
||||
"llama-stack-client>=0.2.17",
|
||||
"llama-stack-client>=0.2.18",
|
||||
"streamlit-option-menu",
|
||||
]
|
||||
|
||||
|
|
@ -93,6 +93,7 @@ unit = [
|
|||
"blobfile",
|
||||
"faiss-cpu",
|
||||
"pymilvus>=2.5.12",
|
||||
"milvus-lite>=2.5.0",
|
||||
"litellm",
|
||||
"together",
|
||||
"coverage",
|
||||
|
|
@ -118,6 +119,7 @@ test = [
|
|||
"sqlalchemy[asyncio]>=2.0.41",
|
||||
"requests",
|
||||
"pymilvus>=2.5.12",
|
||||
"milvus-lite>=2.5.0",
|
||||
"weaviate-client>=4.16.4",
|
||||
]
|
||||
docs = [
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
|
|
@ -19,10 +18,10 @@ from llama_stack.apis.post_training import (
|
|||
LoraFinetuningConfig,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", force=True)
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="post_training")
|
||||
|
||||
|
||||
skip_because_resource_intensive = pytest.mark.skip(
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import time
|
||||
from io import BytesIO
|
||||
|
||||
|
|
@ -14,8 +13,9 @@ from openai import BadRequestError as OpenAIBadRequestError
|
|||
|
||||
from llama_stack.apis.vector_io import Chunk
|
||||
from llama_stack.core.library_client import LlamaStackAsLibraryClient
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(name=__name__, category="vector_io")
|
||||
|
||||
|
||||
def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
|
||||
|
|
@ -56,6 +56,7 @@ def skip_if_provider_doesnt_support_openai_vector_stores_search(client_with_mode
|
|||
"keyword": [
|
||||
"inline::sqlite-vec",
|
||||
"remote::milvus",
|
||||
"inline::milvus",
|
||||
],
|
||||
"hybrid": [
|
||||
"inline::sqlite-vec",
|
||||
|
|
|
|||
|
|
@ -45,7 +45,6 @@ from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
|
|||
|
||||
|
||||
class TestConvertChatChoiceToResponseMessage:
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_string_content(self):
|
||||
choice = OpenAIChoice(
|
||||
message=OpenAIAssistantMessageParam(content="Test message"),
|
||||
|
|
@ -61,7 +60,6 @@ class TestConvertChatChoiceToResponseMessage:
|
|||
assert isinstance(result.content[0], OpenAIResponseOutputMessageContentOutputText)
|
||||
assert result.content[0].text == "Test message"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_text_param_content(self):
|
||||
choice = OpenAIChoice(
|
||||
message=OpenAIAssistantMessageParam(
|
||||
|
|
@ -78,12 +76,10 @@ class TestConvertChatChoiceToResponseMessage:
|
|||
|
||||
|
||||
class TestConvertResponseContentToChatContent:
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_string_content(self):
|
||||
result = await convert_response_content_to_chat_content("Simple string")
|
||||
assert result == "Simple string"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_text_content_parts(self):
|
||||
content = [
|
||||
OpenAIResponseInputMessageContentText(text="First part"),
|
||||
|
|
@ -98,7 +94,6 @@ class TestConvertResponseContentToChatContent:
|
|||
assert isinstance(result[1], OpenAIChatCompletionContentPartTextParam)
|
||||
assert result[1].text == "Second part"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_image_content(self):
|
||||
content = [OpenAIResponseInputMessageContentImage(image_url="https://example.com/image.jpg", detail="high")]
|
||||
|
||||
|
|
@ -111,7 +106,6 @@ class TestConvertResponseContentToChatContent:
|
|||
|
||||
|
||||
class TestConvertResponseInputToChatMessages:
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_string_input(self):
|
||||
result = await convert_response_input_to_chat_messages("User message")
|
||||
|
||||
|
|
@ -119,7 +113,6 @@ class TestConvertResponseInputToChatMessages:
|
|||
assert isinstance(result[0], OpenAIUserMessageParam)
|
||||
assert result[0].content == "User message"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_function_tool_call_output(self):
|
||||
input_items = [
|
||||
OpenAIResponseInputFunctionToolCallOutput(
|
||||
|
|
@ -135,7 +128,6 @@ class TestConvertResponseInputToChatMessages:
|
|||
assert result[0].content == "Tool output"
|
||||
assert result[0].tool_call_id == "call_123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_function_tool_call(self):
|
||||
input_items = [
|
||||
OpenAIResponseOutputMessageFunctionToolCall(
|
||||
|
|
@ -154,7 +146,6 @@ class TestConvertResponseInputToChatMessages:
|
|||
assert result[0].tool_calls[0].function.name == "test_function"
|
||||
assert result[0].tool_calls[0].function.arguments == '{"param": "value"}'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_response_message(self):
|
||||
input_items = [
|
||||
OpenAIResponseMessage(
|
||||
|
|
@ -173,7 +164,6 @@ class TestConvertResponseInputToChatMessages:
|
|||
|
||||
|
||||
class TestConvertResponseTextToChatResponseFormat:
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_text_format(self):
|
||||
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text"))
|
||||
result = await convert_response_text_to_chat_response_format(text)
|
||||
|
|
@ -181,14 +171,12 @@ class TestConvertResponseTextToChatResponseFormat:
|
|||
assert isinstance(result, OpenAIResponseFormatText)
|
||||
assert result.type == "text"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_json_object_format(self):
|
||||
text = OpenAIResponseText(format={"type": "json_object"})
|
||||
result = await convert_response_text_to_chat_response_format(text)
|
||||
|
||||
assert isinstance(result, OpenAIResponseFormatJSONObject)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_json_schema_format(self):
|
||||
schema_def = {"type": "object", "properties": {"test": {"type": "string"}}}
|
||||
text = OpenAIResponseText(
|
||||
|
|
@ -204,7 +192,6 @@ class TestConvertResponseTextToChatResponseFormat:
|
|||
assert result.json_schema["name"] == "test_schema"
|
||||
assert result.json_schema["schema"] == schema_def
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_text_format(self):
|
||||
text = OpenAIResponseText()
|
||||
result = await convert_response_text_to_chat_response_format(text)
|
||||
|
|
@ -214,27 +201,22 @@ class TestConvertResponseTextToChatResponseFormat:
|
|||
|
||||
|
||||
class TestGetMessageTypeByRole:
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_role(self):
|
||||
result = await get_message_type_by_role("user")
|
||||
assert result == OpenAIUserMessageParam
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_system_role(self):
|
||||
result = await get_message_type_by_role("system")
|
||||
assert result == OpenAISystemMessageParam
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assistant_role(self):
|
||||
result = await get_message_type_by_role("assistant")
|
||||
assert result == OpenAIAssistantMessageParam
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_developer_role(self):
|
||||
result = await get_message_type_by_role("developer")
|
||||
assert result == OpenAIDeveloperMessageParam
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_role(self):
|
||||
result = await get_message_type_by_role("unknown")
|
||||
assert result is None
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import logging # allow-direct-logging
|
||||
import threading
|
||||
import time
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
|
|
|
|||
16
uv.lock
generated
16
uv.lock
generated
|
|
@ -1719,7 +1719,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "llama-stack"
|
||||
version = "0.2.17"
|
||||
version = "0.2.18"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "aiohttp" },
|
||||
|
|
@ -1809,6 +1809,7 @@ test = [
|
|||
{ name = "chardet" },
|
||||
{ name = "datasets" },
|
||||
{ name = "mcp" },
|
||||
{ name = "milvus-lite" },
|
||||
{ name = "openai" },
|
||||
{ name = "pymilvus" },
|
||||
{ name = "pypdf" },
|
||||
|
|
@ -1831,6 +1832,7 @@ unit = [
|
|||
{ name = "faiss-cpu" },
|
||||
{ name = "litellm" },
|
||||
{ name = "mcp" },
|
||||
{ name = "milvus-lite" },
|
||||
{ name = "ollama" },
|
||||
{ name = "openai" },
|
||||
{ name = "pymilvus" },
|
||||
|
|
@ -1854,8 +1856,8 @@ requires-dist = [
|
|||
{ name = "jinja2", specifier = ">=3.1.6" },
|
||||
{ name = "jsonschema" },
|
||||
{ name = "llama-api-client", specifier = ">=0.1.2" },
|
||||
{ name = "llama-stack-client", specifier = ">=0.2.17" },
|
||||
{ name = "llama-stack-client", marker = "extra == 'ui'", specifier = ">=0.2.17" },
|
||||
{ name = "llama-stack-client", specifier = ">=0.2.18" },
|
||||
{ name = "llama-stack-client", marker = "extra == 'ui'", specifier = ">=0.2.18" },
|
||||
{ name = "openai", specifier = ">=1.99.6,<1.100.0" },
|
||||
{ name = "opentelemetry-exporter-otlp-proto-http", specifier = ">=1.30.0" },
|
||||
{ name = "opentelemetry-sdk", specifier = ">=1.30.0" },
|
||||
|
|
@ -1925,6 +1927,7 @@ test = [
|
|||
{ name = "chardet" },
|
||||
{ name = "datasets" },
|
||||
{ name = "mcp" },
|
||||
{ name = "milvus-lite", specifier = ">=2.5.0" },
|
||||
{ name = "openai" },
|
||||
{ name = "pymilvus", specifier = ">=2.5.12" },
|
||||
{ name = "pypdf" },
|
||||
|
|
@ -1946,6 +1949,7 @@ unit = [
|
|||
{ name = "faiss-cpu" },
|
||||
{ name = "litellm" },
|
||||
{ name = "mcp" },
|
||||
{ name = "milvus-lite", specifier = ">=2.5.0" },
|
||||
{ name = "ollama" },
|
||||
{ name = "openai" },
|
||||
{ name = "pymilvus", specifier = ">=2.5.12" },
|
||||
|
|
@ -1959,7 +1963,7 @@ unit = [
|
|||
|
||||
[[package]]
|
||||
name = "llama-stack-client"
|
||||
version = "0.2.17"
|
||||
version = "0.2.18"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "anyio" },
|
||||
|
|
@ -1978,9 +1982,9 @@ dependencies = [
|
|||
{ name = "tqdm" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c5/2a/bb2949d6a5c494d21da0c185d426e25eaa8016f8287b689249afc6c96fb5/llama_stack_client-0.2.17.tar.gz", hash = "sha256:1fe2070133c6356761e394fa346045e9b6b567d4c63157b9bc6be89b9a6e7a41", size = 257636, upload-time = "2025-08-05T01:42:55.911Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/69/da/5e5a745495f8a2b8ef24fc4d01fe9031aa2277c36447cb22192ec8c8cc1e/llama_stack_client-0.2.18.tar.gz", hash = "sha256:860c885c9e549445178ac55cc9422e6e2a91215ac7aff5aaccfb42f3ce07e79e", size = 277284, upload-time = "2025-08-19T22:12:09.106Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/81/fc/5eccc86b83c5ced3a3bca071d250a86ccafa4ff17546cf781deb7758ab74/llama_stack_client-0.2.17-py3-none-any.whl", hash = "sha256:336c32f8688700ff64717b8109f405dc87a990fbe310c2027ac9ed6d39d67d16", size = 350329, upload-time = "2025-08-05T01:42:54.381Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0a/e4/e97f8fdd8a07aa1efc7f7e37b5657d84357b664bf70dd1885a437edc0699/llama_stack_client-0.2.18-py3-none-any.whl", hash = "sha256:90f827d5476f7fc15fd993f1863af6a6e72bd064646bf6a99435eb43a1327f70", size = 367586, upload-time = "2025-08-19T22:12:07.899Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue