mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 14:08:00 +00:00
feat: add auto-generated CI documentation pre-commit hook (#2890)
Our CI is entirely undocumented, this commit adds a README.md file with a table of the current CI and what is does --------- Signed-off-by: Nathan Weinberg <nweinber@redhat.com>
This commit is contained in:
parent
7f834339ba
commit
b381ed6d64
93 changed files with 495 additions and 477 deletions
|
@ -145,6 +145,24 @@ repos:
|
|||
pass_filenames: false
|
||||
require_serial: true
|
||||
files: ^.github/workflows/.*$
|
||||
- id: check-log-usage
|
||||
name: Check for proper log usage (use llama_stack.log instead)
|
||||
entry: bash
|
||||
language: system
|
||||
types: [python]
|
||||
pass_filenames: true
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
matches=$(grep -EnH '^[^#]*\b(import logging|from logging\b)' "$@" | grep -v '# allow-direct-logging' || true)
|
||||
if [ -n "$matches" ]; then
|
||||
# GitHub Actions annotation format
|
||||
while IFS=: read -r file line_num rest; do
|
||||
echo "::error file=$file,line=$line_num::Do not use 'import logging' or 'from logging import' in $file. Use the custom log instead: from llama_stack.log import get_logger; logger = get_logger(). If direct logging is truly needed, add: # allow-direct-logging"
|
||||
done <<< "$matches"
|
||||
exit 1
|
||||
fi
|
||||
exit 0
|
||||
|
||||
ci:
|
||||
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
|
||||
|
|
|
@ -15,7 +15,7 @@ from llama_stack.log import get_logger
|
|||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
logger = get_logger(name=__name__, category="server")
|
||||
log = get_logger(name=__name__, category="server")
|
||||
|
||||
|
||||
class StackRun(Subcommand):
|
||||
|
@ -126,7 +126,7 @@ class StackRun(Subcommand):
|
|||
self.parser.error("Config file is required for venv environment")
|
||||
|
||||
if config_file:
|
||||
logger.info(f"Using run configuration: {config_file}")
|
||||
log.info(f"Using run configuration: {config_file}")
|
||||
|
||||
try:
|
||||
config_dict = yaml.safe_load(config_file.read_text())
|
||||
|
@ -145,7 +145,7 @@ class StackRun(Subcommand):
|
|||
# If neither image type nor image name is provided, assume the server should be run directly
|
||||
# using the current environment packages.
|
||||
if not image_type and not image_name:
|
||||
logger.info("No image type or image name provided. Assuming environment packages.")
|
||||
log.info("No image type or image name provided. Assuming environment packages.")
|
||||
from llama_stack.core.server.server import main as server_main
|
||||
|
||||
# Build the server args from the current args passed to the CLI
|
||||
|
@ -185,11 +185,11 @@ class StackRun(Subcommand):
|
|||
run_command(run_args)
|
||||
|
||||
def _start_ui_development_server(self, stack_server_port: int):
|
||||
logger.info("Attempting to start UI development server...")
|
||||
log.info("Attempting to start UI development server...")
|
||||
# Check if npm is available
|
||||
npm_check = subprocess.run(["npm", "--version"], capture_output=True, text=True, check=False)
|
||||
if npm_check.returncode != 0:
|
||||
logger.warning(
|
||||
log.warning(
|
||||
f"'npm' command not found or not executable. UI development server will not be started. Error: {npm_check.stderr}"
|
||||
)
|
||||
return
|
||||
|
@ -214,13 +214,13 @@ class StackRun(Subcommand):
|
|||
stderr=stderr_log_file,
|
||||
env={**os.environ, "NEXT_PUBLIC_LLAMA_STACK_BASE_URL": f"http://localhost:{stack_server_port}"},
|
||||
)
|
||||
logger.info(f"UI development server process started in {ui_dir} with PID {process.pid}.")
|
||||
logger.info(f"Logs: stdout -> {ui_stdout_log_path}, stderr -> {ui_stderr_log_path}")
|
||||
logger.info(f"UI will be available at http://localhost:{os.getenv('LLAMA_STACK_UI_PORT', 8322)}")
|
||||
log.info(f"UI development server process started in {ui_dir} with PID {process.pid}.")
|
||||
log.info(f"Logs: stdout -> {ui_stdout_log_path}, stderr -> {ui_stderr_log_path}")
|
||||
log.info(f"UI will be available at http://localhost:{os.getenv('LLAMA_STACK_UI_PORT', 8322)}")
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(
|
||||
log.error(
|
||||
"Failed to start UI development server: 'npm' command not found. Make sure npm is installed and in your PATH."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start UI development server in {ui_dir}: {e}")
|
||||
log.error(f"Failed to start UI development server in {ui_dir}: {e}")
|
||||
|
|
|
@ -8,7 +8,7 @@ import argparse
|
|||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="cli")
|
||||
log = get_logger(name=__name__, category="cli")
|
||||
|
||||
|
||||
# TODO: this can probably just be inlined now?
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import importlib.resources
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
@ -17,10 +16,9 @@ from llama_stack.core.external import load_external_apis
|
|||
from llama_stack.core.utils.exec import run_command
|
||||
from llama_stack.core.utils.image_types import LlamaStackImageType
|
||||
from llama_stack.distributions.template import DistributionTemplate
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# These are the dependencies needed by the distribution server.
|
||||
# `llama-stack` is automatically installed by the installation script.
|
||||
SERVER_DEPENDENCIES = [
|
||||
|
@ -33,6 +31,8 @@ SERVER_DEPENDENCIES = [
|
|||
"opentelemetry-exporter-otlp-proto-http",
|
||||
]
|
||||
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class ApiInput(BaseModel):
|
||||
api: Api
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import logging
|
||||
import textwrap
|
||||
from typing import Any
|
||||
|
||||
|
@ -21,9 +20,10 @@ from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
|
|||
from llama_stack.core.utils.config_dirs import EXTERNAL_PROVIDERS_DIR
|
||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.core.utils.prompt_for_config import prompt_for_config
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
def configure_single_provider(registry: dict[str, ProviderSpec], provider: Provider) -> Provider:
|
||||
|
@ -49,7 +49,7 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
|
|||
is_nux = len(config.providers) == 0
|
||||
|
||||
if is_nux:
|
||||
logger.info(
|
||||
log.info(
|
||||
textwrap.dedent(
|
||||
"""
|
||||
Llama Stack is composed of several APIs working together. For each API served by the Stack,
|
||||
|
@ -75,12 +75,12 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
|
|||
|
||||
existing_providers = config.providers.get(api_str, [])
|
||||
if existing_providers:
|
||||
logger.info(f"Re-configuring existing providers for API `{api_str}`...")
|
||||
log.info(f"Re-configuring existing providers for API `{api_str}`...")
|
||||
updated_providers = []
|
||||
for p in existing_providers:
|
||||
logger.info(f"> Configuring provider `({p.provider_type})`")
|
||||
log.info(f"> Configuring provider `({p.provider_type})`")
|
||||
updated_providers.append(configure_single_provider(provider_registry[api], p))
|
||||
logger.info("")
|
||||
log.info("")
|
||||
else:
|
||||
# we are newly configuring this API
|
||||
plist = build_spec.providers.get(api_str, [])
|
||||
|
@ -89,17 +89,17 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
|
|||
if not plist:
|
||||
raise ValueError(f"No provider configured for API {api_str}?")
|
||||
|
||||
logger.info(f"Configuring API `{api_str}`...")
|
||||
log.info(f"Configuring API `{api_str}`...")
|
||||
updated_providers = []
|
||||
for i, provider in enumerate(plist):
|
||||
if i >= 1:
|
||||
others = ", ".join(p.provider_type for p in plist[i:])
|
||||
logger.info(
|
||||
log.info(
|
||||
f"Not configuring other providers ({others}) interactively. Please edit the resulting YAML directly.\n"
|
||||
)
|
||||
break
|
||||
|
||||
logger.info(f"> Configuring provider `({provider.provider_type})`")
|
||||
log.info(f"> Configuring provider `({provider.provider_type})`")
|
||||
pid = provider.provider_type.split("::")[-1]
|
||||
updated_providers.append(
|
||||
configure_single_provider(
|
||||
|
@ -111,7 +111,7 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
|
|||
),
|
||||
)
|
||||
)
|
||||
logger.info("")
|
||||
log.info("")
|
||||
|
||||
config.providers[api_str] = updated_providers
|
||||
|
||||
|
@ -169,7 +169,7 @@ def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfi
|
|||
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))
|
||||
|
||||
if "routing_table" in config_dict:
|
||||
logger.info("Upgrading config...")
|
||||
log.info("Upgrading config...")
|
||||
config_dict = upgrade_from_routing_table(config_dict)
|
||||
|
||||
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
|
||||
|
|
|
@ -23,7 +23,7 @@ from llama_stack.providers.datatypes import (
|
|||
remote_provider_spec,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
def stack_apis() -> list[Api]:
|
||||
|
@ -141,18 +141,18 @@ def get_provider_registry(config=None) -> dict[Api, dict[str, ProviderSpec]]:
|
|||
registry: dict[Api, dict[str, ProviderSpec]] = {}
|
||||
for api in providable_apis():
|
||||
name = api.name.lower()
|
||||
logger.debug(f"Importing module {name}")
|
||||
log.debug(f"Importing module {name}")
|
||||
try:
|
||||
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
|
||||
registry[api] = {a.provider_type: a for a in module.available_providers()}
|
||||
except ImportError as e:
|
||||
logger.warning(f"Failed to import module {name}: {e}")
|
||||
log.warning(f"Failed to import module {name}: {e}")
|
||||
|
||||
# Refresh providable APIs with external APIs if any
|
||||
external_apis = load_external_apis(config)
|
||||
for api, api_spec in external_apis.items():
|
||||
name = api_spec.name.lower()
|
||||
logger.info(f"Importing external API {name} module {api_spec.module}")
|
||||
log.info(f"Importing external API {name} module {api_spec.module}")
|
||||
try:
|
||||
module = importlib.import_module(api_spec.module)
|
||||
registry[api] = {a.provider_type: a for a in module.available_providers()}
|
||||
|
@ -161,7 +161,7 @@ def get_provider_registry(config=None) -> dict[Api, dict[str, ProviderSpec]]:
|
|||
# This assume that the in-tree provider(s) are not available for this API which means
|
||||
# that users will need to use external providers for this API.
|
||||
registry[api] = {}
|
||||
logger.error(
|
||||
log.error(
|
||||
f"Failed to import external API {name}: {e}. Could not populate the in-tree provider(s) registry for {api.name}. \n"
|
||||
"Install the API package to load any in-tree providers for this API."
|
||||
)
|
||||
|
@ -183,13 +183,13 @@ def get_provider_registry(config=None) -> dict[Api, dict[str, ProviderSpec]]:
|
|||
def get_external_providers_from_dir(
|
||||
registry: dict[Api, dict[str, ProviderSpec]], config
|
||||
) -> dict[Api, dict[str, ProviderSpec]]:
|
||||
logger.warning(
|
||||
log.warning(
|
||||
"Specifying external providers via `external_providers_dir` is being deprecated. Please specify `module:` in the provider instead."
|
||||
)
|
||||
external_providers_dir = os.path.abspath(os.path.expanduser(config.external_providers_dir))
|
||||
if not os.path.exists(external_providers_dir):
|
||||
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")
|
||||
logger.info(f"Loading external providers from {external_providers_dir}")
|
||||
log.info(f"Loading external providers from {external_providers_dir}")
|
||||
|
||||
for api in providable_apis():
|
||||
api_name = api.name.lower()
|
||||
|
@ -198,13 +198,13 @@ def get_external_providers_from_dir(
|
|||
for provider_type in ["remote", "inline"]:
|
||||
api_dir = os.path.join(external_providers_dir, provider_type, api_name)
|
||||
if not os.path.exists(api_dir):
|
||||
logger.debug(f"No {provider_type} provider directory found for {api_name}")
|
||||
log.debug(f"No {provider_type} provider directory found for {api_name}")
|
||||
continue
|
||||
|
||||
# Look for provider spec files in the API directory
|
||||
for spec_path in glob.glob(os.path.join(api_dir, "*.yaml")):
|
||||
provider_name = os.path.splitext(os.path.basename(spec_path))[0]
|
||||
logger.info(f"Loading {provider_type} provider spec from {spec_path}")
|
||||
log.info(f"Loading {provider_type} provider spec from {spec_path}")
|
||||
|
||||
try:
|
||||
with open(spec_path) as f:
|
||||
|
@ -217,16 +217,16 @@ def get_external_providers_from_dir(
|
|||
spec = _load_inline_provider_spec(spec_data, api, provider_name)
|
||||
provider_type_key = f"inline::{provider_name}"
|
||||
|
||||
logger.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}")
|
||||
log.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}")
|
||||
if provider_type_key in registry[api]:
|
||||
logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
|
||||
log.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
|
||||
registry[api][provider_type_key] = spec
|
||||
logger.info(f"Successfully loaded external provider {provider_type_key}")
|
||||
log.info(f"Successfully loaded external provider {provider_type_key}")
|
||||
except yaml.YAMLError as yaml_err:
|
||||
logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
|
||||
log.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
|
||||
raise yaml_err
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load provider spec from {spec_path}: {e}")
|
||||
log.error(f"Failed to load provider spec from {spec_path}: {e}")
|
||||
raise e
|
||||
|
||||
return registry
|
||||
|
@ -241,7 +241,7 @@ def get_external_providers_from_module(
|
|||
else:
|
||||
provider_list = config.providers.items()
|
||||
if provider_list is None:
|
||||
logger.warning("Could not get list of providers from config")
|
||||
log.warning("Could not get list of providers from config")
|
||||
return registry
|
||||
for provider_api, providers in provider_list:
|
||||
for provider in providers:
|
||||
|
@ -272,6 +272,6 @@ def get_external_providers_from_module(
|
|||
"get_provider_spec not found. If specifying an external provider via `module` in the Provider spec, the Provider must have the `provider.get_provider_spec` module available"
|
||||
) from exc
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load provider spec from module {provider.module}: {e}")
|
||||
log.error(f"Failed to load provider spec from module {provider.module}: {e}")
|
||||
raise e
|
||||
return registry
|
||||
|
|
|
@ -11,7 +11,7 @@ from llama_stack.apis.datatypes import Api, ExternalApiSpec
|
|||
from llama_stack.core.datatypes import BuildConfig, StackRunConfig
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
def load_external_apis(config: StackRunConfig | BuildConfig | None) -> dict[Api, ExternalApiSpec]:
|
||||
|
@ -28,10 +28,10 @@ def load_external_apis(config: StackRunConfig | BuildConfig | None) -> dict[Api,
|
|||
|
||||
external_apis_dir = config.external_apis_dir.expanduser().resolve()
|
||||
if not external_apis_dir.is_dir():
|
||||
logger.error(f"External APIs directory is not a directory: {external_apis_dir}")
|
||||
log.error(f"External APIs directory is not a directory: {external_apis_dir}")
|
||||
return {}
|
||||
|
||||
logger.info(f"Loading external APIs from {external_apis_dir}")
|
||||
log.info(f"Loading external APIs from {external_apis_dir}")
|
||||
external_apis: dict[Api, ExternalApiSpec] = {}
|
||||
|
||||
# Look for YAML files in the external APIs directory
|
||||
|
@ -42,13 +42,13 @@ def load_external_apis(config: StackRunConfig | BuildConfig | None) -> dict[Api,
|
|||
|
||||
spec = ExternalApiSpec(**spec_data)
|
||||
api = Api.add(spec.name)
|
||||
logger.info(f"Loaded external API spec for {spec.name} from {yaml_path}")
|
||||
log.info(f"Loaded external API spec for {spec.name} from {yaml_path}")
|
||||
external_apis[api] = spec
|
||||
except yaml.YAMLError as yaml_err:
|
||||
logger.error(f"Failed to parse YAML file {yaml_path}: {yaml_err}")
|
||||
log.error(f"Failed to parse YAML file {yaml_path}: {yaml_err}")
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception(f"Failed to load external API spec from {yaml_path}")
|
||||
log.exception(f"Failed to load external API spec from {yaml_path}")
|
||||
raise
|
||||
|
||||
return external_apis
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
@ -48,6 +47,7 @@ from llama_stack.core.stack import (
|
|||
from llama_stack.core.utils.config import redact_sensitive_fields
|
||||
from llama_stack.core.utils.context import preserve_contexts_async_generator
|
||||
from llama_stack.core.utils.exec import in_notebook
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
CURRENT_TRACE_CONTEXT,
|
||||
end_trace,
|
||||
|
@ -55,7 +55,7 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
|||
start_trace,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
@ -84,7 +84,7 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any:
|
|||
try:
|
||||
return [convert_to_pydantic(item_type, item) for item in value]
|
||||
except Exception:
|
||||
logger.error(f"Error converting list {value} into {item_type}")
|
||||
log.error(f"Error converting list {value} into {item_type}")
|
||||
return value
|
||||
|
||||
elif origin is dict:
|
||||
|
@ -92,7 +92,7 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any:
|
|||
try:
|
||||
return {k: convert_to_pydantic(val_type, v) for k, v in value.items()}
|
||||
except Exception:
|
||||
logger.error(f"Error converting dict {value} into {val_type}")
|
||||
log.error(f"Error converting dict {value} into {val_type}")
|
||||
return value
|
||||
|
||||
try:
|
||||
|
@ -108,7 +108,7 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any:
|
|||
return convert_to_pydantic(union_type, value)
|
||||
except Exception:
|
||||
continue
|
||||
logger.warning(
|
||||
log.warning(
|
||||
f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}",
|
||||
)
|
||||
raise ValueError(f"Failed to convert parameter {value} into {annotation}: {e}") from e
|
||||
|
@ -171,13 +171,15 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|||
|
||||
def _remove_root_logger_handlers(self):
|
||||
"""
|
||||
Remove all handlers from the root logger. Needed to avoid polluting the console with logs.
|
||||
Remove all handlers from the root log. Needed to avoid polluting the console with logs.
|
||||
"""
|
||||
import logging # allow-direct-logging
|
||||
|
||||
root_logger = logging.getLogger()
|
||||
|
||||
for handler in root_logger.handlers[:]:
|
||||
root_logger.removeHandler(handler)
|
||||
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
|
||||
log.info(f"Removed handler {handler.__class__.__name__} from root log")
|
||||
|
||||
def request(self, *args, **kwargs):
|
||||
loop = self.loop
|
||||
|
|
|
@ -16,7 +16,7 @@ from llama_stack.providers.datatypes import HealthResponse, HealthStatus
|
|||
from .datatypes import StackRunConfig
|
||||
from .utils.config import redact_sensitive_fields
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class ProviderImplConfig(BaseModel):
|
||||
|
@ -38,7 +38,7 @@ class ProviderImpl(Providers):
|
|||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logger.debug("ProviderImpl.shutdown")
|
||||
log.debug("ProviderImpl.shutdown")
|
||||
pass
|
||||
|
||||
async def list_providers(self) -> ListProvidersResponse:
|
||||
|
|
|
@ -6,19 +6,19 @@
|
|||
|
||||
import contextvars
|
||||
import json
|
||||
import logging
|
||||
from contextlib import AbstractContextManager
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.datatypes import User
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .utils.dynamic import instantiate_class_type
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Context variable for request provider data and auth attributes
|
||||
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
|
||||
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class RequestProviderDataContext(AbstractContextManager):
|
||||
"""Context manager for request provider data"""
|
||||
|
|
|
@ -54,7 +54,7 @@ from llama_stack.providers.datatypes import (
|
|||
VectorDBsProtocolPrivate,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class InvalidProviderError(Exception):
|
||||
|
@ -101,7 +101,7 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
|
|||
|
||||
protocols[api] = api_class
|
||||
except (ImportError, AttributeError):
|
||||
logger.exception(f"Failed to load external API {api_spec.name}")
|
||||
log.exception(f"Failed to load external API {api_spec.name}")
|
||||
|
||||
return protocols
|
||||
|
||||
|
@ -223,7 +223,7 @@ def validate_and_prepare_providers(
|
|||
specs = {}
|
||||
for provider in providers:
|
||||
if not provider.provider_id or provider.provider_id == "__disabled__":
|
||||
logger.debug(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
|
||||
log.debug(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
|
||||
continue
|
||||
|
||||
validate_provider(provider, api, provider_registry)
|
||||
|
@ -245,10 +245,10 @@ def validate_provider(provider: Provider, api: Api, provider_registry: ProviderR
|
|||
|
||||
p = provider_registry[api][provider.provider_type]
|
||||
if p.deprecation_error:
|
||||
logger.error(p.deprecation_error)
|
||||
log.error(p.deprecation_error)
|
||||
raise InvalidProviderError(p.deprecation_error)
|
||||
elif p.deprecation_warning:
|
||||
logger.warning(
|
||||
log.warning(
|
||||
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
|
||||
)
|
||||
|
||||
|
@ -261,9 +261,9 @@ def sort_providers_by_deps(
|
|||
{k: list(v.values()) for k, v in providers_with_specs.items()}
|
||||
)
|
||||
|
||||
logger.debug(f"Resolved {len(sorted_providers)} providers")
|
||||
log.debug(f"Resolved {len(sorted_providers)} providers")
|
||||
for api_str, provider in sorted_providers:
|
||||
logger.debug(f" {api_str} => {provider.provider_id}")
|
||||
log.debug(f" {api_str} => {provider.provider_id}")
|
||||
return sorted_providers
|
||||
|
||||
|
||||
|
@ -348,7 +348,7 @@ async def instantiate_provider(
|
|||
if not hasattr(provider_spec, "module") or provider_spec.module is None:
|
||||
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
|
||||
|
||||
logger.debug(f"Instantiating provider {provider.provider_id} from {provider_spec.module}")
|
||||
log.debug(f"Instantiating provider {provider.provider_id} from {provider_spec.module}")
|
||||
module = importlib.import_module(provider_spec.module)
|
||||
args = []
|
||||
if isinstance(provider_spec, RemoteProviderSpec):
|
||||
|
@ -418,7 +418,7 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
|
|||
obj_params = set(obj_sig.parameters)
|
||||
obj_params.discard("self")
|
||||
if not (proto_params <= obj_params):
|
||||
logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
|
||||
log.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
|
||||
missing_methods.append((name, "signature_mismatch"))
|
||||
else:
|
||||
# Check if the method has a concrete implementation (not just a protocol stub)
|
||||
|
|
|
@ -12,7 +12,7 @@ from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
|||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class DatasetIORouter(DatasetIO):
|
||||
|
@ -20,15 +20,15 @@ class DatasetIORouter(DatasetIO):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logger.debug("Initializing DatasetIORouter")
|
||||
log.debug("Initializing DatasetIORouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.debug("DatasetIORouter.initialize")
|
||||
log.debug("DatasetIORouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logger.debug("DatasetIORouter.shutdown")
|
||||
log.debug("DatasetIORouter.shutdown")
|
||||
pass
|
||||
|
||||
async def register_dataset(
|
||||
|
@ -38,7 +38,7 @@ class DatasetIORouter(DatasetIO):
|
|||
metadata: dict[str, Any] | None = None,
|
||||
dataset_id: str | None = None,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
log.debug(
|
||||
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
|
||||
)
|
||||
await self.routing_table.register_dataset(
|
||||
|
@ -54,7 +54,7 @@ class DatasetIORouter(DatasetIO):
|
|||
start_index: int | None = None,
|
||||
limit: int | None = None,
|
||||
) -> PaginatedResponse:
|
||||
logger.debug(
|
||||
log.debug(
|
||||
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl(dataset_id)
|
||||
|
@ -65,7 +65,7 @@ class DatasetIORouter(DatasetIO):
|
|||
)
|
||||
|
||||
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
|
||||
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
|
||||
log.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
|
||||
provider = await self.routing_table.get_provider_impl(dataset_id)
|
||||
return await provider.append_rows(
|
||||
dataset_id=dataset_id,
|
||||
|
|
|
@ -16,7 +16,7 @@ from llama_stack.apis.scoring import (
|
|||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class ScoringRouter(Scoring):
|
||||
|
@ -24,15 +24,15 @@ class ScoringRouter(Scoring):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logger.debug("Initializing ScoringRouter")
|
||||
log.debug("Initializing ScoringRouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.debug("ScoringRouter.initialize")
|
||||
log.debug("ScoringRouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logger.debug("ScoringRouter.shutdown")
|
||||
log.debug("ScoringRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def score_batch(
|
||||
|
@ -41,7 +41,7 @@ class ScoringRouter(Scoring):
|
|||
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse:
|
||||
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
|
||||
log.debug(f"ScoringRouter.score_batch: {dataset_id}")
|
||||
res = {}
|
||||
for fn_identifier in scoring_functions.keys():
|
||||
provider = await self.routing_table.get_provider_impl(fn_identifier)
|
||||
|
@ -63,7 +63,7 @@ class ScoringRouter(Scoring):
|
|||
input_rows: list[dict[str, Any]],
|
||||
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
||||
) -> ScoreResponse:
|
||||
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
|
||||
log.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
|
||||
res = {}
|
||||
# look up and map each scoring function to its provider impl
|
||||
for fn_identifier in scoring_functions.keys():
|
||||
|
@ -82,15 +82,15 @@ class EvalRouter(Eval):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logger.debug("Initializing EvalRouter")
|
||||
log.debug("Initializing EvalRouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.debug("EvalRouter.initialize")
|
||||
log.debug("EvalRouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logger.debug("EvalRouter.shutdown")
|
||||
log.debug("EvalRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def run_eval(
|
||||
|
@ -98,7 +98,7 @@ class EvalRouter(Eval):
|
|||
benchmark_id: str,
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> Job:
|
||||
logger.debug(f"EvalRouter.run_eval: {benchmark_id}")
|
||||
log.debug(f"EvalRouter.run_eval: {benchmark_id}")
|
||||
provider = await self.routing_table.get_provider_impl(benchmark_id)
|
||||
return await provider.run_eval(
|
||||
benchmark_id=benchmark_id,
|
||||
|
@ -112,7 +112,7 @@ class EvalRouter(Eval):
|
|||
scoring_functions: list[str],
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse:
|
||||
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
|
||||
log.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
|
||||
provider = await self.routing_table.get_provider_impl(benchmark_id)
|
||||
return await provider.evaluate_rows(
|
||||
benchmark_id=benchmark_id,
|
||||
|
@ -126,7 +126,7 @@ class EvalRouter(Eval):
|
|||
benchmark_id: str,
|
||||
job_id: str,
|
||||
) -> Job:
|
||||
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
|
||||
log.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
|
||||
provider = await self.routing_table.get_provider_impl(benchmark_id)
|
||||
return await provider.job_status(benchmark_id, job_id)
|
||||
|
||||
|
@ -135,7 +135,7 @@ class EvalRouter(Eval):
|
|||
benchmark_id: str,
|
||||
job_id: str,
|
||||
) -> None:
|
||||
logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
|
||||
log.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
|
||||
provider = await self.routing_table.get_provider_impl(benchmark_id)
|
||||
await provider.job_cancel(
|
||||
benchmark_id,
|
||||
|
@ -147,7 +147,7 @@ class EvalRouter(Eval):
|
|||
benchmark_id: str,
|
||||
job_id: str,
|
||||
) -> EvaluateResponse:
|
||||
logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
|
||||
log.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
|
||||
provider = await self.routing_table.get_provider_impl(benchmark_id)
|
||||
return await provider.job_result(
|
||||
benchmark_id,
|
||||
|
|
|
@ -58,7 +58,7 @@ from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
|||
from llama_stack.providers.utils.inference.stream_utils import stream_and_store_openai_completion
|
||||
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class InferenceRouter(Inference):
|
||||
|
@ -70,7 +70,7 @@ class InferenceRouter(Inference):
|
|||
telemetry: Telemetry | None = None,
|
||||
store: InferenceStore | None = None,
|
||||
) -> None:
|
||||
logger.debug("Initializing InferenceRouter")
|
||||
log.debug("Initializing InferenceRouter")
|
||||
self.routing_table = routing_table
|
||||
self.telemetry = telemetry
|
||||
self.store = store
|
||||
|
@ -79,10 +79,10 @@ class InferenceRouter(Inference):
|
|||
self.formatter = ChatFormat(self.tokenizer)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.debug("InferenceRouter.initialize")
|
||||
log.debug("InferenceRouter.initialize")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logger.debug("InferenceRouter.shutdown")
|
||||
log.debug("InferenceRouter.shutdown")
|
||||
|
||||
async def register_model(
|
||||
self,
|
||||
|
@ -92,7 +92,7 @@ class InferenceRouter(Inference):
|
|||
metadata: dict[str, Any] | None = None,
|
||||
model_type: ModelType | None = None,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
log.debug(
|
||||
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
|
||||
)
|
||||
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
||||
|
@ -117,7 +117,7 @@ class InferenceRouter(Inference):
|
|||
"""
|
||||
span = get_current_span()
|
||||
if span is None:
|
||||
logger.warning("No span found for token usage metrics")
|
||||
log.warning("No span found for token usage metrics")
|
||||
return []
|
||||
metrics = [
|
||||
("prompt_tokens", prompt_tokens),
|
||||
|
@ -182,7 +182,7 @@ class InferenceRouter(Inference):
|
|||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||
logger.debug(
|
||||
log.debug(
|
||||
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
||||
)
|
||||
if sampling_params is None:
|
||||
|
@ -288,7 +288,7 @@ class InferenceRouter(Inference):
|
|||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> BatchChatCompletionResponse:
|
||||
logger.debug(
|
||||
log.debug(
|
||||
f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl(model_id)
|
||||
|
@ -313,7 +313,7 @@ class InferenceRouter(Inference):
|
|||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
logger.debug(
|
||||
log.debug(
|
||||
f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}",
|
||||
)
|
||||
model = await self.routing_table.get_model(model_id)
|
||||
|
@ -374,7 +374,7 @@ class InferenceRouter(Inference):
|
|||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> BatchCompletionResponse:
|
||||
logger.debug(
|
||||
log.debug(
|
||||
f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl(model_id)
|
||||
|
@ -388,7 +388,7 @@ class InferenceRouter(Inference):
|
|||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
logger.debug(f"InferenceRouter.embeddings: {model_id}")
|
||||
log.debug(f"InferenceRouter.embeddings: {model_id}")
|
||||
model = await self.routing_table.get_model(model_id)
|
||||
if model is None:
|
||||
raise ModelNotFoundError(model_id)
|
||||
|
@ -426,7 +426,7 @@ class InferenceRouter(Inference):
|
|||
prompt_logprobs: int | None = None,
|
||||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
logger.debug(
|
||||
log.debug(
|
||||
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
|
||||
)
|
||||
model_obj = await self.routing_table.get_model(model)
|
||||
|
@ -487,7 +487,7 @@ class InferenceRouter(Inference):
|
|||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
logger.debug(
|
||||
log.debug(
|
||||
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}",
|
||||
)
|
||||
model_obj = await self.routing_table.get_model(model)
|
||||
|
@ -558,7 +558,7 @@ class InferenceRouter(Inference):
|
|||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
logger.debug(
|
||||
log.debug(
|
||||
f"InferenceRouter.openai_embeddings: {model=}, input_type={type(input)}, {encoding_format=}, {dimensions=}",
|
||||
)
|
||||
model_obj = await self.routing_table.get_model(model)
|
||||
|
|
|
@ -14,7 +14,7 @@ from llama_stack.apis.shields import Shield
|
|||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class SafetyRouter(Safety):
|
||||
|
@ -22,15 +22,15 @@ class SafetyRouter(Safety):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logger.debug("Initializing SafetyRouter")
|
||||
log.debug("Initializing SafetyRouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.debug("SafetyRouter.initialize")
|
||||
log.debug("SafetyRouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logger.debug("SafetyRouter.shutdown")
|
||||
log.debug("SafetyRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def register_shield(
|
||||
|
@ -40,7 +40,7 @@ class SafetyRouter(Safety):
|
|||
provider_id: str | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> Shield:
|
||||
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
||||
log.debug(f"SafetyRouter.register_shield: {shield_id}")
|
||||
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
||||
|
||||
async def unregister_shield(self, identifier: str) -> None:
|
||||
|
@ -53,7 +53,7 @@ class SafetyRouter(Safety):
|
|||
messages: list[Message],
|
||||
params: dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
||||
log.debug(f"SafetyRouter.run_shield: {shield_id}")
|
||||
provider = await self.routing_table.get_provider_impl(shield_id)
|
||||
return await provider.run_shield(
|
||||
shield_id=shield_id,
|
||||
|
|
|
@ -22,7 +22,7 @@ from llama_stack.log import get_logger
|
|||
|
||||
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class ToolRuntimeRouter(ToolRuntime):
|
||||
|
@ -31,7 +31,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
self,
|
||||
routing_table: ToolGroupsRoutingTable,
|
||||
) -> None:
|
||||
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
|
||||
log.debug("Initializing ToolRuntimeRouter.RagToolImpl")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def query(
|
||||
|
@ -40,7 +40,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
vector_db_ids: list[str],
|
||||
query_config: RAGQueryConfig | None = None,
|
||||
) -> RAGQueryResult:
|
||||
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
|
||||
log.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
|
||||
provider = await self.routing_table.get_provider_impl("knowledge_search")
|
||||
return await provider.query(content, vector_db_ids, query_config)
|
||||
|
||||
|
@ -50,7 +50,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
vector_db_id: str,
|
||||
chunk_size_in_tokens: int = 512,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
log.debug(
|
||||
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl("insert_into_memory")
|
||||
|
@ -60,7 +60,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
self,
|
||||
routing_table: ToolGroupsRoutingTable,
|
||||
) -> None:
|
||||
logger.debug("Initializing ToolRuntimeRouter")
|
||||
log.debug("Initializing ToolRuntimeRouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
|
||||
|
@ -69,15 +69,15 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.debug("ToolRuntimeRouter.initialize")
|
||||
log.debug("ToolRuntimeRouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logger.debug("ToolRuntimeRouter.shutdown")
|
||||
log.debug("ToolRuntimeRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any:
|
||||
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
|
||||
log.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
|
||||
provider = await self.routing_table.get_provider_impl(tool_name)
|
||||
return await provider.invoke_tool(
|
||||
tool_name=tool_name,
|
||||
|
@ -87,5 +87,5 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
async def list_runtime_tools(
|
||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||
) -> ListToolsResponse:
|
||||
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
||||
log.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
||||
return await self.routing_table.list_tools(tool_group_id)
|
||||
|
|
|
@ -30,7 +30,7 @@ from llama_stack.apis.vector_io import (
|
|||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class VectorIORouter(VectorIO):
|
||||
|
@ -40,15 +40,15 @@ class VectorIORouter(VectorIO):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
logger.debug("Initializing VectorIORouter")
|
||||
log.debug("Initializing VectorIORouter")
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.debug("VectorIORouter.initialize")
|
||||
log.debug("VectorIORouter.initialize")
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logger.debug("VectorIORouter.shutdown")
|
||||
log.debug("VectorIORouter.shutdown")
|
||||
pass
|
||||
|
||||
async def _get_first_embedding_model(self) -> tuple[str, int] | None:
|
||||
|
@ -70,10 +70,10 @@ class VectorIORouter(VectorIO):
|
|||
raise ValueError(f"Embedding model {embedding_models[0].identifier} has no embedding dimension")
|
||||
return embedding_models[0].identifier, dimension
|
||||
else:
|
||||
logger.warning("No embedding models found in the routing table")
|
||||
log.warning("No embedding models found in the routing table")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting embedding models: {e}")
|
||||
log.error(f"Error getting embedding models: {e}")
|
||||
return None
|
||||
|
||||
async def register_vector_db(
|
||||
|
@ -85,7 +85,7 @@ class VectorIORouter(VectorIO):
|
|||
vector_db_name: str | None = None,
|
||||
provider_vector_db_id: str | None = None,
|
||||
) -> None:
|
||||
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
|
||||
log.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
|
||||
await self.routing_table.register_vector_db(
|
||||
vector_db_id,
|
||||
embedding_model,
|
||||
|
@ -101,7 +101,7 @@ class VectorIORouter(VectorIO):
|
|||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
log.debug(
|
||||
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl(vector_db_id)
|
||||
|
@ -113,7 +113,7 @@ class VectorIORouter(VectorIO):
|
|||
query: InterleavedContent,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
|
||||
log.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
|
||||
provider = await self.routing_table.get_provider_impl(vector_db_id)
|
||||
return await provider.query_chunks(vector_db_id, query, params)
|
||||
|
||||
|
@ -129,7 +129,7 @@ class VectorIORouter(VectorIO):
|
|||
embedding_dimension: int | None = None,
|
||||
provider_id: str | None = None,
|
||||
) -> VectorStoreObject:
|
||||
logger.debug(f"VectorIORouter.openai_create_vector_store: name={name}, provider_id={provider_id}")
|
||||
log.debug(f"VectorIORouter.openai_create_vector_store: name={name}, provider_id={provider_id}")
|
||||
|
||||
# If no embedding model is provided, use the first available one
|
||||
if embedding_model is None:
|
||||
|
@ -137,7 +137,7 @@ class VectorIORouter(VectorIO):
|
|||
if embedding_model_info is None:
|
||||
raise ValueError("No embedding model provided and no embedding models available in the system")
|
||||
embedding_model, embedding_dimension = embedding_model_info
|
||||
logger.info(f"No embedding model specified, using first available: {embedding_model}")
|
||||
log.info(f"No embedding model specified, using first available: {embedding_model}")
|
||||
|
||||
vector_db_id = f"vs_{uuid.uuid4()}"
|
||||
registered_vector_db = await self.routing_table.register_vector_db(
|
||||
|
@ -168,7 +168,7 @@ class VectorIORouter(VectorIO):
|
|||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
) -> VectorStoreListResponse:
|
||||
logger.debug(f"VectorIORouter.openai_list_vector_stores: limit={limit}")
|
||||
log.debug(f"VectorIORouter.openai_list_vector_stores: limit={limit}")
|
||||
# Route to default provider for now - could aggregate from all providers in the future
|
||||
# call retrieve on each vector dbs to get list of vector stores
|
||||
vector_dbs = await self.routing_table.get_all_with_type("vector_db")
|
||||
|
@ -179,7 +179,7 @@ class VectorIORouter(VectorIO):
|
|||
vector_store = await provider.openai_retrieve_vector_store(vector_db.identifier)
|
||||
all_stores.append(vector_store)
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector store {vector_db.identifier}: {e}")
|
||||
log.error(f"Error retrieving vector store {vector_db.identifier}: {e}")
|
||||
continue
|
||||
|
||||
# Sort by created_at
|
||||
|
@ -215,7 +215,7 @@ class VectorIORouter(VectorIO):
|
|||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreObject:
|
||||
logger.debug(f"VectorIORouter.openai_retrieve_vector_store: {vector_store_id}")
|
||||
log.debug(f"VectorIORouter.openai_retrieve_vector_store: {vector_store_id}")
|
||||
return await self.routing_table.openai_retrieve_vector_store(vector_store_id)
|
||||
|
||||
async def openai_update_vector_store(
|
||||
|
@ -225,7 +225,7 @@ class VectorIORouter(VectorIO):
|
|||
expires_after: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> VectorStoreObject:
|
||||
logger.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}")
|
||||
log.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}")
|
||||
return await self.routing_table.openai_update_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
name=name,
|
||||
|
@ -237,7 +237,7 @@ class VectorIORouter(VectorIO):
|
|||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreDeleteResponse:
|
||||
logger.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}")
|
||||
log.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}")
|
||||
return await self.routing_table.openai_delete_vector_store(vector_store_id)
|
||||
|
||||
async def openai_search_vector_store(
|
||||
|
@ -250,7 +250,7 @@ class VectorIORouter(VectorIO):
|
|||
rewrite_query: bool | None = False,
|
||||
search_mode: str | None = "vector",
|
||||
) -> VectorStoreSearchResponsePage:
|
||||
logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}")
|
||||
log.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}")
|
||||
return await self.routing_table.openai_search_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
query=query,
|
||||
|
@ -268,7 +268,7 @@ class VectorIORouter(VectorIO):
|
|||
attributes: dict[str, Any] | None = None,
|
||||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||
) -> VectorStoreFileObject:
|
||||
logger.debug(f"VectorIORouter.openai_attach_file_to_vector_store: {vector_store_id}, {file_id}")
|
||||
log.debug(f"VectorIORouter.openai_attach_file_to_vector_store: {vector_store_id}, {file_id}")
|
||||
return await self.routing_table.openai_attach_file_to_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
|
@ -285,7 +285,7 @@ class VectorIORouter(VectorIO):
|
|||
before: str | None = None,
|
||||
filter: VectorStoreFileStatus | None = None,
|
||||
) -> list[VectorStoreFileObject]:
|
||||
logger.debug(f"VectorIORouter.openai_list_files_in_vector_store: {vector_store_id}")
|
||||
log.debug(f"VectorIORouter.openai_list_files_in_vector_store: {vector_store_id}")
|
||||
return await self.routing_table.openai_list_files_in_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
limit=limit,
|
||||
|
@ -300,7 +300,7 @@ class VectorIORouter(VectorIO):
|
|||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileObject:
|
||||
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file: {vector_store_id}, {file_id}")
|
||||
log.debug(f"VectorIORouter.openai_retrieve_vector_store_file: {vector_store_id}, {file_id}")
|
||||
return await self.routing_table.openai_retrieve_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
|
@ -311,7 +311,7 @@ class VectorIORouter(VectorIO):
|
|||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileContentsResponse:
|
||||
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}")
|
||||
log.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}")
|
||||
return await self.routing_table.openai_retrieve_vector_store_file_contents(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
|
@ -323,7 +323,7 @@ class VectorIORouter(VectorIO):
|
|||
file_id: str,
|
||||
attributes: dict[str, Any],
|
||||
) -> VectorStoreFileObject:
|
||||
logger.debug(f"VectorIORouter.openai_update_vector_store_file: {vector_store_id}, {file_id}")
|
||||
log.debug(f"VectorIORouter.openai_update_vector_store_file: {vector_store_id}, {file_id}")
|
||||
return await self.routing_table.openai_update_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
|
@ -335,7 +335,7 @@ class VectorIORouter(VectorIO):
|
|||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileDeleteResponse:
|
||||
logger.debug(f"VectorIORouter.openai_delete_vector_store_file: {vector_store_id}, {file_id}")
|
||||
log.debug(f"VectorIORouter.openai_delete_vector_store_file: {vector_store_id}, {file_id}")
|
||||
return await self.routing_table.openai_delete_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
|
|
|
@ -14,7 +14,7 @@ from llama_stack.log import get_logger
|
|||
|
||||
from .common import CommonRoutingTableImpl
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
||||
|
|
|
@ -23,7 +23,7 @@ from llama_stack.core.store import DistributionRegistry
|
|||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
def get_impl_api(p: Any) -> Api:
|
||||
|
@ -177,7 +177,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
|
||||
# Check if user has permission to access this object
|
||||
if not is_action_allowed(self.policy, "read", obj, get_authenticated_user()):
|
||||
logger.debug(f"Access denied to {type} '{identifier}'")
|
||||
log.debug(f"Access denied to {type} '{identifier}'")
|
||||
return None
|
||||
|
||||
return obj
|
||||
|
@ -205,7 +205,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
raise AccessDeniedError("create", obj, creator)
|
||||
if creator:
|
||||
obj.owner = creator
|
||||
logger.info(f"Setting owner for {obj.type} '{obj.identifier}' to {obj.owner.principal}")
|
||||
log.info(f"Setting owner for {obj.type} '{obj.identifier}' to {obj.owner.principal}")
|
||||
|
||||
registered_obj = await register_object_with_provider(obj, p)
|
||||
# TODO: This needs to be fixed for all APIs once they return the registered object
|
||||
|
@ -250,7 +250,7 @@ async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) ->
|
|||
if model is not None:
|
||||
return model
|
||||
|
||||
logger.warning(
|
||||
log.warning(
|
||||
f"WARNING: model identifier '{model_id}' not found in routing table. Falling back to "
|
||||
"searching in all providers. This is only for backwards compatibility and will stop working "
|
||||
"soon. Migrate your calls to use fully scoped `provider_id/model_id` names."
|
||||
|
|
|
@ -26,7 +26,7 @@ from llama_stack.log import get_logger
|
|||
|
||||
from .common import CommonRoutingTableImpl
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||
|
|
|
@ -17,7 +17,7 @@ from llama_stack.log import get_logger
|
|||
|
||||
from .common import CommonRoutingTableImpl, lookup_model
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||
|
@ -33,7 +33,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
try:
|
||||
models = await provider.list_models()
|
||||
except Exception as e:
|
||||
logger.exception(f"Model refresh failed for provider {provider_id}: {e}")
|
||||
log.exception(f"Model refresh failed for provider {provider_id}: {e}")
|
||||
continue
|
||||
|
||||
self.listed_providers.add(provider_id)
|
||||
|
@ -132,7 +132,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
model_ids[model.provider_resource_id] = model.identifier
|
||||
continue
|
||||
|
||||
logger.debug(f"unregistering model {model.identifier}")
|
||||
log.debug(f"unregistering model {model.identifier}")
|
||||
await self.unregister_object(model)
|
||||
|
||||
for model in models:
|
||||
|
@ -143,7 +143,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
if model.identifier == model.provider_resource_id:
|
||||
model.identifier = f"{provider_id}/{model.provider_resource_id}"
|
||||
|
||||
logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})")
|
||||
log.debug(f"registering model {model.identifier} ({model.provider_resource_id})")
|
||||
await self.register_object(
|
||||
ModelWithOwner(
|
||||
identifier=model.identifier,
|
||||
|
|
|
@ -19,7 +19,7 @@ from llama_stack.log import get_logger
|
|||
|
||||
from .common import CommonRoutingTableImpl
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||
|
|
|
@ -15,7 +15,7 @@ from llama_stack.log import get_logger
|
|||
|
||||
from .common import CommonRoutingTableImpl
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||
|
|
|
@ -14,7 +14,7 @@ from llama_stack.log import get_logger
|
|||
|
||||
from .common import CommonRoutingTableImpl
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None:
|
||||
|
|
|
@ -30,7 +30,7 @@ from llama_stack.log import get_logger
|
|||
|
||||
from .common import CommonRoutingTableImpl, lookup_model
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||
|
@ -57,7 +57,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
if len(self.impls_by_provider_id) > 0:
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
if len(self.impls_by_provider_id) > 1:
|
||||
logger.warning(
|
||||
log.warning(
|
||||
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
|
||||
)
|
||||
else:
|
||||
|
|
|
@ -15,7 +15,7 @@ from llama_stack.core.server.auth_providers import create_auth_provider
|
|||
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="auth")
|
||||
log = get_logger(name=__name__, category="auth")
|
||||
|
||||
|
||||
class AuthenticationMiddleware:
|
||||
|
@ -105,13 +105,13 @@ class AuthenticationMiddleware:
|
|||
try:
|
||||
validation_result = await self.auth_provider.validate_token(token, scope)
|
||||
except httpx.TimeoutException:
|
||||
logger.exception("Authentication request timed out")
|
||||
log.exception("Authentication request timed out")
|
||||
return await self._send_auth_error(send, "Authentication service timeout")
|
||||
except ValueError as e:
|
||||
logger.exception("Error during authentication")
|
||||
log.exception("Error during authentication")
|
||||
return await self._send_auth_error(send, str(e))
|
||||
except Exception:
|
||||
logger.exception("Error during authentication")
|
||||
log.exception("Error during authentication")
|
||||
return await self._send_auth_error(send, "Authentication service error")
|
||||
|
||||
# Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware)
|
||||
|
@ -122,7 +122,7 @@ class AuthenticationMiddleware:
|
|||
scope["principal"] = validation_result.principal
|
||||
if validation_result.attributes:
|
||||
scope["user_attributes"] = validation_result.attributes
|
||||
logger.debug(
|
||||
log.debug(
|
||||
f"Authentication successful: {validation_result.principal} with {len(validation_result.attributes)} attributes"
|
||||
)
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ from llama_stack.core.datatypes import (
|
|||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="auth")
|
||||
log = get_logger(name=__name__, category="auth")
|
||||
|
||||
|
||||
class AuthResponse(BaseModel):
|
||||
|
@ -163,7 +163,7 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
|||
timeout=10.0, # Add a reasonable timeout
|
||||
)
|
||||
if response.status_code != 200:
|
||||
logger.warning(f"Token introspection failed with status code: {response.status_code}")
|
||||
log.warning(f"Token introspection failed with status code: {response.status_code}")
|
||||
raise ValueError(f"Token introspection failed: {response.status_code}")
|
||||
|
||||
fields = response.json()
|
||||
|
@ -176,13 +176,13 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
|||
attributes=access_attributes,
|
||||
)
|
||||
except httpx.TimeoutException:
|
||||
logger.exception("Token introspection request timed out")
|
||||
log.exception("Token introspection request timed out")
|
||||
raise
|
||||
except ValueError:
|
||||
# Re-raise ValueError exceptions to preserve their message
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Error during token introspection")
|
||||
log.exception("Error during token introspection")
|
||||
raise ValueError("Token introspection error") from e
|
||||
|
||||
async def close(self):
|
||||
|
@ -273,7 +273,7 @@ class CustomAuthProvider(AuthProvider):
|
|||
timeout=10.0, # Add a reasonable timeout
|
||||
)
|
||||
if response.status_code != 200:
|
||||
logger.warning(f"Authentication failed with status code: {response.status_code}")
|
||||
log.warning(f"Authentication failed with status code: {response.status_code}")
|
||||
raise ValueError(f"Authentication failed: {response.status_code}")
|
||||
|
||||
# Parse and validate the auth response
|
||||
|
@ -282,17 +282,17 @@ class CustomAuthProvider(AuthProvider):
|
|||
auth_response = AuthResponse(**response_data)
|
||||
return User(principal=auth_response.principal, attributes=auth_response.attributes)
|
||||
except Exception as e:
|
||||
logger.exception("Error parsing authentication response")
|
||||
log.exception("Error parsing authentication response")
|
||||
raise ValueError("Invalid authentication response format") from e
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.exception("Authentication request timed out")
|
||||
log.exception("Authentication request timed out")
|
||||
raise
|
||||
except ValueError:
|
||||
# Re-raise ValueError exceptions to preserve their message
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Error during authentication")
|
||||
log.exception("Error during authentication")
|
||||
raise ValueError("Authentication service error") from e
|
||||
|
||||
async def close(self):
|
||||
|
@ -329,7 +329,7 @@ class GitHubTokenAuthProvider(AuthProvider):
|
|||
try:
|
||||
user_info = await _get_github_user_info(token, self.config.github_api_base_url)
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.warning(f"GitHub token validation failed: {e}")
|
||||
log.warning(f"GitHub token validation failed: {e}")
|
||||
raise ValueError("GitHub token validation failed. Please check your token and try again.") from e
|
||||
|
||||
principal = user_info["user"]["login"]
|
||||
|
|
|
@ -15,7 +15,7 @@ from llama_stack.providers.utils.kvstore.api import KVStore
|
|||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl
|
||||
|
||||
logger = get_logger(name=__name__, category="quota")
|
||||
log = get_logger(name=__name__, category="quota")
|
||||
|
||||
|
||||
class QuotaMiddleware:
|
||||
|
@ -46,7 +46,7 @@ class QuotaMiddleware:
|
|||
self.window_seconds = window_seconds
|
||||
|
||||
if isinstance(self.kv_config, SqliteKVStoreConfig):
|
||||
logger.warning(
|
||||
log.warning(
|
||||
"QuotaMiddleware: Using SQLite backend. Expiry/TTL is not enforced; cleanup is manual. "
|
||||
f"window_seconds={self.window_seconds}"
|
||||
)
|
||||
|
@ -84,11 +84,11 @@ class QuotaMiddleware:
|
|||
else:
|
||||
await kv.set(key, str(count))
|
||||
except Exception:
|
||||
logger.exception("Failed to access KV store for quota")
|
||||
log.exception("Failed to access KV store for quota")
|
||||
return await self._send_error(send, 500, "Quota service error")
|
||||
|
||||
if count > limit:
|
||||
logger.warning(
|
||||
log.warning(
|
||||
"Quota exceeded for client %s: %d/%d",
|
||||
key_id,
|
||||
count,
|
||||
|
|
|
@ -9,7 +9,7 @@ import asyncio
|
|||
import functools
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import logging # allow-direct-logging
|
||||
import os
|
||||
import ssl
|
||||
import sys
|
||||
|
@ -80,7 +80,7 @@ from .quota import QuotaMiddleware
|
|||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
logger = get_logger(name=__name__, category="server")
|
||||
log = get_logger(name=__name__, category="server")
|
||||
|
||||
|
||||
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
|
||||
|
@ -157,9 +157,9 @@ async def shutdown(app):
|
|||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
logger.info("Starting up")
|
||||
log.info("Starting up")
|
||||
yield
|
||||
logger.info("Shutting down")
|
||||
log.info("Shutting down")
|
||||
await shutdown(app)
|
||||
|
||||
|
||||
|
@ -182,11 +182,11 @@ async def sse_generator(event_gen_coroutine):
|
|||
yield create_sse_event(item)
|
||||
await asyncio.sleep(0.01)
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Generator cancelled")
|
||||
log.info("Generator cancelled")
|
||||
if event_gen:
|
||||
await event_gen.aclose()
|
||||
except Exception as e:
|
||||
logger.exception("Error in sse_generator")
|
||||
log.exception("Error in sse_generator")
|
||||
yield create_sse_event(
|
||||
{
|
||||
"error": {
|
||||
|
@ -206,11 +206,11 @@ async def log_request_pre_validation(request: Request):
|
|||
log_output = rich.pretty.pretty_repr(parsed_body)
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
log_output = repr(body_bytes)
|
||||
logger.debug(f"Incoming raw request body for {request.method} {request.url.path}:\n{log_output}")
|
||||
log.debug(f"Incoming raw request body for {request.method} {request.url.path}:\n{log_output}")
|
||||
else:
|
||||
logger.debug(f"Incoming {request.method} {request.url.path} request with empty body.")
|
||||
log.debug(f"Incoming {request.method} {request.url.path} request with empty body.")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not read or log request body for {request.method} {request.url.path}: {e}")
|
||||
log.warning(f"Could not read or log request body for {request.method} {request.url.path}: {e}")
|
||||
|
||||
|
||||
def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
||||
|
@ -238,10 +238,10 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
|||
result.url = route
|
||||
return result
|
||||
except Exception as e:
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.exception(f"Error executing endpoint {route=} {method=}")
|
||||
if log.isEnabledFor(logging.DEBUG):
|
||||
log.exception(f"Error executing endpoint {route=} {method=}")
|
||||
else:
|
||||
logger.error(f"Error executing endpoint {route=} {method=}: {str(e)}")
|
||||
log.error(f"Error executing endpoint {route=} {method=}: {str(e)}")
|
||||
raise translate_exception(e) from e
|
||||
|
||||
sig = inspect.signature(func)
|
||||
|
@ -291,7 +291,7 @@ class TracingMiddleware:
|
|||
# Check if the path is a FastAPI built-in path
|
||||
if path.startswith(self.fastapi_paths):
|
||||
# Pass through to FastAPI's built-in handlers
|
||||
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
|
||||
log.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
if not hasattr(self, "route_impls"):
|
||||
|
@ -303,7 +303,7 @@ class TracingMiddleware:
|
|||
)
|
||||
except ValueError:
|
||||
# If no matching endpoint is found, pass through to FastAPI
|
||||
logger.debug(f"No matching route found for path: {path}, falling back to FastAPI")
|
||||
log.debug(f"No matching route found for path: {path}, falling back to FastAPI")
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
trace_attributes = {"__location__": "server", "raw_path": path}
|
||||
|
@ -404,15 +404,15 @@ def main(args: argparse.Namespace | None = None):
|
|||
config_contents = yaml.safe_load(fp)
|
||||
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
||||
logger_config = LoggingConfig(**cfg)
|
||||
logger = get_logger(name=__name__, category="server", config=logger_config)
|
||||
log = get_logger(name=__name__, category="server", config=logger_config)
|
||||
if args.env:
|
||||
for env_pair in args.env:
|
||||
try:
|
||||
key, value = validate_env_pair(env_pair)
|
||||
logger.info(f"Setting CLI environment variable {key} => {value}")
|
||||
log.info(f"Setting CLI environment variable {key} => {value}")
|
||||
os.environ[key] = value
|
||||
except ValueError as e:
|
||||
logger.error(f"Error: {str(e)}")
|
||||
log.error(f"Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
config = replace_env_vars(config_contents)
|
||||
config = StackRunConfig(**cast_image_name_to_string(config))
|
||||
|
@ -438,16 +438,16 @@ def main(args: argparse.Namespace | None = None):
|
|||
impls = loop.run_until_complete(construct_stack(config))
|
||||
|
||||
except InvalidProviderError as e:
|
||||
logger.error(f"Error: {str(e)}")
|
||||
log.error(f"Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
if config.server.auth:
|
||||
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
|
||||
log.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth, impls=impls)
|
||||
else:
|
||||
if config.server.quota:
|
||||
quota = config.server.quota
|
||||
logger.warning(
|
||||
log.warning(
|
||||
"Configured authenticated_max_requests (%d) but no auth is enabled; "
|
||||
"falling back to anonymous_max_requests (%d) for all the requests",
|
||||
quota.authenticated_max_requests,
|
||||
|
@ -455,7 +455,7 @@ def main(args: argparse.Namespace | None = None):
|
|||
)
|
||||
|
||||
if config.server.quota:
|
||||
logger.info("Enabling quota middleware for authenticated and anonymous clients")
|
||||
log.info("Enabling quota middleware for authenticated and anonymous clients")
|
||||
|
||||
quota = config.server.quota
|
||||
anonymous_max_requests = quota.anonymous_max_requests
|
||||
|
@ -516,7 +516,7 @@ def main(args: argparse.Namespace | None = None):
|
|||
if not available_methods:
|
||||
raise ValueError(f"No methods found for {route.name} on {impl}")
|
||||
method = available_methods[0]
|
||||
logger.debug(f"{method} {route.path}")
|
||||
log.debug(f"{method} {route.path}")
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
|
||||
|
@ -528,7 +528,7 @@ def main(args: argparse.Namespace | None = None):
|
|||
)
|
||||
)
|
||||
|
||||
logger.debug(f"serving APIs: {apis_to_serve}")
|
||||
log.debug(f"serving APIs: {apis_to_serve}")
|
||||
|
||||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||
app.exception_handler(Exception)(global_exception_handler)
|
||||
|
@ -553,21 +553,21 @@ def main(args: argparse.Namespace | None = None):
|
|||
if config.server.tls_cafile:
|
||||
ssl_config["ssl_ca_certs"] = config.server.tls_cafile
|
||||
ssl_config["ssl_cert_reqs"] = ssl.CERT_REQUIRED
|
||||
logger.info(
|
||||
log.info(
|
||||
f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}\n CA: {config.server.tls_cafile}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
|
||||
log.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
|
||||
|
||||
listen_host = config.server.host or ["::", "0.0.0.0"]
|
||||
logger.info(f"Listening on {listen_host}:{port}")
|
||||
log.info(f"Listening on {listen_host}:{port}")
|
||||
|
||||
uvicorn_config = {
|
||||
"app": app,
|
||||
"host": listen_host,
|
||||
"port": port,
|
||||
"lifespan": "on",
|
||||
"log_level": logger.getEffectiveLevel(),
|
||||
"log_level": log.getEffectiveLevel(),
|
||||
"log_config": logger_config,
|
||||
}
|
||||
if ssl_config:
|
||||
|
@ -586,19 +586,19 @@ def main(args: argparse.Namespace | None = None):
|
|||
try:
|
||||
loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
logger.info("Received interrupt signal, shutting down gracefully...")
|
||||
log.info("Received interrupt signal, shutting down gracefully...")
|
||||
finally:
|
||||
if not loop.is_closed():
|
||||
logger.debug("Closing event loop")
|
||||
log.debug("Closing event loop")
|
||||
loop.close()
|
||||
|
||||
|
||||
def _log_run_config(run_config: StackRunConfig):
|
||||
"""Logs the run config with redacted fields and disabled providers removed."""
|
||||
logger.info("Run configuration:")
|
||||
log.info("Run configuration:")
|
||||
safe_config = redact_sensitive_fields(run_config.model_dump(mode="json"))
|
||||
clean_config = remove_disabled_providers(safe_config)
|
||||
logger.info(yaml.dump(clean_config, indent=2))
|
||||
log.info(yaml.dump(clean_config, indent=2))
|
||||
|
||||
|
||||
def extract_path_params(route: str) -> list[str]:
|
||||
|
|
|
@ -45,7 +45,7 @@ from llama_stack.core.utils.dynamic import instantiate_class_type
|
|||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class LlamaStack(
|
||||
|
@ -105,11 +105,11 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
|||
|
||||
method = getattr(impls[api], register_method)
|
||||
for obj in objects:
|
||||
logger.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}")
|
||||
log.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}")
|
||||
|
||||
# Do not register models on disabled providers
|
||||
if hasattr(obj, "provider_id") and (not obj.provider_id or obj.provider_id == "__disabled__"):
|
||||
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.")
|
||||
log.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.")
|
||||
continue
|
||||
|
||||
# we want to maintain the type information in arguments to method.
|
||||
|
@ -123,7 +123,7 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
|||
objects_to_process = response.data if hasattr(response, "data") else response
|
||||
|
||||
for obj in objects_to_process:
|
||||
logger.debug(
|
||||
log.debug(
|
||||
f"{rsrc.capitalize()}: {obj.identifier} served by {obj.provider_id}",
|
||||
)
|
||||
|
||||
|
@ -160,7 +160,7 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
|
|||
try:
|
||||
resolved_provider_id = replace_env_vars(v["provider_id"], f"{path}[{i}].provider_id")
|
||||
if resolved_provider_id == "__disabled__":
|
||||
logger.debug(
|
||||
log.debug(
|
||||
f"Skipping config env variable expansion for disabled provider: {v.get('provider_id', '')}"
|
||||
)
|
||||
# Create a copy with resolved provider_id but original config
|
||||
|
@ -315,7 +315,7 @@ async def construct_stack(
|
|||
TEST_RECORDING_CONTEXT = setup_inference_recording()
|
||||
if TEST_RECORDING_CONTEXT:
|
||||
TEST_RECORDING_CONTEXT.__enter__()
|
||||
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
|
||||
log.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
|
||||
|
||||
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
|
||||
policy = run_config.server.auth.access_policy if run_config.server.auth else []
|
||||
|
@ -337,12 +337,12 @@ async def construct_stack(
|
|||
import traceback
|
||||
|
||||
if task.cancelled():
|
||||
logger.error("Model refresh task cancelled")
|
||||
log.error("Model refresh task cancelled")
|
||||
elif task.exception():
|
||||
logger.error(f"Model refresh task failed: {task.exception()}")
|
||||
log.error(f"Model refresh task failed: {task.exception()}")
|
||||
traceback.print_exception(task.exception())
|
||||
else:
|
||||
logger.debug("Model refresh task completed")
|
||||
log.debug("Model refresh task completed")
|
||||
|
||||
REGISTRY_REFRESH_TASK.add_done_callback(cb)
|
||||
return impls
|
||||
|
@ -351,23 +351,23 @@ async def construct_stack(
|
|||
async def shutdown_stack(impls: dict[Api, Any]):
|
||||
for impl in impls.values():
|
||||
impl_name = impl.__class__.__name__
|
||||
logger.info(f"Shutting down {impl_name}")
|
||||
log.info(f"Shutting down {impl_name}")
|
||||
try:
|
||||
if hasattr(impl, "shutdown"):
|
||||
await asyncio.wait_for(impl.shutdown(), timeout=5)
|
||||
else:
|
||||
logger.warning(f"No shutdown method for {impl_name}")
|
||||
log.warning(f"No shutdown method for {impl_name}")
|
||||
except TimeoutError:
|
||||
logger.exception(f"Shutdown timeout for {impl_name}")
|
||||
log.exception(f"Shutdown timeout for {impl_name}")
|
||||
except (Exception, asyncio.CancelledError) as e:
|
||||
logger.exception(f"Failed to shutdown {impl_name}: {e}")
|
||||
log.exception(f"Failed to shutdown {impl_name}: {e}")
|
||||
|
||||
global TEST_RECORDING_CONTEXT
|
||||
if TEST_RECORDING_CONTEXT:
|
||||
try:
|
||||
TEST_RECORDING_CONTEXT.__exit__(None, None, None)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during inference recording cleanup: {e}")
|
||||
log.error(f"Error during inference recording cleanup: {e}")
|
||||
|
||||
global REGISTRY_REFRESH_TASK
|
||||
if REGISTRY_REFRESH_TASK:
|
||||
|
@ -375,14 +375,14 @@ async def shutdown_stack(impls: dict[Api, Any]):
|
|||
|
||||
|
||||
async def refresh_registry_once(impls: dict[Api, Any]):
|
||||
logger.debug("refreshing registry")
|
||||
log.debug("refreshing registry")
|
||||
routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)]
|
||||
for routing_table in routing_tables:
|
||||
await routing_table.refresh()
|
||||
|
||||
|
||||
async def refresh_registry_task(impls: dict[Api, Any]):
|
||||
logger.info("starting registry refresh task")
|
||||
log.info("starting registry refresh task")
|
||||
while True:
|
||||
await refresh_registry_once(impls)
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ from pathlib import Path
|
|||
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="config_resolution")
|
||||
log = get_logger(name=__name__, category="config_resolution")
|
||||
|
||||
|
||||
DISTRO_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "distributions"
|
||||
|
@ -42,25 +42,25 @@ def resolve_config_or_distro(
|
|||
# Strategy 1: Try as file path first
|
||||
config_path = Path(config_or_distro)
|
||||
if config_path.exists() and config_path.is_file():
|
||||
logger.info(f"Using file path: {config_path}")
|
||||
log.info(f"Using file path: {config_path}")
|
||||
return config_path.resolve()
|
||||
|
||||
# Strategy 2: Try as distribution name (if no .yaml extension)
|
||||
if not config_or_distro.endswith(".yaml"):
|
||||
distro_config = _get_distro_config_path(config_or_distro, mode)
|
||||
if distro_config.exists():
|
||||
logger.info(f"Using distribution: {distro_config}")
|
||||
log.info(f"Using distribution: {distro_config}")
|
||||
return distro_config
|
||||
|
||||
# Strategy 3: Try as built distribution name
|
||||
distrib_config = DISTRIBS_BASE_DIR / f"llamastack-{config_or_distro}" / f"{config_or_distro}-{mode}.yaml"
|
||||
if distrib_config.exists():
|
||||
logger.info(f"Using built distribution: {distrib_config}")
|
||||
log.info(f"Using built distribution: {distrib_config}")
|
||||
return distrib_config
|
||||
|
||||
distrib_config = DISTRIBS_BASE_DIR / f"{config_or_distro}" / f"{config_or_distro}-{mode}.yaml"
|
||||
if distrib_config.exists():
|
||||
logger.info(f"Using built distribution: {distrib_config}")
|
||||
log.info(f"Using built distribution: {distrib_config}")
|
||||
return distrib_config
|
||||
|
||||
# Strategy 4: Failed - provide helpful error
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import importlib
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
|
@ -12,9 +12,9 @@ import sys
|
|||
|
||||
from termcolor import cprint
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
import importlib
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
def formulate_run_args(image_type: str, image_name: str) -> list:
|
||||
|
|
|
@ -6,7 +6,6 @@
|
|||
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any, Literal, Union, get_args, get_origin
|
||||
|
||||
|
@ -14,7 +13,9 @@ from pydantic import BaseModel
|
|||
from pydantic.fields import FieldInfo
|
||||
from pydantic_core import PydanticUndefinedType
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
def is_list_of_primitives(field_type):
|
||||
|
|
|
@ -4,11 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import logging # allow-direct-logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from logging.config import dictConfig
|
||||
from logging.config import dictConfig # allow-direct-logging
|
||||
|
||||
from rich.console import Console
|
||||
from rich.errors import MarkupError
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
|
||||
# Copyright (c) Meta Platforms, Inc. and its affiliates.
|
||||
import math
|
||||
from logging import getLogger
|
||||
from logging import getLogger # allow-direct-logging
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from logging import getLogger
|
||||
from logging import getLogger # allow-direct-logging
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import math
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
|
@ -22,6 +21,8 @@ from PIL import Image as PIL_Image
|
|||
from torch import Tensor, nn
|
||||
from torch.distributed import _functional_collectives as funcol
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from ..model import ModelArgs, RMSNorm, apply_rotary_emb, precompute_freqs_cis
|
||||
from .encoder_utils import (
|
||||
build_encoder_attention_mask,
|
||||
|
@ -34,9 +35,10 @@ from .encoder_utils import (
|
|||
from .image_transform import VariableSizeImageTransform
|
||||
from .utils import get_negative_inf_value, to_2tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
MP_SCALE = 8
|
||||
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
def reduce_from_tensor_model_parallel_region(input_):
|
||||
"""All-reduce the input tensor across model parallel group."""
|
||||
|
@ -415,7 +417,7 @@ class VisionEncoder(nn.Module):
|
|||
)
|
||||
state_dict[prefix + "gated_positional_embedding"] = global_pos_embed
|
||||
state_dict[prefix + "gated_positional_embedding_gate"] = torch.zeros(1, dtype=global_pos_embed.dtype)
|
||||
logger.info(f"Initialized global positional embedding with size {global_pos_embed.size()}")
|
||||
log.info(f"Initialized global positional embedding with size {global_pos_embed.size()}")
|
||||
else:
|
||||
global_pos_embed = resize_global_position_embedding(
|
||||
state_dict[prefix + "gated_positional_embedding"],
|
||||
|
@ -423,7 +425,7 @@ class VisionEncoder(nn.Module):
|
|||
self.max_num_tiles,
|
||||
self.max_num_tiles,
|
||||
)
|
||||
logger.info(
|
||||
log.info(
|
||||
f"Resized global positional embedding from {state_dict[prefix + 'gated_positional_embedding'].size()} to {global_pos_embed.size()}"
|
||||
)
|
||||
state_dict[prefix + "gated_positional_embedding"] = global_pos_embed
|
||||
|
@ -771,7 +773,7 @@ class TilePositionEmbedding(nn.Module):
|
|||
if embed is not None:
|
||||
# reshape the weights to the correct shape
|
||||
nt_old, nt_old, _, w = embed.shape
|
||||
logging.info(f"Resizing tile embedding from {nt_old}x{nt_old} to {self.num_tiles}x{self.num_tiles}")
|
||||
log.info(f"Resizing tile embedding from {nt_old}x{nt_old} to {self.num_tiles}x{self.num_tiles}")
|
||||
embed_new = TilePositionEmbedding._dynamic_resize(embed, self.num_tiles)
|
||||
# assign the weights to the module
|
||||
state_dict[prefix + "embedding"] = embed_new
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import Collection, Iterator, Sequence, Set
|
||||
from logging import getLogger
|
||||
from logging import getLogger # allow-direct-logging
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Literal,
|
||||
|
|
|
@ -11,7 +11,7 @@ from llama_stack.log import get_logger
|
|||
|
||||
from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
log = get_logger(name=__name__, category="inference")
|
||||
|
||||
BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
|
||||
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")
|
||||
|
@ -215,7 +215,7 @@ class ToolUtils:
|
|||
# FIXME: Enable multiple tool calls
|
||||
return function_calls[0]
|
||||
else:
|
||||
logger.debug(f"Did not parse tool call from message body: {message_body}")
|
||||
log.debug(f"Did not parse tool call from message body: {message_body}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
|
||||
|
@ -13,11 +12,13 @@ from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
|
|||
from torch import Tensor, nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from ...datatypes import QuantizationMode
|
||||
from ..model import Transformer, TransformerBlock
|
||||
from ..moe import MoE
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = get_logger(__name__, category="core")
|
||||
|
||||
|
||||
def swiglu_wrapper_no_reduce(
|
||||
|
@ -186,7 +187,7 @@ def logging_callbacks(
|
|||
if use_rich_progress:
|
||||
console.print(message)
|
||||
elif rank == 0: # Only log from rank 0 for non-rich logging
|
||||
log.info(message)
|
||||
logger.info(message)
|
||||
|
||||
total_blocks = sum(1 for _, block in model.named_modules() if should_quantize_block(block))
|
||||
progress = None
|
||||
|
@ -220,6 +221,6 @@ def logging_callbacks(
|
|||
if completed is not None:
|
||||
progress.update(task_id, completed=completed)
|
||||
elif rank == 0 and completed and completed % 10 == 0:
|
||||
log.info(f"Rank {rank}: {completed}/{total_blocks} blocks completed")
|
||||
logger.info(f"Rank {rank}: {completed}/{total_blocks} blocks completed")
|
||||
|
||||
return progress, log_status, update_status
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import Collection, Iterator, Sequence, Set
|
||||
from logging import getLogger
|
||||
from logging import getLogger # allow-direct-logging
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Literal,
|
||||
|
|
|
@ -6,16 +6,17 @@
|
|||
|
||||
# type: ignore
|
||||
import collections
|
||||
import logging
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(__name__, category="core")
|
||||
|
||||
try:
|
||||
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
||||
|
||||
log.info("Using efficient FP8 or INT4 operators in FBGEMM.")
|
||||
logger.info("Using efficient FP8 or INT4 operators in FBGEMM.")
|
||||
except ImportError:
|
||||
log.error("No efficient FP8 or INT4 operators. Please install FBGEMM.")
|
||||
logger.error("No efficient FP8 or INT4 operators. Please install FBGEMM.")
|
||||
raise
|
||||
|
||||
import torch
|
||||
|
|
|
@ -84,7 +84,7 @@ MEMORY_QUERY_TOOL = "knowledge_search"
|
|||
WEB_SEARCH_TOOL = "web_search"
|
||||
RAG_TOOL_GROUP = "builtin::rag"
|
||||
|
||||
logger = get_logger(name=__name__, category="agents")
|
||||
log = get_logger(name=__name__, category="agents")
|
||||
|
||||
|
||||
class ChatAgent(ShieldRunnerMixin):
|
||||
|
@ -612,7 +612,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
|
||||
if n_iter >= self.agent_config.max_infer_iters:
|
||||
logger.info(f"done with MAX iterations ({n_iter}), exiting.")
|
||||
log.info(f"done with MAX iterations ({n_iter}), exiting.")
|
||||
# NOTE: mark end_of_turn to indicate to client that we are done with the turn
|
||||
# Do not continue the tool call loop after this point
|
||||
message.stop_reason = StopReason.end_of_turn
|
||||
|
@ -620,7 +620,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
break
|
||||
|
||||
if stop_reason == StopReason.out_of_tokens:
|
||||
logger.info("out of token budget, exiting.")
|
||||
log.info("out of token budget, exiting.")
|
||||
yield message
|
||||
break
|
||||
|
||||
|
@ -634,7 +634,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
message.content = [message.content] + output_attachments
|
||||
yield message
|
||||
else:
|
||||
logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
|
||||
log.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
|
||||
input_messages = input_messages + [message]
|
||||
else:
|
||||
input_messages = input_messages + [message]
|
||||
|
@ -889,7 +889,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
else:
|
||||
tool_name_str = tool_name
|
||||
|
||||
logger.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}")
|
||||
log.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}")
|
||||
result = await self.tool_runtime_api.invoke_tool(
|
||||
tool_name=tool_name_str,
|
||||
kwargs={
|
||||
|
@ -899,7 +899,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
**self.tool_name_to_args.get(tool_name_str, {}),
|
||||
},
|
||||
)
|
||||
logger.debug(f"tool call {tool_name_str} completed with result: {result}")
|
||||
log.debug(f"tool call {tool_name_str} completed with result: {result}")
|
||||
return result
|
||||
|
||||
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import UTC, datetime
|
||||
|
@ -42,6 +41,7 @@ from llama_stack.apis.safety import Safety
|
|||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
||||
from llama_stack.providers.utils.pagination import paginate_records
|
||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||
|
@ -51,7 +51,7 @@ from .config import MetaReferenceAgentsImplConfig
|
|||
from .openai_responses import OpenAIResponsesImpl
|
||||
from .persistence import AgentInfo
|
||||
|
||||
logger = logging.getLogger()
|
||||
log = get_logger(name=__name__, category="agents")
|
||||
|
||||
|
||||
class MetaReferenceAgentsImpl(Agents):
|
||||
|
@ -268,7 +268,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
# Get the agent info using the key
|
||||
agent_info_json = await self.persistence_store.get(agent_key)
|
||||
if not agent_info_json:
|
||||
logger.error(f"Could not find agent info for key {agent_key}")
|
||||
log.error(f"Could not find agent info for key {agent_key}")
|
||||
continue
|
||||
|
||||
try:
|
||||
|
@ -281,7 +281,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing agent info for {agent_id}: {e}")
|
||||
log.error(f"Error parsing agent info for {agent_id}: {e}")
|
||||
continue
|
||||
|
||||
# Convert Agent objects to dictionaries
|
||||
|
|
|
@ -75,7 +75,7 @@ from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefiniti
|
|||
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||
|
||||
logger = get_logger(name=__name__, category="openai_responses")
|
||||
log = get_logger(name=__name__, category="openai_responses")
|
||||
|
||||
OPENAI_RESPONSES_PREFIX = "openai_responses:"
|
||||
|
||||
|
@ -544,12 +544,12 @@ class OpenAIResponsesImpl:
|
|||
break
|
||||
|
||||
if function_tool_calls:
|
||||
logger.info("Exiting inference loop since there is a function (client-side) tool call")
|
||||
log.info("Exiting inference loop since there is a function (client-side) tool call")
|
||||
break
|
||||
|
||||
n_iter += 1
|
||||
if n_iter >= max_infer_iters:
|
||||
logger.info(f"Exiting inference loop since iteration count({n_iter}) exceeds {max_infer_iters=}")
|
||||
log.info(f"Exiting inference loop since iteration count({n_iter}) exceeds {max_infer_iters=}")
|
||||
break
|
||||
|
||||
messages = next_turn_messages
|
||||
|
@ -698,7 +698,7 @@ class OpenAIResponsesImpl:
|
|||
)
|
||||
return search_response.data
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to search vector store {vector_store_id}: {e}")
|
||||
log.warning(f"Failed to search vector store {vector_store_id}: {e}")
|
||||
return []
|
||||
|
||||
# Run all searches in parallel using gather
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
|
@ -15,9 +14,10 @@ from llama_stack.core.access_control.access_control import AccessDeniedError, is
|
|||
from llama_stack.core.access_control.datatypes import AccessRule
|
||||
from llama_stack.core.datatypes import User
|
||||
from llama_stack.core.request_headers import get_authenticated_user
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="agents")
|
||||
|
||||
|
||||
class AgentSessionInfo(Session):
|
||||
|
|
|
@ -5,13 +5,13 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.telemetry import tracing
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="agents")
|
||||
|
||||
|
||||
class SafetyException(Exception): # noqa: N818
|
||||
|
|
|
@ -73,11 +73,12 @@ from .config import MetaReferenceInferenceConfig
|
|||
from .generators import LlamaGenerator
|
||||
from .model_parallel import LlamaModelParallelGenerator
|
||||
|
||||
log = get_logger(__name__, category="inference")
|
||||
# there's a single model parallel process running serving the model. for now,
|
||||
# we don't support multiple concurrent requests to this process.
|
||||
SEMAPHORE = asyncio.Semaphore(1)
|
||||
|
||||
logger = get_logger(__name__, category="inference")
|
||||
|
||||
|
||||
def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator:
|
||||
return LlamaGenerator(config, model_id, llama_model)
|
||||
|
@ -144,7 +145,7 @@ class MetaReferenceInferenceImpl(
|
|||
return model
|
||||
|
||||
async def load_model(self, model_id, llama_model) -> None:
|
||||
log.info(f"Loading model `{model_id}`")
|
||||
logger.info(f"Loading model `{model_id}`")
|
||||
|
||||
builder_params = [self.config, model_id, llama_model]
|
||||
|
||||
|
@ -166,7 +167,7 @@ class MetaReferenceInferenceImpl(
|
|||
self.model_id = model_id
|
||||
self.llama_model = llama_model
|
||||
|
||||
log.info("Warming up...")
|
||||
logger.info("Warming up...")
|
||||
await self.completion(
|
||||
model_id=model_id,
|
||||
content="Hello, world!",
|
||||
|
@ -177,7 +178,7 @@ class MetaReferenceInferenceImpl(
|
|||
messages=[UserMessage(content="Hi how are you?")],
|
||||
sampling_params=SamplingParams(max_tokens=20),
|
||||
)
|
||||
log.info("Warmed up!")
|
||||
logger.info("Warmed up!")
|
||||
|
||||
def check_model(self, request) -> None:
|
||||
if self.model_id is None or self.llama_model is None:
|
||||
|
|
|
@ -12,7 +12,6 @@
|
|||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import tempfile
|
||||
|
@ -32,13 +31,14 @@ from fairscale.nn.model_parallel.initialize import (
|
|||
from pydantic import BaseModel, Field
|
||||
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import GenerationResult
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
ChatCompletionRequestWithRawContent,
|
||||
CompletionRequestWithRawContent,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
class ProcessingMessageName(str, Enum):
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
|
@ -32,8 +31,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
|
||||
from .config import SentenceTransformersInferenceConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SentenceTransformersInferenceImpl(
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
|
|
|
@ -6,7 +6,6 @@
|
|||
|
||||
import gc
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
@ -28,6 +27,7 @@ from llama_stack.apis.post_training import (
|
|||
LoraFinetuningConfig,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
||||
|
||||
from ..config import HuggingFacePostTrainingConfig
|
||||
|
@ -44,7 +44,7 @@ from ..utils import (
|
|||
split_dataset,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class HFFinetuningSingleDevice:
|
||||
|
@ -69,14 +69,14 @@ class HFFinetuningSingleDevice:
|
|||
try:
|
||||
messages = json.loads(row["chat_completion_input"])
|
||||
if not isinstance(messages, list) or len(messages) != 1:
|
||||
logger.warning(f"Invalid chat_completion_input format: {row['chat_completion_input']}")
|
||||
log.warning(f"Invalid chat_completion_input format: {row['chat_completion_input']}")
|
||||
return None, None
|
||||
if "content" not in messages[0]:
|
||||
logger.warning(f"Message missing content: {messages[0]}")
|
||||
log.warning(f"Message missing content: {messages[0]}")
|
||||
return None, None
|
||||
return messages[0]["content"], row["expected_answer"]
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse chat_completion_input: {row['chat_completion_input']}")
|
||||
log.warning(f"Failed to parse chat_completion_input: {row['chat_completion_input']}")
|
||||
return None, None
|
||||
return None, None
|
||||
|
||||
|
@ -86,13 +86,13 @@ class HFFinetuningSingleDevice:
|
|||
try:
|
||||
dialog = json.loads(row["dialog"])
|
||||
if not isinstance(dialog, list) or len(dialog) < 2:
|
||||
logger.warning(f"Dialog must have at least 2 messages: {row['dialog']}")
|
||||
log.warning(f"Dialog must have at least 2 messages: {row['dialog']}")
|
||||
return None, None
|
||||
if dialog[0].get("role") != "user":
|
||||
logger.warning(f"First message must be from user: {dialog[0]}")
|
||||
log.warning(f"First message must be from user: {dialog[0]}")
|
||||
return None, None
|
||||
if not any(msg.get("role") == "assistant" for msg in dialog):
|
||||
logger.warning("Dialog must have at least one assistant message")
|
||||
log.warning("Dialog must have at least one assistant message")
|
||||
return None, None
|
||||
|
||||
# Convert to human/gpt format
|
||||
|
@ -100,14 +100,14 @@ class HFFinetuningSingleDevice:
|
|||
conversations = []
|
||||
for msg in dialog:
|
||||
if "role" not in msg or "content" not in msg:
|
||||
logger.warning(f"Message missing role or content: {msg}")
|
||||
log.warning(f"Message missing role or content: {msg}")
|
||||
continue
|
||||
conversations.append({"from": role_map[msg["role"]], "value": msg["content"]})
|
||||
|
||||
# Format as a single conversation
|
||||
return conversations[0]["value"], conversations[1]["value"]
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse dialog: {row['dialog']}")
|
||||
log.warning(f"Failed to parse dialog: {row['dialog']}")
|
||||
return None, None
|
||||
return None, None
|
||||
|
||||
|
@ -198,7 +198,7 @@ class HFFinetuningSingleDevice:
|
|||
"""
|
||||
import asyncio
|
||||
|
||||
logger.info("Starting training process with async wrapper")
|
||||
log.info("Starting training process with async wrapper")
|
||||
asyncio.run(
|
||||
self._run_training(
|
||||
model=model,
|
||||
|
@ -228,14 +228,14 @@ class HFFinetuningSingleDevice:
|
|||
raise ValueError("DataConfig is required for training")
|
||||
|
||||
# Load dataset
|
||||
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
|
||||
log.info(f"Loading dataset: {config.data_config.dataset_id}")
|
||||
rows = await load_rows_from_dataset(self.datasetio_api, config.data_config.dataset_id)
|
||||
if not self.validate_dataset_format(rows):
|
||||
raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input")
|
||||
logger.info(f"Loaded {len(rows)} rows from dataset")
|
||||
log.info(f"Loaded {len(rows)} rows from dataset")
|
||||
|
||||
# Initialize tokenizer
|
||||
logger.info(f"Initializing tokenizer for model: {model}")
|
||||
log.info(f"Initializing tokenizer for model: {model}")
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model, **provider_config.model_specific_config)
|
||||
|
||||
|
@ -257,16 +257,16 @@ class HFFinetuningSingleDevice:
|
|||
# This ensures consistent sequence lengths across the training process
|
||||
tokenizer.model_max_length = provider_config.max_seq_length
|
||||
|
||||
logger.info("Tokenizer initialized successfully")
|
||||
log.info("Tokenizer initialized successfully")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to initialize tokenizer: {str(e)}") from e
|
||||
|
||||
# Create and preprocess dataset
|
||||
logger.info("Creating and preprocessing dataset")
|
||||
log.info("Creating and preprocessing dataset")
|
||||
try:
|
||||
ds = self._create_dataset(rows, config, provider_config)
|
||||
ds = self._preprocess_dataset(ds, tokenizer, provider_config)
|
||||
logger.info(f"Dataset created with {len(ds)} examples")
|
||||
log.info(f"Dataset created with {len(ds)} examples")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to create dataset: {str(e)}") from e
|
||||
|
||||
|
@ -293,11 +293,11 @@ class HFFinetuningSingleDevice:
|
|||
Returns:
|
||||
Configured SFTConfig object
|
||||
"""
|
||||
logger.info("Configuring training arguments")
|
||||
log.info("Configuring training arguments")
|
||||
lr = 2e-5
|
||||
if config.optimizer_config:
|
||||
lr = config.optimizer_config.lr
|
||||
logger.info(f"Using custom learning rate: {lr}")
|
||||
log.info(f"Using custom learning rate: {lr}")
|
||||
|
||||
# Validate data config
|
||||
if not config.data_config:
|
||||
|
@ -350,17 +350,17 @@ class HFFinetuningSingleDevice:
|
|||
peft_config: Optional LoRA configuration
|
||||
output_dir_path: Path to save the model
|
||||
"""
|
||||
logger.info("Saving final model")
|
||||
log.info("Saving final model")
|
||||
model_obj.config.use_cache = True
|
||||
|
||||
if peft_config:
|
||||
logger.info("Merging LoRA weights with base model")
|
||||
log.info("Merging LoRA weights with base model")
|
||||
model_obj = trainer.model.merge_and_unload()
|
||||
else:
|
||||
model_obj = trainer.model
|
||||
|
||||
save_path = output_dir_path / "merged_model"
|
||||
logger.info(f"Saving model to {save_path}")
|
||||
log.info(f"Saving model to {save_path}")
|
||||
model_obj.save_pretrained(save_path)
|
||||
|
||||
async def _run_training(
|
||||
|
@ -380,13 +380,13 @@ class HFFinetuningSingleDevice:
|
|||
setup_signal_handlers()
|
||||
|
||||
# Convert config dicts back to objects
|
||||
logger.info("Initializing configuration objects")
|
||||
log.info("Initializing configuration objects")
|
||||
provider_config_obj = HuggingFacePostTrainingConfig(**provider_config)
|
||||
config_obj = TrainingConfig(**config)
|
||||
|
||||
# Initialize and validate device
|
||||
device = setup_torch_device(provider_config_obj.device)
|
||||
logger.info(f"Using device '{device}'")
|
||||
log.info(f"Using device '{device}'")
|
||||
|
||||
# Load dataset and tokenizer
|
||||
train_dataset, eval_dataset, tokenizer = await self.load_dataset(model, config_obj, provider_config_obj)
|
||||
|
@ -409,7 +409,7 @@ class HFFinetuningSingleDevice:
|
|||
model_obj = load_model(model, device, provider_config_obj)
|
||||
|
||||
# Initialize trainer
|
||||
logger.info("Initializing SFTTrainer")
|
||||
log.info("Initializing SFTTrainer")
|
||||
trainer = SFTTrainer(
|
||||
model=model_obj,
|
||||
train_dataset=train_dataset,
|
||||
|
@ -420,9 +420,9 @@ class HFFinetuningSingleDevice:
|
|||
|
||||
try:
|
||||
# Train
|
||||
logger.info("Starting training")
|
||||
log.info("Starting training")
|
||||
trainer.train()
|
||||
logger.info("Training completed successfully")
|
||||
log.info("Training completed successfully")
|
||||
|
||||
# Save final model if output directory is provided
|
||||
if output_dir_path:
|
||||
|
@ -430,12 +430,12 @@ class HFFinetuningSingleDevice:
|
|||
|
||||
finally:
|
||||
# Clean up resources
|
||||
logger.info("Cleaning up resources")
|
||||
log.info("Cleaning up resources")
|
||||
if hasattr(trainer, "model"):
|
||||
evacuate_model_from_device(trainer.model, device.type)
|
||||
del trainer
|
||||
gc.collect()
|
||||
logger.info("Cleanup completed")
|
||||
log.info("Cleanup completed")
|
||||
|
||||
async def train(
|
||||
self,
|
||||
|
@ -449,7 +449,7 @@ class HFFinetuningSingleDevice:
|
|||
"""Train a model using HuggingFace's SFTTrainer"""
|
||||
# Initialize and validate device
|
||||
device = setup_torch_device(provider_config.device)
|
||||
logger.info(f"Using device '{device}'")
|
||||
log.info(f"Using device '{device}'")
|
||||
|
||||
output_dir_path = None
|
||||
if output_dir:
|
||||
|
@ -479,7 +479,7 @@ class HFFinetuningSingleDevice:
|
|||
raise ValueError("DataConfig is required for training")
|
||||
|
||||
# Train in a separate process
|
||||
logger.info("Starting training in separate process")
|
||||
log.info("Starting training in separate process")
|
||||
try:
|
||||
# Setup multiprocessing for device
|
||||
if device.type in ["cuda", "mps"]:
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import multiprocessing
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
@ -24,6 +23,7 @@ from llama_stack.apis.post_training import (
|
|||
DPOAlignmentConfig,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
||||
|
||||
from ..config import HuggingFacePostTrainingConfig
|
||||
|
@ -40,7 +40,7 @@ from ..utils import (
|
|||
split_dataset,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(__name__, category="core")
|
||||
|
||||
|
||||
class HFDPOAlignmentSingleDevice:
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
|
@ -19,10 +18,11 @@ from transformers import AutoConfig, AutoModelForCausalLM
|
|||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.post_training import Checkpoint, TrainingConfig
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .config import HuggingFacePostTrainingConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(__name__, category="core")
|
||||
|
||||
|
||||
def setup_environment():
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
|
@ -19,6 +18,7 @@ from torch.utils.data import DataLoader, DistributedSampler
|
|||
from torchtune import modules, training
|
||||
from torchtune import utils as torchtune_utils
|
||||
from torchtune.data import padded_collate_sft
|
||||
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
||||
from torchtune.modules.loss import CEWithChunkedOutputLoss
|
||||
from torchtune.modules.peft import (
|
||||
get_adapter_params,
|
||||
|
@ -45,6 +45,7 @@ from llama_stack.apis.post_training import (
|
|||
)
|
||||
from llama_stack.core.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||
from llama_stack.core.utils.model_utils import model_local_dir
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
||||
from llama_stack.providers.inline.post_training.torchtune.common import utils
|
||||
|
@ -56,9 +57,7 @@ from llama_stack.providers.inline.post_training.torchtune.config import (
|
|||
)
|
||||
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class LoraFinetuningSingleDevice:
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
|
@ -15,13 +14,14 @@ from llama_stack.apis.safety import (
|
|||
ViolationLevel,
|
||||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
from .config import CodeScannerConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="safety")
|
||||
|
||||
ALLOWED_CODE_SCANNER_MODEL_IDS = [
|
||||
"CodeScanner",
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
@ -19,6 +18,7 @@ from llama_stack.apis.safety import (
|
|||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.core.utils.model_utils import model_local_dir
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
|
@ -26,10 +26,10 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
from .config import PromptGuardConfig, PromptGuardType
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
|
||||
|
||||
log = get_logger(name=__name__, category="safety")
|
||||
|
||||
|
||||
class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||
def __init__(self, config: PromptGuardConfig, _deps) -> None:
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
import collections
|
||||
import functools
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
import string
|
||||
|
@ -20,7 +19,9 @@ import nltk
|
|||
from pythainlp.tokenize import sent_tokenize as sent_tokenize_thai
|
||||
from pythainlp.tokenize import word_tokenize as word_tokenize_thai
|
||||
|
||||
logger = logging.getLogger()
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
WORD_LIST = [
|
||||
"western",
|
||||
|
@ -1726,7 +1727,7 @@ def get_langid(text: str, lid_path: str | None = None) -> str:
|
|||
try:
|
||||
line_langs.append(langdetect.detect(line))
|
||||
except langdetect.LangDetectException as e:
|
||||
logger.info("Unable to detect language for text %s due to %s", line, e) # refex: disable=pytotw.037
|
||||
log.info("Unable to detect language for text %s due to %s", line, e) # refex: disable=pytotw.037
|
||||
|
||||
if len(line_langs) == 0:
|
||||
return "en"
|
||||
|
@ -1885,7 +1886,7 @@ class ResponseLanguageChecker(Instruction):
|
|||
return langdetect.detect(value) == self._language
|
||||
except langdetect.LangDetectException as e:
|
||||
# Count as instruction is followed.
|
||||
logger.info("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037
|
||||
log.info("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037
|
||||
return True
|
||||
|
||||
|
||||
|
@ -3110,7 +3111,7 @@ class CapitalLettersEnglishChecker(Instruction):
|
|||
return value.isupper() and langdetect.detect(value) == "en"
|
||||
except langdetect.LangDetectException as e:
|
||||
# Count as instruction is followed.
|
||||
logger.info("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037
|
||||
log.info("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037
|
||||
return True
|
||||
|
||||
|
||||
|
@ -3139,7 +3140,7 @@ class LowercaseLettersEnglishChecker(Instruction):
|
|||
return value.islower() and langdetect.detect(value) == "en"
|
||||
except langdetect.LangDetectException as e:
|
||||
# Count as instruction is followed.
|
||||
logger.info("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037
|
||||
log.info("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037
|
||||
return True
|
||||
|
||||
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import secrets
|
||||
import string
|
||||
from typing import Any
|
||||
|
@ -32,6 +31,7 @@ from llama_stack.apis.tools import (
|
|||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
|
@ -42,7 +42,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
from .config import RagToolRuntimeConfig
|
||||
from .context_retriever import generate_rag_query
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="tools")
|
||||
|
||||
|
||||
def make_random_string(length: int = 8):
|
||||
|
|
|
@ -8,7 +8,6 @@ import asyncio
|
|||
import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import faiss
|
||||
|
@ -24,6 +23,7 @@ from llama_stack.apis.vector_io import (
|
|||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import (
|
||||
HealthResponse,
|
||||
HealthStatus,
|
||||
|
@ -39,7 +39,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
|
||||
from .config import FaissVectorIOConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:{VERSION}::"
|
||||
|
@ -83,7 +83,7 @@ class FaissIndex(EmbeddingIndex):
|
|||
self.index = faiss.deserialize_index(np.load(buffer, allow_pickle=False))
|
||||
self.chunk_ids = [chunk.chunk_id for chunk in self.chunk_by_index.values()]
|
||||
except Exception as e:
|
||||
logger.debug(e, exc_info=True)
|
||||
log.debug(e, exc_info=True)
|
||||
raise ValueError(
|
||||
"Error deserializing Faiss index from storage. If you recently upgraded your Llama Stack, Faiss, "
|
||||
"or NumPy versions, you may need to delete the index and re-create it again or downgrade versions.\n"
|
||||
|
@ -262,7 +262,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
|
|||
assert self.kvstore is not None
|
||||
|
||||
if vector_db_id not in self.cache:
|
||||
logger.warning(f"Vector DB {vector_db_id} not found")
|
||||
log.warning(f"Vector DB {vector_db_id} not found")
|
||||
return
|
||||
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import sqlite3
|
||||
import struct
|
||||
|
@ -24,6 +23,7 @@ from llama_stack.apis.vector_io import (
|
|||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
|
@ -35,7 +35,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
VectorDBWithIndex,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
# Specifying search mode is dependent on the VectorIO provider.
|
||||
VECTOR_SEARCH = "vector"
|
||||
|
@ -257,7 +257,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
|
||||
except sqlite3.Error as e:
|
||||
connection.rollback()
|
||||
logger.error(f"Error inserting into {self.vector_table}: {e}")
|
||||
log.error(f"Error inserting into {self.vector_table}: {e}")
|
||||
raise
|
||||
|
||||
finally:
|
||||
|
@ -306,7 +306,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
try:
|
||||
chunk = Chunk.model_validate_json(chunk_json)
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing chunk JSON for id {_id}: {e}")
|
||||
log.error(f"Error parsing chunk JSON for id {_id}: {e}")
|
||||
continue
|
||||
chunks.append(chunk)
|
||||
scores.append(score)
|
||||
|
@ -352,7 +352,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
try:
|
||||
chunk = Chunk.model_validate_json(chunk_json)
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing chunk JSON for id {_id}: {e}")
|
||||
log.error(f"Error parsing chunk JSON for id {_id}: {e}")
|
||||
continue
|
||||
chunks.append(chunk)
|
||||
scores.append(score)
|
||||
|
@ -447,7 +447,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
connection.commit()
|
||||
except Exception as e:
|
||||
connection.rollback()
|
||||
logger.error(f"Error deleting chunk {chunk_id}: {e}")
|
||||
log.error(f"Error deleting chunk {chunk_id}: {e}")
|
||||
raise
|
||||
finally:
|
||||
cur.close()
|
||||
|
@ -530,7 +530,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
|||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
if vector_db_id not in self.cache:
|
||||
logger.warning(f"Vector DB {vector_db_id} not found")
|
||||
log.warning(f"Vector DB {vector_db_id} not found")
|
||||
return
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
|
|
|
@ -65,7 +65,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
from .config import FireworksImplConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
log = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
|
@ -256,7 +256,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
"stream": bool(request.stream),
|
||||
**self._build_options(request.sampling_params, request.response_format, request.logprobs),
|
||||
}
|
||||
logger.debug(f"params to fireworks: {params}")
|
||||
log.debug(f"params to fireworks: {params}")
|
||||
|
||||
return params
|
||||
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import logging
|
||||
|
||||
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
@ -11,8 +10,6 @@ from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
|||
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||
"""
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
|
@ -33,6 +32,7 @@ from llama_stack.apis.inference import (
|
|||
ToolChoice,
|
||||
ToolConfig,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
|
@ -54,7 +54,7 @@ from .openai_utils import (
|
|||
)
|
||||
from .utils import _is_nvidia_hosted
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
|
||||
|
@ -75,7 +75,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
|
|||
# TODO(mf): filter by available models
|
||||
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
||||
|
||||
logger.info(f"Initializing NVIDIAInferenceAdapter({config.url})...")
|
||||
log.info(f"Initializing NVIDIAInferenceAdapter({config.url})...")
|
||||
|
||||
if _is_nvidia_hosted(config):
|
||||
if not config.api_key:
|
||||
|
|
|
@ -4,13 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
|
||||
import httpx
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from . import NVIDIAConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:
|
||||
|
@ -44,7 +45,7 @@ async def check_health(config: NVIDIAConfig) -> None:
|
|||
RuntimeError: If the server is not running or ready
|
||||
"""
|
||||
if not _is_nvidia_hosted(config):
|
||||
logger.info("Checking NVIDIA NIM health...")
|
||||
log.info("Checking NVIDIA NIM health...")
|
||||
try:
|
||||
is_live, is_ready = await _get_health(config.url)
|
||||
if not is_live:
|
||||
|
|
|
@ -85,7 +85,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
log = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
class OllamaInferenceAdapter(
|
||||
|
@ -117,10 +117,10 @@ class OllamaInferenceAdapter(
|
|||
return self._openai_client
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.info(f"checking connectivity to Ollama at `{self.config.url}`...")
|
||||
log.info(f"checking connectivity to Ollama at `{self.config.url}`...")
|
||||
health_response = await self.health()
|
||||
if health_response["status"] == HealthStatus.ERROR:
|
||||
logger.warning(
|
||||
log.warning(
|
||||
"Ollama Server is not running, make sure to start it using `ollama serve` in a separate terminal"
|
||||
)
|
||||
|
||||
|
@ -339,7 +339,7 @@ class OllamaInferenceAdapter(
|
|||
"options": sampling_options,
|
||||
"stream": request.stream,
|
||||
}
|
||||
logger.debug(f"params to ollama: {params}")
|
||||
log.debug(f"params to ollama: {params}")
|
||||
|
||||
return params
|
||||
|
||||
|
@ -437,7 +437,7 @@ class OllamaInferenceAdapter(
|
|||
if provider_resource_id not in available_models:
|
||||
available_models_latest = [m.model.split(":latest")[0] for m in response.models]
|
||||
if provider_resource_id in available_models_latest:
|
||||
logger.warning(
|
||||
log.warning(
|
||||
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'"
|
||||
)
|
||||
return model
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
@ -12,8 +11,6 @@ from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
|||
from .config import OpenAIConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
#
|
||||
# This OpenAI adapter implements Inference methods using two mixins -
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from huggingface_hub import AsyncInferenceClient, HfApi
|
||||
|
@ -34,6 +33,7 @@ from llama_stack.apis.inference import (
|
|||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.sku_list import all_registered_models
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
|
@ -58,7 +58,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = get_logger(__name__, category="core")
|
||||
|
||||
|
||||
def build_hf_repo_model_entries():
|
||||
|
@ -307,7 +307,7 @@ class TGIAdapter(_HfAdapter):
|
|||
async def initialize(self, config: TGIImplConfig) -> None:
|
||||
if not config.url:
|
||||
raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.")
|
||||
log.info(f"Initializing TGI client with url={config.url}")
|
||||
logger.info(f"Initializing TGI client with url={config.url}")
|
||||
self.client = AsyncInferenceClient(
|
||||
model=config.url,
|
||||
)
|
||||
|
|
|
@ -61,7 +61,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
from .config import TogetherImplConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
log = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
|
@ -232,7 +232,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
"stream": request.stream,
|
||||
**self._build_options(request.sampling_params, request.logprobs, request.response_format),
|
||||
}
|
||||
logger.debug(f"params to together: {params}")
|
||||
log.debug(f"params to together: {params}")
|
||||
return params
|
||||
|
||||
async def embeddings(
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Any
|
||||
|
||||
|
@ -15,8 +14,6 @@ from llama_stack.providers.remote.post_training.nvidia.config import SFTLoRADefa
|
|||
|
||||
from .config import NvidiaPostTrainingConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def warn_unsupported_params(config_dict: Any, supported_keys: set[str], config_name: str) -> None:
|
||||
keys = set(config_dict.__annotations__.keys()) if isinstance(config_dict, BaseModel) else config_dict.keys()
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
|
@ -16,12 +15,13 @@ from llama_stack.apis.safety import (
|
|||
ViolationLevel,
|
||||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
||||
|
||||
from .config import BedrockSafetyConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="safety")
|
||||
|
||||
|
||||
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||
|
@ -76,13 +76,13 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
|||
"""
|
||||
|
||||
shield_params = shield.params
|
||||
logger.debug(f"run_shield::{shield_params}::messages={messages}")
|
||||
log.debug(f"run_shield::{shield_params}::messages={messages}")
|
||||
|
||||
# - convert the messages into format Bedrock expects
|
||||
content_messages = []
|
||||
for message in messages:
|
||||
content_messages.append({"text": {"text": message.content}})
|
||||
logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
|
||||
log.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
|
||||
|
||||
response = self.bedrock_runtime_client.apply_guardrail(
|
||||
guardrailIdentifier=shield.provider_resource_id,
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
|
@ -17,8 +16,6 @@ from llama_stack.providers.utils.inference.openai_compat import convert_message_
|
|||
|
||||
from .config import NVIDIASafetyConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||
def __init__(self, config: NVIDIASafetyConfig) -> None:
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
|
@ -20,12 +19,13 @@ from llama_stack.apis.safety import (
|
|||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new
|
||||
|
||||
from .config import SambaNovaSafetyConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="safety")
|
||||
|
||||
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
||||
|
||||
|
@ -66,7 +66,7 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide
|
|||
"guard" not in shield.provider_resource_id.lower()
|
||||
or shield.provider_resource_id.split("sambanova/")[-1] not in self.environment_available_models
|
||||
):
|
||||
logger.warning(f"Shield {shield.provider_resource_id} not available in {list_models_url}")
|
||||
log.warning(f"Shield {shield.provider_resource_id} not available in {list_models_url}")
|
||||
|
||||
async def unregister_shield(self, identifier: str) -> None:
|
||||
pass
|
||||
|
@ -79,9 +79,9 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide
|
|||
raise ValueError(f"Shield {shield_id} not found")
|
||||
|
||||
shield_params = shield.params
|
||||
logger.debug(f"run_shield::{shield_params}::messages={messages}")
|
||||
log.debug(f"run_shield::{shield_params}::messages={messages}")
|
||||
content_messages = [await convert_message_to_openai_dict_new(m) for m in messages]
|
||||
logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
|
||||
log.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
|
||||
|
||||
response = litellm.completion(
|
||||
model=shield.provider_resource_id, messages=content_messages, api_key=self._get_api_key()
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
@ -20,6 +19,7 @@ from llama_stack.apis.vector_io import (
|
|||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
@ -32,8 +32,6 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
|
||||
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI
|
||||
|
||||
VERSION = "v3"
|
||||
|
@ -43,6 +41,8 @@ OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:chroma:{VERSION}::"
|
|||
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:chroma:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:chroma:{VERSION}::"
|
||||
|
||||
logger = get_logger(__name__, category="core")
|
||||
|
||||
|
||||
# this is a helper to allow us to use async and non-async chroma clients interchangeably
|
||||
async def maybe_await(result):
|
||||
|
@ -92,7 +92,7 @@ class ChromaIndex(EmbeddingIndex):
|
|||
doc = json.loads(doc)
|
||||
chunk = Chunk(**doc)
|
||||
except Exception:
|
||||
log.exception(f"Failed to parse document: {doc}")
|
||||
logger.exception(f"Failed to parse document: {doc}")
|
||||
continue
|
||||
|
||||
score = 1.0 / float(dist) if dist != 0 else float("inf")
|
||||
|
@ -137,7 +137,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
inference_api: Api.inference,
|
||||
files_api: Files | None,
|
||||
) -> None:
|
||||
log.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
|
||||
logger.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.client = None
|
||||
|
@ -150,7 +150,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
self.vector_db_store = self.kvstore
|
||||
|
||||
if isinstance(self.config, RemoteChromaVectorIOConfig):
|
||||
log.info(f"Connecting to Chroma server at: {self.config.url}")
|
||||
logger.info(f"Connecting to Chroma server at: {self.config.url}")
|
||||
url = self.config.url.rstrip("/")
|
||||
parsed = urlparse(url)
|
||||
|
||||
|
@ -159,7 +159,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
|
||||
self.client = await chromadb.AsyncHttpClient(host=parsed.hostname, port=parsed.port)
|
||||
else:
|
||||
log.info(f"Connecting to Chroma local db at: {self.config.db_path}")
|
||||
logger.info(f"Connecting to Chroma local db at: {self.config.db_path}")
|
||||
self.client = chromadb.PersistentClient(path=self.config.db_path)
|
||||
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||
|
||||
|
@ -182,7 +182,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
if vector_db_id not in self.cache:
|
||||
log.warning(f"Vector DB {vector_db_id} not found")
|
||||
logger.warning(f"Vector DB {vector_db_id} not found")
|
||||
return
|
||||
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
|
@ -21,6 +20,7 @@ from llama_stack.apis.vector_io import (
|
|||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
@ -34,7 +34,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti
|
|||
|
||||
from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:milvus:{VERSION}::"
|
||||
|
@ -68,7 +68,7 @@ class MilvusIndex(EmbeddingIndex):
|
|||
)
|
||||
|
||||
if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
|
||||
logger.info(f"Creating new collection {self.collection_name} with nullable sparse field")
|
||||
log.info(f"Creating new collection {self.collection_name} with nullable sparse field")
|
||||
# Create schema for vector search
|
||||
schema = self.client.create_schema()
|
||||
schema.add_field(
|
||||
|
@ -147,7 +147,7 @@ class MilvusIndex(EmbeddingIndex):
|
|||
data=data,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}")
|
||||
log.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}")
|
||||
raise e
|
||||
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
|
@ -203,7 +203,7 @@ class MilvusIndex(EmbeddingIndex):
|
|||
return QueryChunksResponse(chunks=filtered_chunks, scores=filtered_scores)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error performing BM25 search: {e}")
|
||||
log.error(f"Error performing BM25 search: {e}")
|
||||
# Fallback to simple text search
|
||||
return await self._fallback_keyword_search(query_string, k, score_threshold)
|
||||
|
||||
|
@ -247,7 +247,7 @@ class MilvusIndex(EmbeddingIndex):
|
|||
self.client.delete, collection_name=self.collection_name, filter=f'chunk_id == "{chunk_id}"'
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting chunk {chunk_id} from Milvus collection {self.collection_name}: {e}")
|
||||
log.error(f"Error deleting chunk {chunk_id} from Milvus collection {self.collection_name}: {e}")
|
||||
raise
|
||||
|
||||
|
||||
|
@ -288,10 +288,10 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
)
|
||||
self.cache[vector_db.identifier] = index
|
||||
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
||||
logger.info(f"Connecting to Milvus server at {self.config.uri}")
|
||||
log.info(f"Connecting to Milvus server at {self.config.uri}")
|
||||
self.client = MilvusClient(**self.config.model_dump(exclude_none=True))
|
||||
else:
|
||||
logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}")
|
||||
log.info(f"Connecting to Milvus Lite at: {self.config.db_path}")
|
||||
uri = os.path.expanduser(self.config.db_path)
|
||||
self.client = MilvusClient(uri=uri)
|
||||
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import psycopg2
|
||||
|
@ -22,6 +21,7 @@ from llama_stack.apis.vector_io import (
|
|||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
|
@ -33,8 +33,6 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
|
||||
from .config import PGVectorVectorIOConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:pgvector:{VERSION}::"
|
||||
VECTOR_INDEX_PREFIX = f"vector_index:pgvector:{VERSION}::"
|
||||
|
@ -42,6 +40,8 @@ OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:pgvector:{VERSION}::"
|
|||
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:pgvector:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:pgvector:{VERSION}::"
|
||||
|
||||
logger = get_logger(__name__, category="core")
|
||||
|
||||
|
||||
def check_extension_version(cur):
|
||||
cur.execute("SELECT extversion FROM pg_extension WHERE extname = 'vector'")
|
||||
|
@ -187,7 +187,7 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
|
|||
self.metadatadata_collection_name = "openai_vector_stores_metadata"
|
||||
|
||||
async def initialize(self) -> None:
|
||||
log.info(f"Initializing PGVector memory adapter with config: {self.config}")
|
||||
logger.info(f"Initializing PGVector memory adapter with config: {self.config}")
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
await self.initialize_openai_vector_stores()
|
||||
|
||||
|
@ -203,7 +203,7 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
|
|||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
version = check_extension_version(cur)
|
||||
if version:
|
||||
log.info(f"Vector extension version: {version}")
|
||||
logger.info(f"Vector extension version: {version}")
|
||||
else:
|
||||
raise RuntimeError("Vector extension is not installed.")
|
||||
|
||||
|
@ -216,13 +216,13 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
|
|||
"""
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception("Could not connect to PGVector database server")
|
||||
logger.exception("Could not connect to PGVector database server")
|
||||
raise RuntimeError("Could not connect to PGVector database server") from e
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if self.conn is not None:
|
||||
self.conn.close()
|
||||
log.info("Connection to PGVector database server closed")
|
||||
logger.info("Connection to PGVector database server closed")
|
||||
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
# Persist vector DB metadata in the KV store
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
|
@ -24,6 +23,7 @@ from llama_stack.apis.vector_io import (
|
|||
VectorStoreChunkingStrategy,
|
||||
VectorStoreFileObject,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
|
||||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
||||
|
@ -35,13 +35,14 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
|
||||
from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
CHUNK_ID_KEY = "_chunk_id"
|
||||
|
||||
# KV store prefixes for vector databases
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:qdrant:{VERSION}::"
|
||||
|
||||
logger = get_logger(__name__, category="core")
|
||||
|
||||
|
||||
def convert_id(_id: str) -> str:
|
||||
"""
|
||||
|
@ -96,7 +97,7 @@ class QdrantIndex(EmbeddingIndex):
|
|||
points_selector=models.PointIdsList(points=[convert_id(chunk_id)]),
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error deleting chunk {chunk_id} from Qdrant collection {self.collection_name}: {e}")
|
||||
logger.error(f"Error deleting chunk {chunk_id} from Qdrant collection {self.collection_name}: {e}")
|
||||
raise
|
||||
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
|
@ -118,7 +119,7 @@ class QdrantIndex(EmbeddingIndex):
|
|||
try:
|
||||
chunk = Chunk(**point.payload["chunk_content"])
|
||||
except Exception:
|
||||
log.exception("Failed to parse chunk")
|
||||
logger.exception("Failed to parse chunk")
|
||||
continue
|
||||
|
||||
chunks.append(chunk)
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import weaviate
|
||||
|
@ -19,6 +18,7 @@ from llama_stack.apis.files.files import Files
|
|||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
|
@ -33,8 +33,6 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti
|
|||
|
||||
from .config import WeaviateVectorIOConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::"
|
||||
VECTOR_INDEX_PREFIX = f"vector_index:weaviate:{VERSION}::"
|
||||
|
@ -42,6 +40,8 @@ OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:weaviate:{VERSION}::"
|
|||
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:weaviate:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:weaviate:{VERSION}::"
|
||||
|
||||
logger = get_logger(__name__, category="core")
|
||||
|
||||
|
||||
class WeaviateIndex(EmbeddingIndex):
|
||||
def __init__(
|
||||
|
@ -102,7 +102,7 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
chunk_dict = json.loads(chunk_json)
|
||||
chunk = Chunk(**chunk_dict)
|
||||
except Exception:
|
||||
log.exception(f"Failed to parse document: {chunk_json}")
|
||||
logger.exception(f"Failed to parse document: {chunk_json}")
|
||||
continue
|
||||
|
||||
score = 1.0 / doc.metadata.distance if doc.metadata.distance != 0 else float("inf")
|
||||
|
@ -171,7 +171,7 @@ class WeaviateVectorIOAdapter(
|
|||
|
||||
def _get_client(self) -> weaviate.Client:
|
||||
if "localhost" in self.config.weaviate_cluster_url:
|
||||
log.info("using Weaviate locally in container")
|
||||
logger.info("using Weaviate locally in container")
|
||||
host, port = self.config.weaviate_cluster_url.split(":")
|
||||
key = "local_test"
|
||||
client = weaviate.connect_to_local(
|
||||
|
@ -179,7 +179,7 @@ class WeaviateVectorIOAdapter(
|
|||
port=port,
|
||||
)
|
||||
else:
|
||||
log.info("Using Weaviate remote cluster with URL")
|
||||
logger.info("Using Weaviate remote cluster with URL")
|
||||
key = f"{self.config.weaviate_cluster_url}::{self.config.weaviate_api_key}"
|
||||
if key in self.client_cache:
|
||||
return self.client_cache[key]
|
||||
|
@ -197,7 +197,7 @@ class WeaviateVectorIOAdapter(
|
|||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
else:
|
||||
self.kvstore = None
|
||||
log.info("No kvstore configured, registry will not persist across restarts")
|
||||
logger.info("No kvstore configured, registry will not persist across restarts")
|
||||
|
||||
# Load existing vector DB definitions
|
||||
if self.kvstore is not None:
|
||||
|
@ -254,7 +254,7 @@ class WeaviateVectorIOAdapter(
|
|||
client = self._get_client()
|
||||
sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True)
|
||||
if sanitized_collection_name not in self.cache or client.collections.exists(sanitized_collection_name) is False:
|
||||
log.warning(f"Vector DB {sanitized_collection_name} not found")
|
||||
logger.warning(f"Vector DB {sanitized_collection_name} not found")
|
||||
return
|
||||
client.collections.delete(sanitized_collection_name)
|
||||
await self.cache[sanitized_collection_name].index.delete()
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import struct
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
@ -27,7 +26,9 @@ from llama_stack.providers.utils.inference.prompt_adapter import interleaved_con
|
|||
EMBEDDING_MODELS = {}
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
log = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
class SentenceTransformerEmbeddingMixin:
|
||||
|
|
|
@ -54,7 +54,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
log = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
class LiteLLMOpenAIMixin(
|
||||
|
@ -157,7 +157,7 @@ class LiteLLMOpenAIMixin(
|
|||
params = await self._get_params(request)
|
||||
params["model"] = self.get_litellm_model_name(params["model"])
|
||||
|
||||
logger.debug(f"params to litellm (openai compat): {params}")
|
||||
log.debug(f"params to litellm (openai compat): {params}")
|
||||
# see https://docs.litellm.ai/docs/completion/stream#async-completion
|
||||
response = await litellm.acompletion(**params)
|
||||
if stream:
|
||||
|
@ -460,7 +460,7 @@ class LiteLLMOpenAIMixin(
|
|||
:return: True if the model is available dynamically, False otherwise.
|
||||
"""
|
||||
if self.litellm_provider_name not in litellm.models_by_provider:
|
||||
logger.error(f"Provider {self.litellm_provider_name} is not registered in litellm.")
|
||||
log.error(f"Provider {self.litellm_provider_name} is not registered in litellm.")
|
||||
return False
|
||||
|
||||
return model in litellm.models_by_provider[self.litellm_provider_name]
|
||||
|
|
|
@ -17,7 +17,7 @@ from llama_stack.providers.utils.inference import (
|
|||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class RemoteInferenceProviderConfig(BaseModel):
|
||||
|
@ -135,7 +135,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
:param model: The model identifier to check.
|
||||
:return: True if the model is available dynamically, False otherwise.
|
||||
"""
|
||||
logger.info(
|
||||
log.info(
|
||||
f"check_model_availability is not implemented for {self.__class__.__name__}. Returning False by default."
|
||||
)
|
||||
return False
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import struct
|
||||
import time
|
||||
import uuid
|
||||
|
@ -116,6 +115,7 @@ from llama_stack.apis.inference import (
|
|||
from llama_stack.apis.inference import (
|
||||
OpenAIChoice as OpenAIChatCompletionChoice,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
StopReason,
|
||||
|
@ -128,7 +128,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
decode_assistant_message,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
class OpenAICompatCompletionChoiceDelta(BaseModel):
|
||||
|
@ -316,7 +316,7 @@ def process_chat_completion_response(
|
|||
if t.tool_name in request_tools:
|
||||
new_tool_calls.append(t)
|
||||
else:
|
||||
logger.warning(f"Tool {t.tool_name} not found in request tools")
|
||||
log.warning(f"Tool {t.tool_name} not found in request tools")
|
||||
|
||||
if len(new_tool_calls) < len(raw_message.tool_calls):
|
||||
raw_message.tool_calls = new_tool_calls
|
||||
|
@ -477,7 +477,7 @@ async def process_chat_completion_stream_response(
|
|||
)
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Tool {tool_call.tool_name} not found in request tools")
|
||||
log.warning(f"Tool {tool_call.tool_name} not found in request tools")
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
|
@ -1198,7 +1198,7 @@ async def convert_openai_chat_completion_stream(
|
|||
)
|
||||
|
||||
for idx, buffer in tool_call_idx_to_buffer.items():
|
||||
logger.debug(f"toolcall_buffer[{idx}]: {buffer}")
|
||||
log.debug(f"toolcall_buffer[{idx}]: {buffer}")
|
||||
if buffer["name"]:
|
||||
delta = ")"
|
||||
buffer["content"] += delta
|
||||
|
|
|
@ -25,7 +25,7 @@ from llama_stack.apis.inference import (
|
|||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class OpenAIMixin(ABC):
|
||||
|
@ -125,9 +125,9 @@ class OpenAIMixin(ABC):
|
|||
Direct OpenAI completion API call.
|
||||
"""
|
||||
if guided_choice is not None:
|
||||
logger.warning("guided_choice is not supported by the OpenAI API. Ignoring.")
|
||||
log.warning("guided_choice is not supported by the OpenAI API. Ignoring.")
|
||||
if prompt_logprobs is not None:
|
||||
logger.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.")
|
||||
log.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.")
|
||||
|
||||
# TODO: fix openai_completion to return type compatible with OpenAI's API response
|
||||
return await self.client.completions.create( # type: ignore[no-any-return]
|
||||
|
@ -267,6 +267,6 @@ class OpenAIMixin(ABC):
|
|||
pass
|
||||
except Exception as e:
|
||||
# All other errors (auth, rate limit, network, etc.)
|
||||
logger.warning(f"Failed to check model availability for {model}: {e}")
|
||||
log.warning(f"Failed to check model availability for {model}: {e}")
|
||||
|
||||
return False
|
||||
|
|
|
@ -4,16 +4,16 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from pymongo import AsyncMongoClient
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
||||
from ..config import MongoDBKVStoreConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class MongoDBKVStoreImpl(KVStore):
|
||||
|
|
|
@ -4,16 +4,17 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
import psycopg2
|
||||
from psycopg2.extras import DictCursor
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from ..api import KVStore
|
||||
from ..config import PostgresKVStoreConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class PostgresKVStoreImpl(KVStore):
|
||||
|
|
|
@ -6,7 +6,6 @@
|
|||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
import time
|
||||
import uuid
|
||||
|
@ -37,10 +36,11 @@ from llama_stack.apis.vector_io import (
|
|||
VectorStoreSearchResponse,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, make_overlapped_chunks
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
# Constants for OpenAI vector stores
|
||||
CHUNK_MULTIPLIER = 5
|
||||
|
@ -378,7 +378,7 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
try:
|
||||
await self.unregister_vector_db(vector_store_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete underlying vector DB {vector_store_id}: {e}")
|
||||
log.warning(f"Failed to delete underlying vector DB {vector_store_id}: {e}")
|
||||
|
||||
return VectorStoreDeleteResponse(
|
||||
id=vector_store_id,
|
||||
|
@ -460,7 +460,7 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching vector store {vector_store_id}: {e}")
|
||||
log.error(f"Error searching vector store {vector_store_id}: {e}")
|
||||
# Return empty results on error
|
||||
return VectorStoreSearchResponsePage(
|
||||
search_query=search_query,
|
||||
|
@ -614,7 +614,7 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
)
|
||||
vector_store_file_object.status = "completed"
|
||||
except Exception as e:
|
||||
logger.error(f"Error attaching file to vector store: {e}")
|
||||
log.error(f"Error attaching file to vector store: {e}")
|
||||
vector_store_file_object.status = "failed"
|
||||
vector_store_file_object.last_error = VectorStoreFileLastError(
|
||||
code="server_error",
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
|
@ -25,6 +24,7 @@ from llama_stack.apis.common.content_types import (
|
|||
from llama_stack.apis.tools import RAGDocument
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
|
@ -32,12 +32,12 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
)
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Constants for reranker types
|
||||
RERANKER_TYPE_RRF = "rrf"
|
||||
RERANKER_TYPE_WEIGHTED = "weighted"
|
||||
|
||||
log = get_logger(name=__name__, category="memory")
|
||||
|
||||
|
||||
def parse_pdf(data: bytes) -> str:
|
||||
# For PDF and DOC/DOCX files, we can't reliably convert to string
|
||||
|
|
|
@ -17,7 +17,7 @@ from pydantic import BaseModel
|
|||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="scheduler")
|
||||
log = get_logger(name=__name__, category="scheduler")
|
||||
|
||||
|
||||
# TODO: revisit the list of possible statuses when defining a more coherent
|
||||
|
@ -186,7 +186,7 @@ class _NaiveSchedulerBackend(_SchedulerBackend):
|
|||
except Exception as e:
|
||||
on_log_message_cb(str(e))
|
||||
job.status = JobStatus.failed
|
||||
logger.exception(f"Job {job.id} failed.")
|
||||
log.exception(f"Job {job.id} failed.")
|
||||
|
||||
asyncio.run_coroutine_threadsafe(do(), self._loop)
|
||||
|
||||
|
@ -222,7 +222,7 @@ class Scheduler:
|
|||
msg = (datetime.now(UTC), message)
|
||||
# At least for the time being, until there's a better way to expose
|
||||
# logs to users, log messages on console
|
||||
logger.info(f"Job {job.id}: {message}")
|
||||
log.info(f"Job {job.id}: {message}")
|
||||
job.append_log(msg)
|
||||
self._backend.on_log_message_cb(job, msg)
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@ from llama_stack.log import get_logger
|
|||
from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore
|
||||
from .sqlstore import SqlStoreType
|
||||
|
||||
logger = get_logger(name=__name__, category="authorized_sqlstore")
|
||||
log = get_logger(name=__name__, category="authorized_sqlstore")
|
||||
|
||||
# Hardcoded copy of the default policy that our SQL filtering implements
|
||||
# WARNING: If default_policy() changes, this constant must be updated accordingly
|
||||
|
@ -81,7 +81,7 @@ class AuthorizedSqlStore:
|
|||
actual_default = default_policy()
|
||||
|
||||
if SQL_OPTIMIZED_POLICY != actual_default:
|
||||
logger.warning(
|
||||
log.warning(
|
||||
f"SQL_OPTIMIZED_POLICY does not match default_policy(). "
|
||||
f"SQL filtering will use conservative mode. "
|
||||
f"Expected: {SQL_OPTIMIZED_POLICY}, Got: {actual_default}",
|
||||
|
|
|
@ -29,7 +29,7 @@ from llama_stack.log import get_logger
|
|||
from .api import ColumnDefinition, ColumnType, SqlStore
|
||||
from .sqlstore import SqlAlchemySqlStoreConfig
|
||||
|
||||
logger = get_logger(name=__name__, category="sqlstore")
|
||||
log = get_logger(name=__name__, category="sqlstore")
|
||||
|
||||
TYPE_MAPPING: dict[ColumnType, Any] = {
|
||||
ColumnType.INTEGER: Integer,
|
||||
|
@ -280,5 +280,5 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
|||
except Exception as e:
|
||||
# If any error occurs during migration, log it but don't fail
|
||||
# The table creation will handle adding the column
|
||||
logger.error(f"Error adding column {column_name} to table {table}: {e}")
|
||||
log.error(f"Error adding column {column_name} to table {table}: {e}")
|
||||
pass
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import logging
|
||||
import logging # allow-direct-logging
|
||||
import queue
|
||||
import random
|
||||
import threading
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
|
@ -19,10 +18,9 @@ from llama_stack.apis.post_training import (
|
|||
LoraFinetuningConfig,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", force=True)
|
||||
logger = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
skip_because_resource_intensive = pytest.mark.skip(
|
||||
|
@ -71,14 +69,14 @@ class TestPostTraining:
|
|||
)
|
||||
@pytest.mark.timeout(360) # 6 minutes timeout
|
||||
def test_supervised_fine_tune(self, llama_stack_client, purpose, source):
|
||||
logger.info("Starting supervised fine-tuning test")
|
||||
log.info("Starting supervised fine-tuning test")
|
||||
|
||||
# register dataset to train
|
||||
dataset = llama_stack_client.datasets.register(
|
||||
purpose=purpose,
|
||||
source=source,
|
||||
)
|
||||
logger.info(f"Registered dataset with ID: {dataset.identifier}")
|
||||
log.info(f"Registered dataset with ID: {dataset.identifier}")
|
||||
|
||||
algorithm_config = LoraFinetuningConfig(
|
||||
type="LoRA",
|
||||
|
@ -105,7 +103,7 @@ class TestPostTraining:
|
|||
)
|
||||
|
||||
job_uuid = f"test-job{uuid.uuid4()}"
|
||||
logger.info(f"Starting training job with UUID: {job_uuid}")
|
||||
log.info(f"Starting training job with UUID: {job_uuid}")
|
||||
|
||||
# train with HF trl SFTTrainer as the default
|
||||
_ = llama_stack_client.post_training.supervised_fine_tune(
|
||||
|
@ -121,21 +119,21 @@ class TestPostTraining:
|
|||
while True:
|
||||
status = llama_stack_client.post_training.job.status(job_uuid=job_uuid)
|
||||
if not status:
|
||||
logger.error("Job not found")
|
||||
log.error("Job not found")
|
||||
break
|
||||
|
||||
logger.info(f"Current status: {status}")
|
||||
log.info(f"Current status: {status}")
|
||||
assert status.status in ["scheduled", "in_progress", "completed"]
|
||||
if status.status == "completed":
|
||||
break
|
||||
|
||||
logger.info("Waiting for job to complete...")
|
||||
log.info("Waiting for job to complete...")
|
||||
time.sleep(10) # Increased sleep time to reduce polling frequency
|
||||
|
||||
artifacts = llama_stack_client.post_training.job.artifacts(job_uuid=job_uuid)
|
||||
logger.info(f"Job artifacts: {artifacts}")
|
||||
log.info(f"Job artifacts: {artifacts}")
|
||||
|
||||
logger.info(f"Registered dataset with ID: {dataset.identifier}")
|
||||
log.info(f"Registered dataset with ID: {dataset.identifier}")
|
||||
|
||||
# TODO: Fix these tests to properly represent the Jobs API in training
|
||||
#
|
||||
|
@ -181,17 +179,21 @@ class TestPostTraining:
|
|||
)
|
||||
@pytest.mark.timeout(360)
|
||||
def test_preference_optimize(self, llama_stack_client, purpose, source):
|
||||
logger.info("Starting DPO preference optimization test")
|
||||
log.info("Starting DPO preference optimization test")
|
||||
|
||||
# register preference dataset to train
|
||||
dataset = llama_stack_client.datasets.register(
|
||||
purpose=purpose,
|
||||
source=source,
|
||||
)
|
||||
logger.info(f"Registered preference dataset with ID: {dataset.identifier}")
|
||||
log.info(f"Registered preference dataset with ID: {dataset.identifier}")
|
||||
|
||||
# DPO algorithm configuration
|
||||
algorithm_config = DPOAlignmentConfig(
|
||||
reward_scale=1.0,
|
||||
reward_clip=10.0,
|
||||
epsilon=1e-8,
|
||||
gamma=0.99,
|
||||
beta=0.1,
|
||||
loss_type=DPOLossType.sigmoid, # Default loss type
|
||||
)
|
||||
|
@ -211,7 +213,7 @@ class TestPostTraining:
|
|||
)
|
||||
|
||||
job_uuid = f"test-dpo-job-{uuid.uuid4()}"
|
||||
logger.info(f"Starting DPO training job with UUID: {job_uuid}")
|
||||
log.info(f"Starting DPO training job with UUID: {job_uuid}")
|
||||
|
||||
# train with HuggingFace DPO implementation
|
||||
_ = llama_stack_client.post_training.preference_optimize(
|
||||
|
@ -226,15 +228,15 @@ class TestPostTraining:
|
|||
while True:
|
||||
status = llama_stack_client.post_training.job.status(job_uuid=job_uuid)
|
||||
if not status:
|
||||
logger.error("DPO job not found")
|
||||
log.error("DPO job not found")
|
||||
break
|
||||
|
||||
logger.info(f"Current DPO status: {status}")
|
||||
log.info(f"Current DPO status: {status}")
|
||||
if status.status == "completed":
|
||||
break
|
||||
|
||||
logger.info("Waiting for DPO job to complete...")
|
||||
log.info("Waiting for DPO job to complete...")
|
||||
time.sleep(10) # Increased sleep time to reduce polling frequency
|
||||
|
||||
artifacts = llama_stack_client.post_training.job.artifacts(job_uuid=job_uuid)
|
||||
logger.info(f"DPO job artifacts: {artifacts}")
|
||||
log.info(f"DPO job artifacts: {artifacts}")
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import time
|
||||
from io import BytesIO
|
||||
|
||||
|
@ -13,8 +12,10 @@ from llama_stack_client import BadRequestError, LlamaStackClient
|
|||
from openai import BadRequestError as OpenAIBadRequestError
|
||||
|
||||
from llama_stack.apis.vector_io import Chunk
|
||||
from llama_stack.core.library_client import LlamaStackAsLibraryClient
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="vector-io")
|
||||
|
||||
|
||||
def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
|
||||
|
@ -99,7 +100,7 @@ def compat_client_with_empty_stores(compat_client):
|
|||
compat_client.vector_stores.delete(vector_store_id=store.id)
|
||||
except Exception:
|
||||
# If the API is not available or fails, just continue
|
||||
logger.warning("Failed to clear vector stores")
|
||||
log.warning("Failed to clear vector stores")
|
||||
pass
|
||||
|
||||
def clear_files():
|
||||
|
@ -109,7 +110,7 @@ def compat_client_with_empty_stores(compat_client):
|
|||
compat_client.files.delete(file_id=file.id)
|
||||
except Exception:
|
||||
# If the API is not available or fails, just continue
|
||||
logger.warning("Failed to clear files")
|
||||
log.warning("Failed to clear files")
|
||||
pass
|
||||
|
||||
clear_vector_stores()
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import logging # allow-direct-logging
|
||||
import threading
|
||||
import time
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue