mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 14:08:00 +00:00
chore(pre-commit): add pre-commit hook to enforce llama_stack logger usage
Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
parent
81ecaf6221
commit
acd40800cc
58 changed files with 1302 additions and 122 deletions
|
@ -161,6 +161,25 @@ repos:
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
require_serial: true
|
require_serial: true
|
||||||
|
|
||||||
|
- id: check-log-usage
|
||||||
|
name: Ensure 'llama_stack.log' usage for logging
|
||||||
|
entry: bash
|
||||||
|
language: system
|
||||||
|
types: [python]
|
||||||
|
pass_filenames: true
|
||||||
|
args:
|
||||||
|
- -c
|
||||||
|
- |
|
||||||
|
matches=$(grep -EnH '^[^#]*\b(import\s+logging|from\s+logging\b)' "$@" | grep -v -e '#\s*allow-direct-logging' || true)
|
||||||
|
if [ -n "$matches" ]; then
|
||||||
|
# GitHub Actions annotation format
|
||||||
|
while IFS=: read -r file line_num rest; do
|
||||||
|
echo "::error file=$file,line=$line_num::Do not use 'import logging' or 'from logging import' in $file. Use the custom log instead: from llama_stack.log import get_logger; logger = get_logger(). If direct logging is truly needed, add: # allow-direct-logging"
|
||||||
|
done <<< "$matches"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
exit 0
|
||||||
|
|
||||||
ci:
|
ci:
|
||||||
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
|
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
|
||||||
autoupdate_commit_msg: ⬆ [pre-commit.ci] pre-commit autoupdate
|
autoupdate_commit_msg: ⬆ [pre-commit.ci] pre-commit autoupdate
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import importlib.resources
|
import importlib.resources
|
||||||
import logging
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
@ -17,9 +16,10 @@ from llama_stack.core.external import load_external_apis
|
||||||
from llama_stack.core.utils.exec import run_command
|
from llama_stack.core.utils.exec import run_command
|
||||||
from llama_stack.core.utils.image_types import LlamaStackImageType
|
from llama_stack.core.utils.image_types import LlamaStackImageType
|
||||||
from llama_stack.distributions.template import DistributionTemplate
|
from llama_stack.distributions.template import DistributionTemplate
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
# These are the dependencies needed by the distribution server.
|
# These are the dependencies needed by the distribution server.
|
||||||
# `llama-stack` is automatically installed by the installation script.
|
# `llama-stack` is automatically installed by the installation script.
|
||||||
|
|
|
@ -3,7 +3,6 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import logging
|
|
||||||
import textwrap
|
import textwrap
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
@ -21,9 +20,10 @@ from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
|
||||||
from llama_stack.core.utils.config_dirs import EXTERNAL_PROVIDERS_DIR
|
from llama_stack.core.utils.config_dirs import EXTERNAL_PROVIDERS_DIR
|
||||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||||
from llama_stack.core.utils.prompt_for_config import prompt_for_config
|
from llama_stack.core.utils.prompt_for_config import prompt_for_config
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
def configure_single_provider(registry: dict[str, ProviderSpec], provider: Provider) -> Provider:
|
def configure_single_provider(registry: dict[str, ProviderSpec], provider: Provider) -> Provider:
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging # allow-direct-logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
@ -48,6 +48,7 @@ from llama_stack.core.stack import (
|
||||||
from llama_stack.core.utils.config import redact_sensitive_fields
|
from llama_stack.core.utils.config import redact_sensitive_fields
|
||||||
from llama_stack.core.utils.context import preserve_contexts_async_generator
|
from llama_stack.core.utils.context import preserve_contexts_async_generator
|
||||||
from llama_stack.core.utils.exec import in_notebook
|
from llama_stack.core.utils.exec import in_notebook
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.telemetry.tracing import (
|
from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
CURRENT_TRACE_CONTEXT,
|
CURRENT_TRACE_CONTEXT,
|
||||||
end_trace,
|
end_trace,
|
||||||
|
@ -55,7 +56,7 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
start_trace,
|
start_trace,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
|
@ -6,15 +6,15 @@
|
||||||
|
|
||||||
import contextvars
|
import contextvars
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
from contextlib import AbstractContextManager
|
from contextlib import AbstractContextManager
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.core.datatypes import User
|
from llama_stack.core.datatypes import User
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .utils.dynamic import instantiate_class_type
|
from .utils.dynamic import instantiate_class_type
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
# Context variable for request provider data and auth attributes
|
# Context variable for request provider data and auth attributes
|
||||||
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
|
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
|
||||||
|
|
|
@ -9,7 +9,7 @@ import asyncio
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging # allow-direct-logging
|
||||||
import os
|
import os
|
||||||
import ssl
|
import ssl
|
||||||
import sys
|
import sys
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
import importlib
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import subprocess
|
import subprocess
|
||||||
|
@ -12,9 +12,9 @@ import sys
|
||||||
|
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
import importlib
|
log = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
def formulate_run_args(image_type: str, image_name: str) -> list:
|
def formulate_run_args(image_type: str, image_name: str) -> list:
|
||||||
|
|
|
@ -6,7 +6,6 @@
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Annotated, Any, Literal, Union, get_args, get_origin
|
from typing import Annotated, Any, Literal, Union, get_args, get_origin
|
||||||
|
|
||||||
|
@ -14,7 +13,9 @@ from pydantic import BaseModel
|
||||||
from pydantic.fields import FieldInfo
|
from pydantic.fields import FieldInfo
|
||||||
from pydantic_core import PydanticUndefinedType
|
from pydantic_core import PydanticUndefinedType
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
log = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
def is_list_of_primitives(field_type):
|
def is_list_of_primitives(field_type):
|
||||||
|
|
|
@ -4,11 +4,11 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
import logging # allow-direct-logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
from logging.config import dictConfig
|
from logging.config import dictConfig # allow-direct-logging
|
||||||
|
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.errors import MarkupError
|
from rich.errors import MarkupError
|
||||||
|
|
|
@ -13,14 +13,15 @@
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and its affiliates.
|
# Copyright (c) Meta Platforms, Inc. and its affiliates.
|
||||||
import math
|
import math
|
||||||
from logging import getLogger
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .utils import get_negative_inf_value, to_2tuple
|
from .utils import get_negative_inf_value, to_2tuple
|
||||||
|
|
||||||
logger = getLogger()
|
logger = get_logger(name=__name__, category="models::llama")
|
||||||
|
|
||||||
|
|
||||||
def resize_local_position_embedding(orig_pos_embed, grid_size):
|
def resize_local_position_embedding(orig_pos_embed, grid_size):
|
||||||
|
|
|
@ -13,7 +13,6 @@
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from logging import getLogger
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -21,9 +20,11 @@ import torchvision.transforms as tv
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torchvision.transforms import functional as F
|
from torchvision.transforms import functional as F
|
||||||
|
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
IMAGE_RES = 224
|
IMAGE_RES = 224
|
||||||
|
|
||||||
logger = getLogger()
|
logger = get_logger(name=__name__, category="models::llama")
|
||||||
|
|
||||||
|
|
||||||
class VariableSizeImageTransform:
|
class VariableSizeImageTransform:
|
||||||
|
|
|
@ -3,8 +3,6 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
import math
|
import math
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
@ -22,6 +20,8 @@ from PIL import Image as PIL_Image
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from torch.distributed import _functional_collectives as funcol
|
from torch.distributed import _functional_collectives as funcol
|
||||||
|
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
from ..model import ModelArgs, RMSNorm, apply_rotary_emb, precompute_freqs_cis
|
from ..model import ModelArgs, RMSNorm, apply_rotary_emb, precompute_freqs_cis
|
||||||
from .encoder_utils import (
|
from .encoder_utils import (
|
||||||
build_encoder_attention_mask,
|
build_encoder_attention_mask,
|
||||||
|
@ -34,9 +34,10 @@ from .encoder_utils import (
|
||||||
from .image_transform import VariableSizeImageTransform
|
from .image_transform import VariableSizeImageTransform
|
||||||
from .utils import get_negative_inf_value, to_2tuple
|
from .utils import get_negative_inf_value, to_2tuple
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
MP_SCALE = 8
|
MP_SCALE = 8
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="models")
|
||||||
|
|
||||||
|
|
||||||
def reduce_from_tensor_model_parallel_region(input_):
|
def reduce_from_tensor_model_parallel_region(input_):
|
||||||
"""All-reduce the input tensor across model parallel group."""
|
"""All-reduce the input tensor across model parallel group."""
|
||||||
|
@ -771,7 +772,7 @@ class TilePositionEmbedding(nn.Module):
|
||||||
if embed is not None:
|
if embed is not None:
|
||||||
# reshape the weights to the correct shape
|
# reshape the weights to the correct shape
|
||||||
nt_old, nt_old, _, w = embed.shape
|
nt_old, nt_old, _, w = embed.shape
|
||||||
logging.info(f"Resizing tile embedding from {nt_old}x{nt_old} to {self.num_tiles}x{self.num_tiles}")
|
logger.info(f"Resizing tile embedding from {nt_old}x{nt_old} to {self.num_tiles}x{self.num_tiles}")
|
||||||
embed_new = TilePositionEmbedding._dynamic_resize(embed, self.num_tiles)
|
embed_new = TilePositionEmbedding._dynamic_resize(embed, self.num_tiles)
|
||||||
# assign the weights to the module
|
# assign the weights to the module
|
||||||
state_dict[prefix + "embedding"] = embed_new
|
state_dict[prefix + "embedding"] = embed_new
|
||||||
|
|
|
@ -4,8 +4,8 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
from collections.abc import Collection, Iterator, Sequence, Set
|
from collections.abc import Collection, Iterator, Sequence, Set
|
||||||
from logging import getLogger
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
from typing import (
|
||||||
Literal,
|
Literal,
|
||||||
|
@ -14,11 +14,9 @@ from typing import (
|
||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.tokenizer_utils import load_bpe_file
|
from llama_stack.models.llama.tokenizer_utils import load_bpe_file
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# The tiktoken tokenizer can handle <=400k chars without
|
# The tiktoken tokenizer can handle <=400k chars without
|
||||||
# pyo3_runtime.PanicException.
|
# pyo3_runtime.PanicException.
|
||||||
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
|
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
|
||||||
|
@ -31,6 +29,8 @@ MAX_NO_WHITESPACES_CHARS = 25_000
|
||||||
|
|
||||||
_INSTANCE = None
|
_INSTANCE = None
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="models::llama")
|
||||||
|
|
||||||
|
|
||||||
class Tokenizer:
|
class Tokenizer:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
@ -13,11 +12,13 @@ from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
from ...datatypes import QuantizationMode
|
from ...datatypes import QuantizationMode
|
||||||
from ..model import Transformer, TransformerBlock
|
from ..model import Transformer, TransformerBlock
|
||||||
from ..moe import MoE
|
from ..moe import MoE
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="models")
|
||||||
|
|
||||||
|
|
||||||
def swiglu_wrapper_no_reduce(
|
def swiglu_wrapper_no_reduce(
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from collections.abc import Collection, Iterator, Sequence, Set
|
from collections.abc import Collection, Iterator, Sequence, Set
|
||||||
from logging import getLogger
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
from typing import (
|
||||||
Literal,
|
Literal,
|
||||||
|
@ -14,11 +13,9 @@ from typing import (
|
||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.tokenizer_utils import load_bpe_file
|
from llama_stack.models.llama.tokenizer_utils import load_bpe_file
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# The tiktoken tokenizer can handle <=400k chars without
|
# The tiktoken tokenizer can handle <=400k chars without
|
||||||
# pyo3_runtime.PanicException.
|
# pyo3_runtime.PanicException.
|
||||||
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
|
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
|
||||||
|
@ -101,6 +98,8 @@ BASIC_SPECIAL_TOKENS = [
|
||||||
"<|fim_suffix|>",
|
"<|fim_suffix|>",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="models::llama")
|
||||||
|
|
||||||
|
|
||||||
class Tokenizer:
|
class Tokenizer:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -6,9 +6,10 @@
|
||||||
|
|
||||||
# type: ignore
|
# type: ignore
|
||||||
import collections
|
import collections
|
||||||
import logging
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
log = get_logger(name=__name__, category="llama")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
|
@ -42,6 +41,7 @@ from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.core.datatypes import AccessRule
|
from llama_stack.core.datatypes import AccessRule
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
||||||
from llama_stack.providers.utils.pagination import paginate_records
|
from llama_stack.providers.utils.pagination import paginate_records
|
||||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||||
|
@ -51,7 +51,7 @@ from .config import MetaReferenceAgentsImplConfig
|
||||||
from .persistence import AgentInfo
|
from .persistence import AgentInfo
|
||||||
from .responses.openai_responses import OpenAIResponsesImpl
|
from .responses.openai_responses import OpenAIResponsesImpl
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = get_logger(name=__name__, category="agents")
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceAgentsImpl(Agents):
|
class MetaReferenceAgentsImpl(Agents):
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
|
@ -15,9 +14,10 @@ from llama_stack.core.access_control.access_control import AccessDeniedError, is
|
||||||
from llama_stack.core.access_control.datatypes import AccessRule
|
from llama_stack.core.access_control.datatypes import AccessRule
|
||||||
from llama_stack.core.datatypes import User
|
from llama_stack.core.datatypes import User
|
||||||
from llama_stack.core.request_headers import get_authenticated_user
|
from llama_stack.core.request_headers import get_authenticated_user
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="agents")
|
||||||
|
|
||||||
|
|
||||||
class AgentSessionInfo(Session):
|
class AgentSessionInfo(Session):
|
||||||
|
|
|
@ -5,13 +5,13 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import Message
|
from llama_stack.apis.inference import Message
|
||||||
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.telemetry import tracing
|
from llama_stack.providers.utils.telemetry import tracing
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="agents")
|
||||||
|
|
||||||
|
|
||||||
class SafetyException(Exception): # noqa: N818
|
class SafetyException(Exception): # noqa: N818
|
||||||
|
|
|
@ -12,7 +12,6 @@
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
|
@ -32,13 +31,14 @@ from fairscale.nn.model_parallel.initialize import (
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
||||||
|
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.datatypes import GenerationResult
|
from llama_stack.models.llama.datatypes import GenerationResult
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
ChatCompletionRequestWithRawContent,
|
ChatCompletionRequestWithRawContent,
|
||||||
CompletionRequestWithRawContent,
|
CompletionRequestWithRawContent,
|
||||||
)
|
)
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="inference")
|
||||||
|
|
||||||
|
|
||||||
class ProcessingMessageName(str, Enum):
|
class ProcessingMessageName(str, Enum):
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
|
@ -21,6 +20,7 @@ from llama_stack.apis.inference import (
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import ModelType
|
from llama_stack.apis.models import ModelType
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||||
SentenceTransformerEmbeddingMixin,
|
SentenceTransformerEmbeddingMixin,
|
||||||
|
@ -32,7 +32,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
|
|
||||||
from .config import SentenceTransformersInferenceConfig
|
from .config import SentenceTransformersInferenceConfig
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="inference")
|
||||||
|
|
||||||
|
|
||||||
class SentenceTransformersInferenceImpl(
|
class SentenceTransformersInferenceImpl(
|
||||||
|
|
|
@ -6,7 +6,6 @@
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
@ -28,6 +27,7 @@ from llama_stack.apis.post_training import (
|
||||||
LoraFinetuningConfig,
|
LoraFinetuningConfig,
|
||||||
TrainingConfig,
|
TrainingConfig,
|
||||||
)
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
||||||
|
|
||||||
from ..config import HuggingFacePostTrainingConfig
|
from ..config import HuggingFacePostTrainingConfig
|
||||||
|
@ -44,7 +44,7 @@ from ..utils import (
|
||||||
split_dataset,
|
split_dataset,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = get_logger(name=__name__, category="post_training")
|
||||||
|
|
||||||
|
|
||||||
class HFFinetuningSingleDevice:
|
class HFFinetuningSingleDevice:
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
import logging
|
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
@ -24,6 +23,7 @@ from llama_stack.apis.post_training import (
|
||||||
DPOAlignmentConfig,
|
DPOAlignmentConfig,
|
||||||
TrainingConfig,
|
TrainingConfig,
|
||||||
)
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
||||||
|
|
||||||
from ..config import HuggingFacePostTrainingConfig
|
from ..config import HuggingFacePostTrainingConfig
|
||||||
|
@ -40,7 +40,7 @@ from ..utils import (
|
||||||
split_dataset,
|
split_dataset,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = get_logger(name=__name__, category="post_training")
|
||||||
|
|
||||||
|
|
||||||
class HFDPOAlignmentSingleDevice:
|
class HFDPOAlignmentSingleDevice:
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
|
@ -19,10 +18,11 @@ from transformers import AutoConfig, AutoModelForCausalLM
|
||||||
|
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.post_training import Checkpoint, TrainingConfig
|
from llama_stack.apis.post_training import Checkpoint, TrainingConfig
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .config import HuggingFacePostTrainingConfig
|
from .config import HuggingFacePostTrainingConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = get_logger(name=__name__, category="post_training")
|
||||||
|
|
||||||
|
|
||||||
def setup_environment():
|
def setup_environment():
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
|
@ -19,6 +18,7 @@ from torch.utils.data import DataLoader, DistributedSampler
|
||||||
from torchtune import modules, training
|
from torchtune import modules, training
|
||||||
from torchtune import utils as torchtune_utils
|
from torchtune import utils as torchtune_utils
|
||||||
from torchtune.data import padded_collate_sft
|
from torchtune.data import padded_collate_sft
|
||||||
|
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
||||||
from torchtune.modules.loss import CEWithChunkedOutputLoss
|
from torchtune.modules.loss import CEWithChunkedOutputLoss
|
||||||
from torchtune.modules.peft import (
|
from torchtune.modules.peft import (
|
||||||
get_adapter_params,
|
get_adapter_params,
|
||||||
|
@ -45,6 +45,7 @@ from llama_stack.apis.post_training import (
|
||||||
)
|
)
|
||||||
from llama_stack.core.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
from llama_stack.core.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||||
from llama_stack.core.utils.model_utils import model_local_dir
|
from llama_stack.core.utils.model_utils import model_local_dir
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.sku_list import resolve_model
|
from llama_stack.models.llama.sku_list import resolve_model
|
||||||
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
||||||
from llama_stack.providers.inline.post_training.torchtune.common import utils
|
from llama_stack.providers.inline.post_training.torchtune.common import utils
|
||||||
|
@ -56,9 +57,7 @@ from llama_stack.providers.inline.post_training.torchtune.config import (
|
||||||
)
|
)
|
||||||
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
|
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="post_training")
|
||||||
|
|
||||||
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
class LoraFinetuningSingleDevice:
|
class LoraFinetuningSingleDevice:
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.inference import Message
|
from llama_stack.apis.inference import Message
|
||||||
|
@ -15,13 +14,14 @@ from llama_stack.apis.safety import (
|
||||||
ViolationLevel,
|
ViolationLevel,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
interleaved_content_as_str,
|
interleaved_content_as_str,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .config import CodeScannerConfig
|
from .config import CodeScannerConfig
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="safety")
|
||||||
|
|
||||||
ALLOWED_CODE_SCANNER_MODEL_IDS = [
|
ALLOWED_CODE_SCANNER_MODEL_IDS = [
|
||||||
"CodeScanner",
|
"CodeScanner",
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from string import Template
|
from string import Template
|
||||||
|
@ -25,6 +24,7 @@ from llama_stack.apis.safety import (
|
||||||
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
|
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.core.datatypes import Api
|
from llama_stack.core.datatypes import Api
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.datatypes import Role
|
from llama_stack.models.llama.datatypes import Role
|
||||||
from llama_stack.models.llama.sku_types import CoreModelId
|
from llama_stack.models.llama.sku_types import CoreModelId
|
||||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||||
|
@ -137,6 +137,8 @@ Provide your safety assessment for ONLY THE LAST $agent_type message in the abov
|
||||||
|
|
||||||
PROMPT_TEMPLATE = Template(f"{PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS}")
|
PROMPT_TEMPLATE = Template(f"{PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS}")
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="safety")
|
||||||
|
|
||||||
|
|
||||||
class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||||
def __init__(self, config: LlamaGuardConfig, deps) -> None:
|
def __init__(self, config: LlamaGuardConfig, deps) -> None:
|
||||||
|
@ -412,7 +414,7 @@ class LlamaGuardShield:
|
||||||
unsafe_code_list = [code.strip() for code in unsafe_code.split(",")]
|
unsafe_code_list = [code.strip() for code in unsafe_code.split(",")]
|
||||||
invalid_codes = [code for code in unsafe_code_list if code not in SAFETY_CODE_TO_CATEGORIES_MAP]
|
invalid_codes = [code for code in unsafe_code_list if code not in SAFETY_CODE_TO_CATEGORIES_MAP]
|
||||||
if invalid_codes:
|
if invalid_codes:
|
||||||
logging.warning(f"Invalid safety codes returned: {invalid_codes}")
|
logger.warning(f"Invalid safety codes returned: {invalid_codes}")
|
||||||
# just returning safe object, as we don't know what the invalid codes can map to
|
# just returning safe object, as we don't know what the invalid codes can map to
|
||||||
return ModerationObject(
|
return ModerationObject(
|
||||||
id=f"modr-{uuid.uuid4()}",
|
id=f"modr-{uuid.uuid4()}",
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -21,6 +20,7 @@ from llama_stack.apis.safety import (
|
||||||
from llama_stack.apis.safety.safety import ModerationObject
|
from llama_stack.apis.safety.safety import ModerationObject
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.core.utils.model_utils import model_local_dir
|
from llama_stack.core.utils.model_utils import model_local_dir
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
interleaved_content_as_str,
|
interleaved_content_as_str,
|
||||||
|
@ -28,7 +28,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
|
||||||
from .config import PromptGuardConfig, PromptGuardType
|
from .config import PromptGuardConfig, PromptGuardType
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="safety")
|
||||||
|
|
||||||
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
|
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
import collections
|
import collections
|
||||||
import functools
|
import functools
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import string
|
import string
|
||||||
|
@ -20,7 +19,9 @@ import nltk
|
||||||
from pythainlp.tokenize import sent_tokenize as sent_tokenize_thai
|
from pythainlp.tokenize import sent_tokenize as sent_tokenize_thai
|
||||||
from pythainlp.tokenize import word_tokenize as word_tokenize_thai
|
from pythainlp.tokenize import word_tokenize as word_tokenize_thai
|
||||||
|
|
||||||
logger = logging.getLogger()
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="scoring")
|
||||||
|
|
||||||
WORD_LIST = [
|
WORD_LIST = [
|
||||||
"western",
|
"western",
|
||||||
|
|
|
@ -4,13 +4,10 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
import threading
|
import threading
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from opentelemetry import metrics, trace
|
from opentelemetry import metrics, trace
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
|
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
|
||||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
||||||
from opentelemetry.sdk.metrics import MeterProvider
|
from opentelemetry.sdk.metrics import MeterProvider
|
||||||
|
@ -40,6 +37,7 @@ from llama_stack.apis.telemetry import (
|
||||||
UnstructuredLogEvent,
|
UnstructuredLogEvent,
|
||||||
)
|
)
|
||||||
from llama_stack.core.datatypes import Api
|
from llama_stack.core.datatypes import Api
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import (
|
from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import (
|
||||||
ConsoleSpanProcessor,
|
ConsoleSpanProcessor,
|
||||||
)
|
)
|
||||||
|
@ -61,6 +59,8 @@ _GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = {
|
||||||
_global_lock = threading.Lock()
|
_global_lock = threading.Lock()
|
||||||
_TRACER_PROVIDER = None
|
_TRACER_PROVIDER = None
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="telemetry")
|
||||||
|
|
||||||
|
|
||||||
def is_tracing_enabled(tracer):
|
def is_tracing_enabled(tracer):
|
||||||
with tracer.start_as_current_span("check_tracing") as span:
|
with tracer.start_as_current_span("check_tracing") as span:
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
|
||||||
import secrets
|
import secrets
|
||||||
import string
|
import string
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
@ -32,6 +31,7 @@ from llama_stack.apis.tools import (
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
|
@ -42,7 +42,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
||||||
from .config import RagToolRuntimeConfig
|
from .config import RagToolRuntimeConfig
|
||||||
from .context_retriever import generate_rag_query
|
from .context_retriever import generate_rag_query
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="tool_runtime")
|
||||||
|
|
||||||
|
|
||||||
def make_random_string(length: int = 8):
|
def make_random_string(length: int = 8):
|
||||||
|
|
|
@ -8,7 +8,6 @@ import asyncio
|
||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import faiss
|
import faiss
|
||||||
|
@ -24,6 +23,7 @@ from llama_stack.apis.vector_io import (
|
||||||
QueryChunksResponse,
|
QueryChunksResponse,
|
||||||
VectorIO,
|
VectorIO,
|
||||||
)
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import (
|
from llama_stack.providers.datatypes import (
|
||||||
HealthResponse,
|
HealthResponse,
|
||||||
HealthStatus,
|
HealthStatus,
|
||||||
|
@ -40,7 +40,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
||||||
|
|
||||||
from .config import FaissVectorIOConfig
|
from .config import FaissVectorIOConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = get_logger(name=__name__, category="vector_io")
|
||||||
|
|
||||||
VERSION = "v3"
|
VERSION = "v3"
|
||||||
VECTOR_DBS_PREFIX = f"vector_dbs:{VERSION}::"
|
VECTOR_DBS_PREFIX = f"vector_dbs:{VERSION}::"
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
|
||||||
import re
|
import re
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import struct
|
import struct
|
||||||
|
@ -24,6 +23,7 @@ from llama_stack.apis.vector_io import (
|
||||||
QueryChunksResponse,
|
QueryChunksResponse,
|
||||||
VectorIO,
|
VectorIO,
|
||||||
)
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||||
|
@ -36,7 +36,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
||||||
VectorDBWithIndex,
|
VectorDBWithIndex,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = get_logger(name=__name__, category="vector_io")
|
||||||
|
|
||||||
# Specifying search mode is dependent on the VectorIO provider.
|
# Specifying search mode is dependent on the VectorIO provider.
|
||||||
VECTOR_SEARCH = "vector"
|
VECTOR_SEARCH = "vector"
|
||||||
|
|
|
@ -3,15 +3,14 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import logging
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
|
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
|
||||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
|
|
||||||
from .models import MODEL_ENTRIES
|
from .models import MODEL_ENTRIES
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = get_logger(name=__name__, category="inference")
|
||||||
|
|
||||||
|
|
||||||
class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
|
|
||||||
|
@ -33,6 +32,7 @@ from llama_stack.apis.inference import (
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
)
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat
|
from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
|
@ -54,7 +54,7 @@ from .openai_utils import (
|
||||||
)
|
)
|
||||||
from .utils import _is_nvidia_hosted
|
from .utils import _is_nvidia_hosted
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = get_logger(name=__name__, category="inference")
|
||||||
|
|
||||||
|
|
||||||
class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
|
class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
|
||||||
|
|
|
@ -4,13 +4,13 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
from . import NVIDIAConfig
|
from . import NVIDIAConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = get_logger(name=__name__, category="inference")
|
||||||
|
|
||||||
|
|
||||||
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:
|
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:
|
||||||
|
|
|
@ -4,15 +4,14 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
|
|
||||||
from .config import OpenAIConfig
|
from .config import OpenAIConfig
|
||||||
from .models import MODEL_ENTRIES
|
from .models import MODEL_ENTRIES
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = get_logger(name=__name__, category="inference")
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
|
|
||||||
from huggingface_hub import AsyncInferenceClient, HfApi
|
from huggingface_hub import AsyncInferenceClient, HfApi
|
||||||
|
@ -34,6 +33,7 @@ from llama_stack.apis.inference import (
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.sku_list import all_registered_models
|
from llama_stack.models.llama.sku_list import all_registered_models
|
||||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
|
@ -58,7 +58,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
|
||||||
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
|
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="inference")
|
||||||
|
|
||||||
|
|
||||||
def build_hf_repo_model_entries():
|
def build_hf_repo_model_entries():
|
||||||
|
|
|
@ -4,18 +4,18 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.post_training import TrainingConfig
|
from llama_stack.apis.post_training import TrainingConfig
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.remote.post_training.nvidia.config import SFTLoRADefaultConfig
|
from llama_stack.providers.remote.post_training.nvidia.config import SFTLoRADefaultConfig
|
||||||
|
|
||||||
from .config import NvidiaPostTrainingConfig
|
from .config import NvidiaPostTrainingConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = get_logger(name=__name__, category="integration")
|
||||||
|
|
||||||
|
|
||||||
def warn_unsupported_params(config_dict: Any, supported_keys: set[str], config_name: str) -> None:
|
def warn_unsupported_params(config_dict: Any, supported_keys: set[str], config_name: str) -> None:
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.inference import Message
|
from llama_stack.apis.inference import Message
|
||||||
|
@ -16,12 +15,13 @@ from llama_stack.apis.safety import (
|
||||||
ViolationLevel,
|
ViolationLevel,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||||
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
||||||
|
|
||||||
from .config import BedrockSafetyConfig
|
from .config import BedrockSafetyConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = get_logger(name=__name__, category="safety")
|
||||||
|
|
||||||
|
|
||||||
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
@ -12,12 +11,13 @@ import requests
|
||||||
from llama_stack.apis.inference import Message
|
from llama_stack.apis.inference import Message
|
||||||
from llama_stack.apis.safety import RunShieldResponse, Safety, SafetyViolation, ViolationLevel
|
from llama_stack.apis.safety import RunShieldResponse, Safety, SafetyViolation, ViolationLevel
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new
|
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new
|
||||||
|
|
||||||
from .config import NVIDIASafetyConfig
|
from .config import NVIDIASafetyConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = get_logger(name=__name__, category="safety")
|
||||||
|
|
||||||
|
|
||||||
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
@ -20,12 +19,13 @@ from llama_stack.apis.safety import (
|
||||||
)
|
)
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new
|
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new
|
||||||
|
|
||||||
from .config import SambaNovaSafetyConfig
|
from .config import SambaNovaSafetyConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = get_logger(name=__name__, category="safety")
|
||||||
|
|
||||||
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
@ -20,6 +19,7 @@ from llama_stack.apis.vector_io import (
|
||||||
QueryChunksResponse,
|
QueryChunksResponse,
|
||||||
VectorIO,
|
VectorIO,
|
||||||
)
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||||
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
|
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
|
@ -33,7 +33,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
||||||
|
|
||||||
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
|
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="vector_io")
|
||||||
|
|
||||||
ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI
|
ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
@ -21,6 +20,7 @@ from llama_stack.apis.vector_io import (
|
||||||
QueryChunksResponse,
|
QueryChunksResponse,
|
||||||
VectorIO,
|
VectorIO,
|
||||||
)
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||||
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
|
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
|
@ -36,7 +36,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti
|
||||||
|
|
||||||
from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig
|
from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = get_logger(name=__name__, category="vector_io")
|
||||||
|
|
||||||
VERSION = "v3"
|
VERSION = "v3"
|
||||||
VECTOR_DBS_PREFIX = f"vector_dbs:milvus:{VERSION}::"
|
VECTOR_DBS_PREFIX = f"vector_dbs:milvus:{VERSION}::"
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import psycopg2
|
import psycopg2
|
||||||
|
@ -22,6 +21,7 @@ from llama_stack.apis.vector_io import (
|
||||||
QueryChunksResponse,
|
QueryChunksResponse,
|
||||||
VectorIO,
|
VectorIO,
|
||||||
)
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||||
|
@ -34,7 +34,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
||||||
|
|
||||||
from .config import PGVectorVectorIOConfig
|
from .config import PGVectorVectorIOConfig
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="vector_io")
|
||||||
|
|
||||||
VERSION = "v3"
|
VERSION = "v3"
|
||||||
VECTOR_DBS_PREFIX = f"vector_dbs:pgvector:{VERSION}::"
|
VECTOR_DBS_PREFIX = f"vector_dbs:pgvector:{VERSION}::"
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
@ -24,6 +23,7 @@ from llama_stack.apis.vector_io import (
|
||||||
VectorStoreChunkingStrategy,
|
VectorStoreChunkingStrategy,
|
||||||
VectorStoreFileObject,
|
VectorStoreFileObject,
|
||||||
)
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||||
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
|
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
|
||||||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
||||||
|
@ -36,7 +36,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
||||||
|
|
||||||
from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
|
from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="vector_io")
|
||||||
CHUNK_ID_KEY = "_chunk_id"
|
CHUNK_ID_KEY = "_chunk_id"
|
||||||
|
|
||||||
# KV store prefixes for vector databases
|
# KV store prefixes for vector databases
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import weaviate
|
import weaviate
|
||||||
|
@ -19,6 +18,7 @@ from llama_stack.apis.files.files import Files
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||||
|
@ -34,7 +34,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti
|
||||||
|
|
||||||
from .config import WeaviateVectorIOConfig
|
from .config import WeaviateVectorIOConfig
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="vector_io")
|
||||||
|
|
||||||
VERSION = "v3"
|
VERSION = "v3"
|
||||||
VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::"
|
VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::"
|
||||||
|
|
|
@ -5,10 +5,11 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import logging
|
|
||||||
import struct
|
import struct
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
@ -27,7 +28,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import interleaved_con
|
||||||
EMBEDDING_MODELS = {}
|
EMBEDDING_MODELS = {}
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="inference")
|
||||||
|
|
||||||
|
|
||||||
class SentenceTransformerEmbeddingMixin:
|
class SentenceTransformerEmbeddingMixin:
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import struct
|
import struct
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
@ -116,6 +115,7 @@ from llama_stack.apis.inference import (
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
OpenAIChoice as OpenAIChatCompletionChoice,
|
OpenAIChoice as OpenAIChatCompletionChoice,
|
||||||
)
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
StopReason,
|
StopReason,
|
||||||
|
@ -128,7 +128,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
decode_assistant_message,
|
decode_assistant_message,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = get_logger(name=__name__, category="inference")
|
||||||
|
|
||||||
|
|
||||||
class OpenAICompatCompletionChoiceDelta(BaseModel):
|
class OpenAICompatCompletionChoiceDelta(BaseModel):
|
||||||
|
|
|
@ -4,16 +4,16 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from pymongo import AsyncMongoClient
|
from pymongo import AsyncMongoClient
|
||||||
|
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
|
|
||||||
from ..config import MongoDBKVStoreConfig
|
from ..config import MongoDBKVStoreConfig
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="kvstore")
|
||||||
|
|
||||||
|
|
||||||
class MongoDBKVStoreImpl(KVStore):
|
class MongoDBKVStoreImpl(KVStore):
|
||||||
|
|
|
@ -4,16 +4,17 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
import psycopg2
|
import psycopg2
|
||||||
from psycopg2.extras import DictCursor
|
from psycopg2.extras import DictCursor
|
||||||
|
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
from ..api import KVStore
|
from ..api import KVStore
|
||||||
from ..config import PostgresKVStoreConfig
|
from ..config import PostgresKVStoreConfig
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="kvstore")
|
||||||
|
|
||||||
|
|
||||||
class PostgresKVStoreImpl(KVStore):
|
class PostgresKVStoreImpl(KVStore):
|
||||||
|
|
|
@ -44,7 +44,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
||||||
make_overlapped_chunks,
|
make_overlapped_chunks,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = get_logger(__name__, category="vector_io")
|
logger = get_logger(name=__name__, category="memory")
|
||||||
|
|
||||||
# Constants for OpenAI vector stores
|
# Constants for OpenAI vector stores
|
||||||
CHUNK_MULTIPLIER = 5
|
CHUNK_MULTIPLIER = 5
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
import logging
|
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
@ -26,6 +25,7 @@ from llama_stack.apis.common.content_types import (
|
||||||
from llama_stack.apis.tools import RAGDocument
|
from llama_stack.apis.tools import RAGDocument
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
|
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
@ -33,7 +33,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
|
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="memory")
|
||||||
|
|
||||||
|
|
||||||
class ChunkForDeletion(BaseModel):
|
class ChunkForDeletion(BaseModel):
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import contextvars
|
import contextvars
|
||||||
import logging
|
import logging # allow-direct-logging
|
||||||
import queue
|
import queue
|
||||||
import random
|
import random
|
||||||
import sys
|
import sys
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
@ -19,10 +18,10 @@ from llama_stack.apis.post_training import (
|
||||||
LoraFinetuningConfig,
|
LoraFinetuningConfig,
|
||||||
TrainingConfig,
|
TrainingConfig,
|
||||||
)
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", force=True)
|
logger = get_logger(name=__name__, category="post_training")
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
skip_because_resource_intensive = pytest.mark.skip(
|
skip_because_resource_intensive = pytest.mark.skip(
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
import time
|
import time
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
|
@ -14,8 +13,9 @@ from openai import BadRequestError as OpenAIBadRequestError
|
||||||
|
|
||||||
from llama_stack.apis.vector_io import Chunk
|
from llama_stack.apis.vector_io import Chunk
|
||||||
from llama_stack.core.library_client import LlamaStackAsLibraryClient
|
from llama_stack.core.library_client import LlamaStackAsLibraryClient
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = get_logger(name=__name__, category="vector_io")
|
||||||
|
|
||||||
|
|
||||||
def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
|
def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging # allow-direct-logging
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue