mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 14:08:00 +00:00
feat: add auto-generated CI documentation pre-commit hook (#2890)
Our CI is entirely undocumented, this commit adds a README.md file with a table of the current CI and what is does --------- Signed-off-by: Nathan Weinberg <nweinber@redhat.com>
This commit is contained in:
parent
7f834339ba
commit
b381ed6d64
93 changed files with 495 additions and 477 deletions
|
@ -145,6 +145,24 @@ repos:
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
require_serial: true
|
require_serial: true
|
||||||
files: ^.github/workflows/.*$
|
files: ^.github/workflows/.*$
|
||||||
|
- id: check-log-usage
|
||||||
|
name: Check for proper log usage (use llama_stack.log instead)
|
||||||
|
entry: bash
|
||||||
|
language: system
|
||||||
|
types: [python]
|
||||||
|
pass_filenames: true
|
||||||
|
args:
|
||||||
|
- -c
|
||||||
|
- |
|
||||||
|
matches=$(grep -EnH '^[^#]*\b(import logging|from logging\b)' "$@" | grep -v '# allow-direct-logging' || true)
|
||||||
|
if [ -n "$matches" ]; then
|
||||||
|
# GitHub Actions annotation format
|
||||||
|
while IFS=: read -r file line_num rest; do
|
||||||
|
echo "::error file=$file,line=$line_num::Do not use 'import logging' or 'from logging import' in $file. Use the custom log instead: from llama_stack.log import get_logger; logger = get_logger(). If direct logging is truly needed, add: # allow-direct-logging"
|
||||||
|
done <<< "$matches"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
exit 0
|
||||||
|
|
||||||
ci:
|
ci:
|
||||||
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
|
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
|
||||||
|
|
|
@ -15,7 +15,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="server")
|
log = get_logger(name=__name__, category="server")
|
||||||
|
|
||||||
|
|
||||||
class StackRun(Subcommand):
|
class StackRun(Subcommand):
|
||||||
|
@ -126,7 +126,7 @@ class StackRun(Subcommand):
|
||||||
self.parser.error("Config file is required for venv environment")
|
self.parser.error("Config file is required for venv environment")
|
||||||
|
|
||||||
if config_file:
|
if config_file:
|
||||||
logger.info(f"Using run configuration: {config_file}")
|
log.info(f"Using run configuration: {config_file}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
config_dict = yaml.safe_load(config_file.read_text())
|
config_dict = yaml.safe_load(config_file.read_text())
|
||||||
|
@ -145,7 +145,7 @@ class StackRun(Subcommand):
|
||||||
# If neither image type nor image name is provided, assume the server should be run directly
|
# If neither image type nor image name is provided, assume the server should be run directly
|
||||||
# using the current environment packages.
|
# using the current environment packages.
|
||||||
if not image_type and not image_name:
|
if not image_type and not image_name:
|
||||||
logger.info("No image type or image name provided. Assuming environment packages.")
|
log.info("No image type or image name provided. Assuming environment packages.")
|
||||||
from llama_stack.core.server.server import main as server_main
|
from llama_stack.core.server.server import main as server_main
|
||||||
|
|
||||||
# Build the server args from the current args passed to the CLI
|
# Build the server args from the current args passed to the CLI
|
||||||
|
@ -185,11 +185,11 @@ class StackRun(Subcommand):
|
||||||
run_command(run_args)
|
run_command(run_args)
|
||||||
|
|
||||||
def _start_ui_development_server(self, stack_server_port: int):
|
def _start_ui_development_server(self, stack_server_port: int):
|
||||||
logger.info("Attempting to start UI development server...")
|
log.info("Attempting to start UI development server...")
|
||||||
# Check if npm is available
|
# Check if npm is available
|
||||||
npm_check = subprocess.run(["npm", "--version"], capture_output=True, text=True, check=False)
|
npm_check = subprocess.run(["npm", "--version"], capture_output=True, text=True, check=False)
|
||||||
if npm_check.returncode != 0:
|
if npm_check.returncode != 0:
|
||||||
logger.warning(
|
log.warning(
|
||||||
f"'npm' command not found or not executable. UI development server will not be started. Error: {npm_check.stderr}"
|
f"'npm' command not found or not executable. UI development server will not be started. Error: {npm_check.stderr}"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
@ -214,13 +214,13 @@ class StackRun(Subcommand):
|
||||||
stderr=stderr_log_file,
|
stderr=stderr_log_file,
|
||||||
env={**os.environ, "NEXT_PUBLIC_LLAMA_STACK_BASE_URL": f"http://localhost:{stack_server_port}"},
|
env={**os.environ, "NEXT_PUBLIC_LLAMA_STACK_BASE_URL": f"http://localhost:{stack_server_port}"},
|
||||||
)
|
)
|
||||||
logger.info(f"UI development server process started in {ui_dir} with PID {process.pid}.")
|
log.info(f"UI development server process started in {ui_dir} with PID {process.pid}.")
|
||||||
logger.info(f"Logs: stdout -> {ui_stdout_log_path}, stderr -> {ui_stderr_log_path}")
|
log.info(f"Logs: stdout -> {ui_stdout_log_path}, stderr -> {ui_stderr_log_path}")
|
||||||
logger.info(f"UI will be available at http://localhost:{os.getenv('LLAMA_STACK_UI_PORT', 8322)}")
|
log.info(f"UI will be available at http://localhost:{os.getenv('LLAMA_STACK_UI_PORT', 8322)}")
|
||||||
|
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
logger.error(
|
log.error(
|
||||||
"Failed to start UI development server: 'npm' command not found. Make sure npm is installed and in your PATH."
|
"Failed to start UI development server: 'npm' command not found. Make sure npm is installed and in your PATH."
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to start UI development server in {ui_dir}: {e}")
|
log.error(f"Failed to start UI development server in {ui_dir}: {e}")
|
||||||
|
|
|
@ -8,7 +8,7 @@ import argparse
|
||||||
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="cli")
|
log = get_logger(name=__name__, category="cli")
|
||||||
|
|
||||||
|
|
||||||
# TODO: this can probably just be inlined now?
|
# TODO: this can probably just be inlined now?
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import importlib.resources
|
import importlib.resources
|
||||||
import logging
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
@ -17,10 +16,9 @@ from llama_stack.core.external import load_external_apis
|
||||||
from llama_stack.core.utils.exec import run_command
|
from llama_stack.core.utils.exec import run_command
|
||||||
from llama_stack.core.utils.image_types import LlamaStackImageType
|
from llama_stack.core.utils.image_types import LlamaStackImageType
|
||||||
from llama_stack.distributions.template import DistributionTemplate
|
from llama_stack.distributions.template import DistributionTemplate
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# These are the dependencies needed by the distribution server.
|
# These are the dependencies needed by the distribution server.
|
||||||
# `llama-stack` is automatically installed by the installation script.
|
# `llama-stack` is automatically installed by the installation script.
|
||||||
SERVER_DEPENDENCIES = [
|
SERVER_DEPENDENCIES = [
|
||||||
|
@ -33,6 +31,8 @@ SERVER_DEPENDENCIES = [
|
||||||
"opentelemetry-exporter-otlp-proto-http",
|
"opentelemetry-exporter-otlp-proto-http",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
log = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
class ApiInput(BaseModel):
|
class ApiInput(BaseModel):
|
||||||
api: Api
|
api: Api
|
||||||
|
|
|
@ -3,7 +3,6 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import logging
|
|
||||||
import textwrap
|
import textwrap
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
@ -21,9 +20,10 @@ from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
|
||||||
from llama_stack.core.utils.config_dirs import EXTERNAL_PROVIDERS_DIR
|
from llama_stack.core.utils.config_dirs import EXTERNAL_PROVIDERS_DIR
|
||||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||||
from llama_stack.core.utils.prompt_for_config import prompt_for_config
|
from llama_stack.core.utils.prompt_for_config import prompt_for_config
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
def configure_single_provider(registry: dict[str, ProviderSpec], provider: Provider) -> Provider:
|
def configure_single_provider(registry: dict[str, ProviderSpec], provider: Provider) -> Provider:
|
||||||
|
@ -49,7 +49,7 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
|
||||||
is_nux = len(config.providers) == 0
|
is_nux = len(config.providers) == 0
|
||||||
|
|
||||||
if is_nux:
|
if is_nux:
|
||||||
logger.info(
|
log.info(
|
||||||
textwrap.dedent(
|
textwrap.dedent(
|
||||||
"""
|
"""
|
||||||
Llama Stack is composed of several APIs working together. For each API served by the Stack,
|
Llama Stack is composed of several APIs working together. For each API served by the Stack,
|
||||||
|
@ -75,12 +75,12 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
|
||||||
|
|
||||||
existing_providers = config.providers.get(api_str, [])
|
existing_providers = config.providers.get(api_str, [])
|
||||||
if existing_providers:
|
if existing_providers:
|
||||||
logger.info(f"Re-configuring existing providers for API `{api_str}`...")
|
log.info(f"Re-configuring existing providers for API `{api_str}`...")
|
||||||
updated_providers = []
|
updated_providers = []
|
||||||
for p in existing_providers:
|
for p in existing_providers:
|
||||||
logger.info(f"> Configuring provider `({p.provider_type})`")
|
log.info(f"> Configuring provider `({p.provider_type})`")
|
||||||
updated_providers.append(configure_single_provider(provider_registry[api], p))
|
updated_providers.append(configure_single_provider(provider_registry[api], p))
|
||||||
logger.info("")
|
log.info("")
|
||||||
else:
|
else:
|
||||||
# we are newly configuring this API
|
# we are newly configuring this API
|
||||||
plist = build_spec.providers.get(api_str, [])
|
plist = build_spec.providers.get(api_str, [])
|
||||||
|
@ -89,17 +89,17 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
|
||||||
if not plist:
|
if not plist:
|
||||||
raise ValueError(f"No provider configured for API {api_str}?")
|
raise ValueError(f"No provider configured for API {api_str}?")
|
||||||
|
|
||||||
logger.info(f"Configuring API `{api_str}`...")
|
log.info(f"Configuring API `{api_str}`...")
|
||||||
updated_providers = []
|
updated_providers = []
|
||||||
for i, provider in enumerate(plist):
|
for i, provider in enumerate(plist):
|
||||||
if i >= 1:
|
if i >= 1:
|
||||||
others = ", ".join(p.provider_type for p in plist[i:])
|
others = ", ".join(p.provider_type for p in plist[i:])
|
||||||
logger.info(
|
log.info(
|
||||||
f"Not configuring other providers ({others}) interactively. Please edit the resulting YAML directly.\n"
|
f"Not configuring other providers ({others}) interactively. Please edit the resulting YAML directly.\n"
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
logger.info(f"> Configuring provider `({provider.provider_type})`")
|
log.info(f"> Configuring provider `({provider.provider_type})`")
|
||||||
pid = provider.provider_type.split("::")[-1]
|
pid = provider.provider_type.split("::")[-1]
|
||||||
updated_providers.append(
|
updated_providers.append(
|
||||||
configure_single_provider(
|
configure_single_provider(
|
||||||
|
@ -111,7 +111,7 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
logger.info("")
|
log.info("")
|
||||||
|
|
||||||
config.providers[api_str] = updated_providers
|
config.providers[api_str] = updated_providers
|
||||||
|
|
||||||
|
@ -169,7 +169,7 @@ def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfi
|
||||||
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))
|
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))
|
||||||
|
|
||||||
if "routing_table" in config_dict:
|
if "routing_table" in config_dict:
|
||||||
logger.info("Upgrading config...")
|
log.info("Upgrading config...")
|
||||||
config_dict = upgrade_from_routing_table(config_dict)
|
config_dict = upgrade_from_routing_table(config_dict)
|
||||||
|
|
||||||
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
|
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
|
||||||
|
|
|
@ -23,7 +23,7 @@ from llama_stack.providers.datatypes import (
|
||||||
remote_provider_spec,
|
remote_provider_spec,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
log = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
def stack_apis() -> list[Api]:
|
def stack_apis() -> list[Api]:
|
||||||
|
@ -141,18 +141,18 @@ def get_provider_registry(config=None) -> dict[Api, dict[str, ProviderSpec]]:
|
||||||
registry: dict[Api, dict[str, ProviderSpec]] = {}
|
registry: dict[Api, dict[str, ProviderSpec]] = {}
|
||||||
for api in providable_apis():
|
for api in providable_apis():
|
||||||
name = api.name.lower()
|
name = api.name.lower()
|
||||||
logger.debug(f"Importing module {name}")
|
log.debug(f"Importing module {name}")
|
||||||
try:
|
try:
|
||||||
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
|
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
|
||||||
registry[api] = {a.provider_type: a for a in module.available_providers()}
|
registry[api] = {a.provider_type: a for a in module.available_providers()}
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning(f"Failed to import module {name}: {e}")
|
log.warning(f"Failed to import module {name}: {e}")
|
||||||
|
|
||||||
# Refresh providable APIs with external APIs if any
|
# Refresh providable APIs with external APIs if any
|
||||||
external_apis = load_external_apis(config)
|
external_apis = load_external_apis(config)
|
||||||
for api, api_spec in external_apis.items():
|
for api, api_spec in external_apis.items():
|
||||||
name = api_spec.name.lower()
|
name = api_spec.name.lower()
|
||||||
logger.info(f"Importing external API {name} module {api_spec.module}")
|
log.info(f"Importing external API {name} module {api_spec.module}")
|
||||||
try:
|
try:
|
||||||
module = importlib.import_module(api_spec.module)
|
module = importlib.import_module(api_spec.module)
|
||||||
registry[api] = {a.provider_type: a for a in module.available_providers()}
|
registry[api] = {a.provider_type: a for a in module.available_providers()}
|
||||||
|
@ -161,7 +161,7 @@ def get_provider_registry(config=None) -> dict[Api, dict[str, ProviderSpec]]:
|
||||||
# This assume that the in-tree provider(s) are not available for this API which means
|
# This assume that the in-tree provider(s) are not available for this API which means
|
||||||
# that users will need to use external providers for this API.
|
# that users will need to use external providers for this API.
|
||||||
registry[api] = {}
|
registry[api] = {}
|
||||||
logger.error(
|
log.error(
|
||||||
f"Failed to import external API {name}: {e}. Could not populate the in-tree provider(s) registry for {api.name}. \n"
|
f"Failed to import external API {name}: {e}. Could not populate the in-tree provider(s) registry for {api.name}. \n"
|
||||||
"Install the API package to load any in-tree providers for this API."
|
"Install the API package to load any in-tree providers for this API."
|
||||||
)
|
)
|
||||||
|
@ -183,13 +183,13 @@ def get_provider_registry(config=None) -> dict[Api, dict[str, ProviderSpec]]:
|
||||||
def get_external_providers_from_dir(
|
def get_external_providers_from_dir(
|
||||||
registry: dict[Api, dict[str, ProviderSpec]], config
|
registry: dict[Api, dict[str, ProviderSpec]], config
|
||||||
) -> dict[Api, dict[str, ProviderSpec]]:
|
) -> dict[Api, dict[str, ProviderSpec]]:
|
||||||
logger.warning(
|
log.warning(
|
||||||
"Specifying external providers via `external_providers_dir` is being deprecated. Please specify `module:` in the provider instead."
|
"Specifying external providers via `external_providers_dir` is being deprecated. Please specify `module:` in the provider instead."
|
||||||
)
|
)
|
||||||
external_providers_dir = os.path.abspath(os.path.expanduser(config.external_providers_dir))
|
external_providers_dir = os.path.abspath(os.path.expanduser(config.external_providers_dir))
|
||||||
if not os.path.exists(external_providers_dir):
|
if not os.path.exists(external_providers_dir):
|
||||||
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")
|
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")
|
||||||
logger.info(f"Loading external providers from {external_providers_dir}")
|
log.info(f"Loading external providers from {external_providers_dir}")
|
||||||
|
|
||||||
for api in providable_apis():
|
for api in providable_apis():
|
||||||
api_name = api.name.lower()
|
api_name = api.name.lower()
|
||||||
|
@ -198,13 +198,13 @@ def get_external_providers_from_dir(
|
||||||
for provider_type in ["remote", "inline"]:
|
for provider_type in ["remote", "inline"]:
|
||||||
api_dir = os.path.join(external_providers_dir, provider_type, api_name)
|
api_dir = os.path.join(external_providers_dir, provider_type, api_name)
|
||||||
if not os.path.exists(api_dir):
|
if not os.path.exists(api_dir):
|
||||||
logger.debug(f"No {provider_type} provider directory found for {api_name}")
|
log.debug(f"No {provider_type} provider directory found for {api_name}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Look for provider spec files in the API directory
|
# Look for provider spec files in the API directory
|
||||||
for spec_path in glob.glob(os.path.join(api_dir, "*.yaml")):
|
for spec_path in glob.glob(os.path.join(api_dir, "*.yaml")):
|
||||||
provider_name = os.path.splitext(os.path.basename(spec_path))[0]
|
provider_name = os.path.splitext(os.path.basename(spec_path))[0]
|
||||||
logger.info(f"Loading {provider_type} provider spec from {spec_path}")
|
log.info(f"Loading {provider_type} provider spec from {spec_path}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(spec_path) as f:
|
with open(spec_path) as f:
|
||||||
|
@ -217,16 +217,16 @@ def get_external_providers_from_dir(
|
||||||
spec = _load_inline_provider_spec(spec_data, api, provider_name)
|
spec = _load_inline_provider_spec(spec_data, api, provider_name)
|
||||||
provider_type_key = f"inline::{provider_name}"
|
provider_type_key = f"inline::{provider_name}"
|
||||||
|
|
||||||
logger.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}")
|
log.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}")
|
||||||
if provider_type_key in registry[api]:
|
if provider_type_key in registry[api]:
|
||||||
logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
|
log.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
|
||||||
registry[api][provider_type_key] = spec
|
registry[api][provider_type_key] = spec
|
||||||
logger.info(f"Successfully loaded external provider {provider_type_key}")
|
log.info(f"Successfully loaded external provider {provider_type_key}")
|
||||||
except yaml.YAMLError as yaml_err:
|
except yaml.YAMLError as yaml_err:
|
||||||
logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
|
log.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
|
||||||
raise yaml_err
|
raise yaml_err
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to load provider spec from {spec_path}: {e}")
|
log.error(f"Failed to load provider spec from {spec_path}: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
return registry
|
return registry
|
||||||
|
@ -241,7 +241,7 @@ def get_external_providers_from_module(
|
||||||
else:
|
else:
|
||||||
provider_list = config.providers.items()
|
provider_list = config.providers.items()
|
||||||
if provider_list is None:
|
if provider_list is None:
|
||||||
logger.warning("Could not get list of providers from config")
|
log.warning("Could not get list of providers from config")
|
||||||
return registry
|
return registry
|
||||||
for provider_api, providers in provider_list:
|
for provider_api, providers in provider_list:
|
||||||
for provider in providers:
|
for provider in providers:
|
||||||
|
@ -272,6 +272,6 @@ def get_external_providers_from_module(
|
||||||
"get_provider_spec not found. If specifying an external provider via `module` in the Provider spec, the Provider must have the `provider.get_provider_spec` module available"
|
"get_provider_spec not found. If specifying an external provider via `module` in the Provider spec, the Provider must have the `provider.get_provider_spec` module available"
|
||||||
) from exc
|
) from exc
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to load provider spec from module {provider.module}: {e}")
|
log.error(f"Failed to load provider spec from module {provider.module}: {e}")
|
||||||
raise e
|
raise e
|
||||||
return registry
|
return registry
|
||||||
|
|
|
@ -11,7 +11,7 @@ from llama_stack.apis.datatypes import Api, ExternalApiSpec
|
||||||
from llama_stack.core.datatypes import BuildConfig, StackRunConfig
|
from llama_stack.core.datatypes import BuildConfig, StackRunConfig
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
log = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
def load_external_apis(config: StackRunConfig | BuildConfig | None) -> dict[Api, ExternalApiSpec]:
|
def load_external_apis(config: StackRunConfig | BuildConfig | None) -> dict[Api, ExternalApiSpec]:
|
||||||
|
@ -28,10 +28,10 @@ def load_external_apis(config: StackRunConfig | BuildConfig | None) -> dict[Api,
|
||||||
|
|
||||||
external_apis_dir = config.external_apis_dir.expanduser().resolve()
|
external_apis_dir = config.external_apis_dir.expanduser().resolve()
|
||||||
if not external_apis_dir.is_dir():
|
if not external_apis_dir.is_dir():
|
||||||
logger.error(f"External APIs directory is not a directory: {external_apis_dir}")
|
log.error(f"External APIs directory is not a directory: {external_apis_dir}")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
logger.info(f"Loading external APIs from {external_apis_dir}")
|
log.info(f"Loading external APIs from {external_apis_dir}")
|
||||||
external_apis: dict[Api, ExternalApiSpec] = {}
|
external_apis: dict[Api, ExternalApiSpec] = {}
|
||||||
|
|
||||||
# Look for YAML files in the external APIs directory
|
# Look for YAML files in the external APIs directory
|
||||||
|
@ -42,13 +42,13 @@ def load_external_apis(config: StackRunConfig | BuildConfig | None) -> dict[Api,
|
||||||
|
|
||||||
spec = ExternalApiSpec(**spec_data)
|
spec = ExternalApiSpec(**spec_data)
|
||||||
api = Api.add(spec.name)
|
api = Api.add(spec.name)
|
||||||
logger.info(f"Loaded external API spec for {spec.name} from {yaml_path}")
|
log.info(f"Loaded external API spec for {spec.name} from {yaml_path}")
|
||||||
external_apis[api] = spec
|
external_apis[api] = spec
|
||||||
except yaml.YAMLError as yaml_err:
|
except yaml.YAMLError as yaml_err:
|
||||||
logger.error(f"Failed to parse YAML file {yaml_path}: {yaml_err}")
|
log.error(f"Failed to parse YAML file {yaml_path}: {yaml_err}")
|
||||||
raise
|
raise
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(f"Failed to load external API spec from {yaml_path}")
|
log.exception(f"Failed to load external API spec from {yaml_path}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
return external_apis
|
return external_apis
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
@ -48,6 +47,7 @@ from llama_stack.core.stack import (
|
||||||
from llama_stack.core.utils.config import redact_sensitive_fields
|
from llama_stack.core.utils.config import redact_sensitive_fields
|
||||||
from llama_stack.core.utils.context import preserve_contexts_async_generator
|
from llama_stack.core.utils.context import preserve_contexts_async_generator
|
||||||
from llama_stack.core.utils.exec import in_notebook
|
from llama_stack.core.utils.exec import in_notebook
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.telemetry.tracing import (
|
from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
CURRENT_TRACE_CONTEXT,
|
CURRENT_TRACE_CONTEXT,
|
||||||
end_trace,
|
end_trace,
|
||||||
|
@ -55,7 +55,7 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
start_trace,
|
start_trace,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
@ -84,7 +84,7 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any:
|
||||||
try:
|
try:
|
||||||
return [convert_to_pydantic(item_type, item) for item in value]
|
return [convert_to_pydantic(item_type, item) for item in value]
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.error(f"Error converting list {value} into {item_type}")
|
log.error(f"Error converting list {value} into {item_type}")
|
||||||
return value
|
return value
|
||||||
|
|
||||||
elif origin is dict:
|
elif origin is dict:
|
||||||
|
@ -92,7 +92,7 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any:
|
||||||
try:
|
try:
|
||||||
return {k: convert_to_pydantic(val_type, v) for k, v in value.items()}
|
return {k: convert_to_pydantic(val_type, v) for k, v in value.items()}
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.error(f"Error converting dict {value} into {val_type}")
|
log.error(f"Error converting dict {value} into {val_type}")
|
||||||
return value
|
return value
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -108,7 +108,7 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any:
|
||||||
return convert_to_pydantic(union_type, value)
|
return convert_to_pydantic(union_type, value)
|
||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
logger.warning(
|
log.warning(
|
||||||
f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}",
|
f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}",
|
||||||
)
|
)
|
||||||
raise ValueError(f"Failed to convert parameter {value} into {annotation}: {e}") from e
|
raise ValueError(f"Failed to convert parameter {value} into {annotation}: {e}") from e
|
||||||
|
@ -171,13 +171,15 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
|
|
||||||
def _remove_root_logger_handlers(self):
|
def _remove_root_logger_handlers(self):
|
||||||
"""
|
"""
|
||||||
Remove all handlers from the root logger. Needed to avoid polluting the console with logs.
|
Remove all handlers from the root log. Needed to avoid polluting the console with logs.
|
||||||
"""
|
"""
|
||||||
|
import logging # allow-direct-logging
|
||||||
|
|
||||||
root_logger = logging.getLogger()
|
root_logger = logging.getLogger()
|
||||||
|
|
||||||
for handler in root_logger.handlers[:]:
|
for handler in root_logger.handlers[:]:
|
||||||
root_logger.removeHandler(handler)
|
root_logger.removeHandler(handler)
|
||||||
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
|
log.info(f"Removed handler {handler.__class__.__name__} from root log")
|
||||||
|
|
||||||
def request(self, *args, **kwargs):
|
def request(self, *args, **kwargs):
|
||||||
loop = self.loop
|
loop = self.loop
|
||||||
|
|
|
@ -16,7 +16,7 @@ from llama_stack.providers.datatypes import HealthResponse, HealthStatus
|
||||||
from .datatypes import StackRunConfig
|
from .datatypes import StackRunConfig
|
||||||
from .utils.config import redact_sensitive_fields
|
from .utils.config import redact_sensitive_fields
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
log = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
class ProviderImplConfig(BaseModel):
|
class ProviderImplConfig(BaseModel):
|
||||||
|
@ -38,7 +38,7 @@ class ProviderImpl(Providers):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
logger.debug("ProviderImpl.shutdown")
|
log.debug("ProviderImpl.shutdown")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def list_providers(self) -> ListProvidersResponse:
|
async def list_providers(self) -> ListProvidersResponse:
|
||||||
|
|
|
@ -6,19 +6,19 @@
|
||||||
|
|
||||||
import contextvars
|
import contextvars
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
from contextlib import AbstractContextManager
|
from contextlib import AbstractContextManager
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.core.datatypes import User
|
from llama_stack.core.datatypes import User
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .utils.dynamic import instantiate_class_type
|
from .utils.dynamic import instantiate_class_type
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Context variable for request provider data and auth attributes
|
# Context variable for request provider data and auth attributes
|
||||||
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
|
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
|
||||||
|
|
||||||
|
log = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
class RequestProviderDataContext(AbstractContextManager):
|
class RequestProviderDataContext(AbstractContextManager):
|
||||||
"""Context manager for request provider data"""
|
"""Context manager for request provider data"""
|
||||||
|
|
|
@ -54,7 +54,7 @@ from llama_stack.providers.datatypes import (
|
||||||
VectorDBsProtocolPrivate,
|
VectorDBsProtocolPrivate,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
log = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
class InvalidProviderError(Exception):
|
class InvalidProviderError(Exception):
|
||||||
|
@ -101,7 +101,7 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
|
||||||
|
|
||||||
protocols[api] = api_class
|
protocols[api] = api_class
|
||||||
except (ImportError, AttributeError):
|
except (ImportError, AttributeError):
|
||||||
logger.exception(f"Failed to load external API {api_spec.name}")
|
log.exception(f"Failed to load external API {api_spec.name}")
|
||||||
|
|
||||||
return protocols
|
return protocols
|
||||||
|
|
||||||
|
@ -223,7 +223,7 @@ def validate_and_prepare_providers(
|
||||||
specs = {}
|
specs = {}
|
||||||
for provider in providers:
|
for provider in providers:
|
||||||
if not provider.provider_id or provider.provider_id == "__disabled__":
|
if not provider.provider_id or provider.provider_id == "__disabled__":
|
||||||
logger.debug(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
|
log.debug(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
validate_provider(provider, api, provider_registry)
|
validate_provider(provider, api, provider_registry)
|
||||||
|
@ -245,10 +245,10 @@ def validate_provider(provider: Provider, api: Api, provider_registry: ProviderR
|
||||||
|
|
||||||
p = provider_registry[api][provider.provider_type]
|
p = provider_registry[api][provider.provider_type]
|
||||||
if p.deprecation_error:
|
if p.deprecation_error:
|
||||||
logger.error(p.deprecation_error)
|
log.error(p.deprecation_error)
|
||||||
raise InvalidProviderError(p.deprecation_error)
|
raise InvalidProviderError(p.deprecation_error)
|
||||||
elif p.deprecation_warning:
|
elif p.deprecation_warning:
|
||||||
logger.warning(
|
log.warning(
|
||||||
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
|
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -261,9 +261,9 @@ def sort_providers_by_deps(
|
||||||
{k: list(v.values()) for k, v in providers_with_specs.items()}
|
{k: list(v.values()) for k, v in providers_with_specs.items()}
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"Resolved {len(sorted_providers)} providers")
|
log.debug(f"Resolved {len(sorted_providers)} providers")
|
||||||
for api_str, provider in sorted_providers:
|
for api_str, provider in sorted_providers:
|
||||||
logger.debug(f" {api_str} => {provider.provider_id}")
|
log.debug(f" {api_str} => {provider.provider_id}")
|
||||||
return sorted_providers
|
return sorted_providers
|
||||||
|
|
||||||
|
|
||||||
|
@ -348,7 +348,7 @@ async def instantiate_provider(
|
||||||
if not hasattr(provider_spec, "module") or provider_spec.module is None:
|
if not hasattr(provider_spec, "module") or provider_spec.module is None:
|
||||||
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
|
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
|
||||||
|
|
||||||
logger.debug(f"Instantiating provider {provider.provider_id} from {provider_spec.module}")
|
log.debug(f"Instantiating provider {provider.provider_id} from {provider_spec.module}")
|
||||||
module = importlib.import_module(provider_spec.module)
|
module = importlib.import_module(provider_spec.module)
|
||||||
args = []
|
args = []
|
||||||
if isinstance(provider_spec, RemoteProviderSpec):
|
if isinstance(provider_spec, RemoteProviderSpec):
|
||||||
|
@ -418,7 +418,7 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
|
||||||
obj_params = set(obj_sig.parameters)
|
obj_params = set(obj_sig.parameters)
|
||||||
obj_params.discard("self")
|
obj_params.discard("self")
|
||||||
if not (proto_params <= obj_params):
|
if not (proto_params <= obj_params):
|
||||||
logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
|
log.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
|
||||||
missing_methods.append((name, "signature_mismatch"))
|
missing_methods.append((name, "signature_mismatch"))
|
||||||
else:
|
else:
|
||||||
# Check if the method has a concrete implementation (not just a protocol stub)
|
# Check if the method has a concrete implementation (not just a protocol stub)
|
||||||
|
|
|
@ -12,7 +12,7 @@ from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
from llama_stack.providers.datatypes import RoutingTable
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
log = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
class DatasetIORouter(DatasetIO):
|
class DatasetIORouter(DatasetIO):
|
||||||
|
@ -20,15 +20,15 @@ class DatasetIORouter(DatasetIO):
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug("Initializing DatasetIORouter")
|
log.debug("Initializing DatasetIORouter")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
logger.debug("DatasetIORouter.initialize")
|
log.debug("DatasetIORouter.initialize")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
logger.debug("DatasetIORouter.shutdown")
|
log.debug("DatasetIORouter.shutdown")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_dataset(
|
async def register_dataset(
|
||||||
|
@ -38,7 +38,7 @@ class DatasetIORouter(DatasetIO):
|
||||||
metadata: dict[str, Any] | None = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
dataset_id: str | None = None,
|
dataset_id: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug(
|
log.debug(
|
||||||
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
|
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
|
||||||
)
|
)
|
||||||
await self.routing_table.register_dataset(
|
await self.routing_table.register_dataset(
|
||||||
|
@ -54,7 +54,7 @@ class DatasetIORouter(DatasetIO):
|
||||||
start_index: int | None = None,
|
start_index: int | None = None,
|
||||||
limit: int | None = None,
|
limit: int | None = None,
|
||||||
) -> PaginatedResponse:
|
) -> PaginatedResponse:
|
||||||
logger.debug(
|
log.debug(
|
||||||
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
|
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
|
||||||
)
|
)
|
||||||
provider = await self.routing_table.get_provider_impl(dataset_id)
|
provider = await self.routing_table.get_provider_impl(dataset_id)
|
||||||
|
@ -65,7 +65,7 @@ class DatasetIORouter(DatasetIO):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
|
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
|
||||||
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
|
log.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
|
||||||
provider = await self.routing_table.get_provider_impl(dataset_id)
|
provider = await self.routing_table.get_provider_impl(dataset_id)
|
||||||
return await provider.append_rows(
|
return await provider.append_rows(
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
|
|
|
@ -16,7 +16,7 @@ from llama_stack.apis.scoring import (
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
from llama_stack.providers.datatypes import RoutingTable
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
log = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
class ScoringRouter(Scoring):
|
class ScoringRouter(Scoring):
|
||||||
|
@ -24,15 +24,15 @@ class ScoringRouter(Scoring):
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug("Initializing ScoringRouter")
|
log.debug("Initializing ScoringRouter")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
logger.debug("ScoringRouter.initialize")
|
log.debug("ScoringRouter.initialize")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
logger.debug("ScoringRouter.shutdown")
|
log.debug("ScoringRouter.shutdown")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def score_batch(
|
async def score_batch(
|
||||||
|
@ -41,7 +41,7 @@ class ScoringRouter(Scoring):
|
||||||
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
||||||
save_results_dataset: bool = False,
|
save_results_dataset: bool = False,
|
||||||
) -> ScoreBatchResponse:
|
) -> ScoreBatchResponse:
|
||||||
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
|
log.debug(f"ScoringRouter.score_batch: {dataset_id}")
|
||||||
res = {}
|
res = {}
|
||||||
for fn_identifier in scoring_functions.keys():
|
for fn_identifier in scoring_functions.keys():
|
||||||
provider = await self.routing_table.get_provider_impl(fn_identifier)
|
provider = await self.routing_table.get_provider_impl(fn_identifier)
|
||||||
|
@ -63,7 +63,7 @@ class ScoringRouter(Scoring):
|
||||||
input_rows: list[dict[str, Any]],
|
input_rows: list[dict[str, Any]],
|
||||||
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
||||||
) -> ScoreResponse:
|
) -> ScoreResponse:
|
||||||
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
|
log.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
|
||||||
res = {}
|
res = {}
|
||||||
# look up and map each scoring function to its provider impl
|
# look up and map each scoring function to its provider impl
|
||||||
for fn_identifier in scoring_functions.keys():
|
for fn_identifier in scoring_functions.keys():
|
||||||
|
@ -82,15 +82,15 @@ class EvalRouter(Eval):
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug("Initializing EvalRouter")
|
log.debug("Initializing EvalRouter")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
logger.debug("EvalRouter.initialize")
|
log.debug("EvalRouter.initialize")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
logger.debug("EvalRouter.shutdown")
|
log.debug("EvalRouter.shutdown")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def run_eval(
|
async def run_eval(
|
||||||
|
@ -98,7 +98,7 @@ class EvalRouter(Eval):
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
benchmark_config: BenchmarkConfig,
|
benchmark_config: BenchmarkConfig,
|
||||||
) -> Job:
|
) -> Job:
|
||||||
logger.debug(f"EvalRouter.run_eval: {benchmark_id}")
|
log.debug(f"EvalRouter.run_eval: {benchmark_id}")
|
||||||
provider = await self.routing_table.get_provider_impl(benchmark_id)
|
provider = await self.routing_table.get_provider_impl(benchmark_id)
|
||||||
return await provider.run_eval(
|
return await provider.run_eval(
|
||||||
benchmark_id=benchmark_id,
|
benchmark_id=benchmark_id,
|
||||||
|
@ -112,7 +112,7 @@ class EvalRouter(Eval):
|
||||||
scoring_functions: list[str],
|
scoring_functions: list[str],
|
||||||
benchmark_config: BenchmarkConfig,
|
benchmark_config: BenchmarkConfig,
|
||||||
) -> EvaluateResponse:
|
) -> EvaluateResponse:
|
||||||
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
|
log.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
|
||||||
provider = await self.routing_table.get_provider_impl(benchmark_id)
|
provider = await self.routing_table.get_provider_impl(benchmark_id)
|
||||||
return await provider.evaluate_rows(
|
return await provider.evaluate_rows(
|
||||||
benchmark_id=benchmark_id,
|
benchmark_id=benchmark_id,
|
||||||
|
@ -126,7 +126,7 @@ class EvalRouter(Eval):
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
job_id: str,
|
job_id: str,
|
||||||
) -> Job:
|
) -> Job:
|
||||||
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
|
log.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
|
||||||
provider = await self.routing_table.get_provider_impl(benchmark_id)
|
provider = await self.routing_table.get_provider_impl(benchmark_id)
|
||||||
return await provider.job_status(benchmark_id, job_id)
|
return await provider.job_status(benchmark_id, job_id)
|
||||||
|
|
||||||
|
@ -135,7 +135,7 @@ class EvalRouter(Eval):
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
job_id: str,
|
job_id: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
|
log.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
|
||||||
provider = await self.routing_table.get_provider_impl(benchmark_id)
|
provider = await self.routing_table.get_provider_impl(benchmark_id)
|
||||||
await provider.job_cancel(
|
await provider.job_cancel(
|
||||||
benchmark_id,
|
benchmark_id,
|
||||||
|
@ -147,7 +147,7 @@ class EvalRouter(Eval):
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
job_id: str,
|
job_id: str,
|
||||||
) -> EvaluateResponse:
|
) -> EvaluateResponse:
|
||||||
logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
|
log.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
|
||||||
provider = await self.routing_table.get_provider_impl(benchmark_id)
|
provider = await self.routing_table.get_provider_impl(benchmark_id)
|
||||||
return await provider.job_result(
|
return await provider.job_result(
|
||||||
benchmark_id,
|
benchmark_id,
|
||||||
|
|
|
@ -58,7 +58,7 @@ from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||||
from llama_stack.providers.utils.inference.stream_utils import stream_and_store_openai_completion
|
from llama_stack.providers.utils.inference.stream_utils import stream_and_store_openai_completion
|
||||||
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
log = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
class InferenceRouter(Inference):
|
class InferenceRouter(Inference):
|
||||||
|
@ -70,7 +70,7 @@ class InferenceRouter(Inference):
|
||||||
telemetry: Telemetry | None = None,
|
telemetry: Telemetry | None = None,
|
||||||
store: InferenceStore | None = None,
|
store: InferenceStore | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug("Initializing InferenceRouter")
|
log.debug("Initializing InferenceRouter")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
self.telemetry = telemetry
|
self.telemetry = telemetry
|
||||||
self.store = store
|
self.store = store
|
||||||
|
@ -79,10 +79,10 @@ class InferenceRouter(Inference):
|
||||||
self.formatter = ChatFormat(self.tokenizer)
|
self.formatter = ChatFormat(self.tokenizer)
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
logger.debug("InferenceRouter.initialize")
|
log.debug("InferenceRouter.initialize")
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
logger.debug("InferenceRouter.shutdown")
|
log.debug("InferenceRouter.shutdown")
|
||||||
|
|
||||||
async def register_model(
|
async def register_model(
|
||||||
self,
|
self,
|
||||||
|
@ -92,7 +92,7 @@ class InferenceRouter(Inference):
|
||||||
metadata: dict[str, Any] | None = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
model_type: ModelType | None = None,
|
model_type: ModelType | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug(
|
log.debug(
|
||||||
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
|
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
|
||||||
)
|
)
|
||||||
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
||||||
|
@ -117,7 +117,7 @@ class InferenceRouter(Inference):
|
||||||
"""
|
"""
|
||||||
span = get_current_span()
|
span = get_current_span()
|
||||||
if span is None:
|
if span is None:
|
||||||
logger.warning("No span found for token usage metrics")
|
log.warning("No span found for token usage metrics")
|
||||||
return []
|
return []
|
||||||
metrics = [
|
metrics = [
|
||||||
("prompt_tokens", prompt_tokens),
|
("prompt_tokens", prompt_tokens),
|
||||||
|
@ -182,7 +182,7 @@ class InferenceRouter(Inference):
|
||||||
logprobs: LogProbConfig | None = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
tool_config: ToolConfig | None = None,
|
tool_config: ToolConfig | None = None,
|
||||||
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
|
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||||
logger.debug(
|
log.debug(
|
||||||
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
||||||
)
|
)
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
|
@ -288,7 +288,7 @@ class InferenceRouter(Inference):
|
||||||
response_format: ResponseFormat | None = None,
|
response_format: ResponseFormat | None = None,
|
||||||
logprobs: LogProbConfig | None = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
) -> BatchChatCompletionResponse:
|
) -> BatchChatCompletionResponse:
|
||||||
logger.debug(
|
log.debug(
|
||||||
f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
|
f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
|
||||||
)
|
)
|
||||||
provider = await self.routing_table.get_provider_impl(model_id)
|
provider = await self.routing_table.get_provider_impl(model_id)
|
||||||
|
@ -313,7 +313,7 @@ class InferenceRouter(Inference):
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
sampling_params = SamplingParams()
|
sampling_params = SamplingParams()
|
||||||
logger.debug(
|
log.debug(
|
||||||
f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}",
|
f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}",
|
||||||
)
|
)
|
||||||
model = await self.routing_table.get_model(model_id)
|
model = await self.routing_table.get_model(model_id)
|
||||||
|
@ -374,7 +374,7 @@ class InferenceRouter(Inference):
|
||||||
response_format: ResponseFormat | None = None,
|
response_format: ResponseFormat | None = None,
|
||||||
logprobs: LogProbConfig | None = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
) -> BatchCompletionResponse:
|
) -> BatchCompletionResponse:
|
||||||
logger.debug(
|
log.debug(
|
||||||
f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
|
f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
|
||||||
)
|
)
|
||||||
provider = await self.routing_table.get_provider_impl(model_id)
|
provider = await self.routing_table.get_provider_impl(model_id)
|
||||||
|
@ -388,7 +388,7 @@ class InferenceRouter(Inference):
|
||||||
output_dimension: int | None = None,
|
output_dimension: int | None = None,
|
||||||
task_type: EmbeddingTaskType | None = None,
|
task_type: EmbeddingTaskType | None = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
logger.debug(f"InferenceRouter.embeddings: {model_id}")
|
log.debug(f"InferenceRouter.embeddings: {model_id}")
|
||||||
model = await self.routing_table.get_model(model_id)
|
model = await self.routing_table.get_model(model_id)
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ModelNotFoundError(model_id)
|
raise ModelNotFoundError(model_id)
|
||||||
|
@ -426,7 +426,7 @@ class InferenceRouter(Inference):
|
||||||
prompt_logprobs: int | None = None,
|
prompt_logprobs: int | None = None,
|
||||||
suffix: str | None = None,
|
suffix: str | None = None,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
logger.debug(
|
log.debug(
|
||||||
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
|
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
|
||||||
)
|
)
|
||||||
model_obj = await self.routing_table.get_model(model)
|
model_obj = await self.routing_table.get_model(model)
|
||||||
|
@ -487,7 +487,7 @@ class InferenceRouter(Inference):
|
||||||
top_p: float | None = None,
|
top_p: float | None = None,
|
||||||
user: str | None = None,
|
user: str | None = None,
|
||||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||||
logger.debug(
|
log.debug(
|
||||||
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}",
|
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}",
|
||||||
)
|
)
|
||||||
model_obj = await self.routing_table.get_model(model)
|
model_obj = await self.routing_table.get_model(model)
|
||||||
|
@ -558,7 +558,7 @@ class InferenceRouter(Inference):
|
||||||
dimensions: int | None = None,
|
dimensions: int | None = None,
|
||||||
user: str | None = None,
|
user: str | None = None,
|
||||||
) -> OpenAIEmbeddingsResponse:
|
) -> OpenAIEmbeddingsResponse:
|
||||||
logger.debug(
|
log.debug(
|
||||||
f"InferenceRouter.openai_embeddings: {model=}, input_type={type(input)}, {encoding_format=}, {dimensions=}",
|
f"InferenceRouter.openai_embeddings: {model=}, input_type={type(input)}, {encoding_format=}, {dimensions=}",
|
||||||
)
|
)
|
||||||
model_obj = await self.routing_table.get_model(model)
|
model_obj = await self.routing_table.get_model(model)
|
||||||
|
|
|
@ -14,7 +14,7 @@ from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
from llama_stack.providers.datatypes import RoutingTable
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
log = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
class SafetyRouter(Safety):
|
class SafetyRouter(Safety):
|
||||||
|
@ -22,15 +22,15 @@ class SafetyRouter(Safety):
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug("Initializing SafetyRouter")
|
log.debug("Initializing SafetyRouter")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
logger.debug("SafetyRouter.initialize")
|
log.debug("SafetyRouter.initialize")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
logger.debug("SafetyRouter.shutdown")
|
log.debug("SafetyRouter.shutdown")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_shield(
|
async def register_shield(
|
||||||
|
@ -40,7 +40,7 @@ class SafetyRouter(Safety):
|
||||||
provider_id: str | None = None,
|
provider_id: str | None = None,
|
||||||
params: dict[str, Any] | None = None,
|
params: dict[str, Any] | None = None,
|
||||||
) -> Shield:
|
) -> Shield:
|
||||||
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
log.debug(f"SafetyRouter.register_shield: {shield_id}")
|
||||||
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
||||||
|
|
||||||
async def unregister_shield(self, identifier: str) -> None:
|
async def unregister_shield(self, identifier: str) -> None:
|
||||||
|
@ -53,7 +53,7 @@ class SafetyRouter(Safety):
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
params: dict[str, Any] = None,
|
params: dict[str, Any] = None,
|
||||||
) -> RunShieldResponse:
|
) -> RunShieldResponse:
|
||||||
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
log.debug(f"SafetyRouter.run_shield: {shield_id}")
|
||||||
provider = await self.routing_table.get_provider_impl(shield_id)
|
provider = await self.routing_table.get_provider_impl(shield_id)
|
||||||
return await provider.run_shield(
|
return await provider.run_shield(
|
||||||
shield_id=shield_id,
|
shield_id=shield_id,
|
||||||
|
|
|
@ -22,7 +22,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
|
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
log = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
class ToolRuntimeRouter(ToolRuntime):
|
class ToolRuntimeRouter(ToolRuntime):
|
||||||
|
@ -31,7 +31,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
self,
|
self,
|
||||||
routing_table: ToolGroupsRoutingTable,
|
routing_table: ToolGroupsRoutingTable,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
|
log.debug("Initializing ToolRuntimeRouter.RagToolImpl")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
|
||||||
async def query(
|
async def query(
|
||||||
|
@ -40,7 +40,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
vector_db_ids: list[str],
|
vector_db_ids: list[str],
|
||||||
query_config: RAGQueryConfig | None = None,
|
query_config: RAGQueryConfig | None = None,
|
||||||
) -> RAGQueryResult:
|
) -> RAGQueryResult:
|
||||||
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
|
log.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
|
||||||
provider = await self.routing_table.get_provider_impl("knowledge_search")
|
provider = await self.routing_table.get_provider_impl("knowledge_search")
|
||||||
return await provider.query(content, vector_db_ids, query_config)
|
return await provider.query(content, vector_db_ids, query_config)
|
||||||
|
|
||||||
|
@ -50,7 +50,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
vector_db_id: str,
|
vector_db_id: str,
|
||||||
chunk_size_in_tokens: int = 512,
|
chunk_size_in_tokens: int = 512,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug(
|
log.debug(
|
||||||
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
|
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
|
||||||
)
|
)
|
||||||
provider = await self.routing_table.get_provider_impl("insert_into_memory")
|
provider = await self.routing_table.get_provider_impl("insert_into_memory")
|
||||||
|
@ -60,7 +60,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
self,
|
self,
|
||||||
routing_table: ToolGroupsRoutingTable,
|
routing_table: ToolGroupsRoutingTable,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug("Initializing ToolRuntimeRouter")
|
log.debug("Initializing ToolRuntimeRouter")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
|
||||||
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
|
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
|
||||||
|
@ -69,15 +69,15 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
|
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
logger.debug("ToolRuntimeRouter.initialize")
|
log.debug("ToolRuntimeRouter.initialize")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
logger.debug("ToolRuntimeRouter.shutdown")
|
log.debug("ToolRuntimeRouter.shutdown")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any:
|
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any:
|
||||||
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
|
log.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
|
||||||
provider = await self.routing_table.get_provider_impl(tool_name)
|
provider = await self.routing_table.get_provider_impl(tool_name)
|
||||||
return await provider.invoke_tool(
|
return await provider.invoke_tool(
|
||||||
tool_name=tool_name,
|
tool_name=tool_name,
|
||||||
|
@ -87,5 +87,5 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
async def list_runtime_tools(
|
async def list_runtime_tools(
|
||||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||||
) -> ListToolsResponse:
|
) -> ListToolsResponse:
|
||||||
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
log.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
||||||
return await self.routing_table.list_tools(tool_group_id)
|
return await self.routing_table.list_tools(tool_group_id)
|
||||||
|
|
|
@ -30,7 +30,7 @@ from llama_stack.apis.vector_io import (
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
log = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
class VectorIORouter(VectorIO):
|
class VectorIORouter(VectorIO):
|
||||||
|
@ -40,15 +40,15 @@ class VectorIORouter(VectorIO):
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug("Initializing VectorIORouter")
|
log.debug("Initializing VectorIORouter")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
logger.debug("VectorIORouter.initialize")
|
log.debug("VectorIORouter.initialize")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
logger.debug("VectorIORouter.shutdown")
|
log.debug("VectorIORouter.shutdown")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def _get_first_embedding_model(self) -> tuple[str, int] | None:
|
async def _get_first_embedding_model(self) -> tuple[str, int] | None:
|
||||||
|
@ -70,10 +70,10 @@ class VectorIORouter(VectorIO):
|
||||||
raise ValueError(f"Embedding model {embedding_models[0].identifier} has no embedding dimension")
|
raise ValueError(f"Embedding model {embedding_models[0].identifier} has no embedding dimension")
|
||||||
return embedding_models[0].identifier, dimension
|
return embedding_models[0].identifier, dimension
|
||||||
else:
|
else:
|
||||||
logger.warning("No embedding models found in the routing table")
|
log.warning("No embedding models found in the routing table")
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting embedding models: {e}")
|
log.error(f"Error getting embedding models: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def register_vector_db(
|
async def register_vector_db(
|
||||||
|
@ -85,7 +85,7 @@ class VectorIORouter(VectorIO):
|
||||||
vector_db_name: str | None = None,
|
vector_db_name: str | None = None,
|
||||||
provider_vector_db_id: str | None = None,
|
provider_vector_db_id: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
|
log.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
|
||||||
await self.routing_table.register_vector_db(
|
await self.routing_table.register_vector_db(
|
||||||
vector_db_id,
|
vector_db_id,
|
||||||
embedding_model,
|
embedding_model,
|
||||||
|
@ -101,7 +101,7 @@ class VectorIORouter(VectorIO):
|
||||||
chunks: list[Chunk],
|
chunks: list[Chunk],
|
||||||
ttl_seconds: int | None = None,
|
ttl_seconds: int | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug(
|
log.debug(
|
||||||
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
||||||
)
|
)
|
||||||
provider = await self.routing_table.get_provider_impl(vector_db_id)
|
provider = await self.routing_table.get_provider_impl(vector_db_id)
|
||||||
|
@ -113,7 +113,7 @@ class VectorIORouter(VectorIO):
|
||||||
query: InterleavedContent,
|
query: InterleavedContent,
|
||||||
params: dict[str, Any] | None = None,
|
params: dict[str, Any] | None = None,
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
|
log.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
|
||||||
provider = await self.routing_table.get_provider_impl(vector_db_id)
|
provider = await self.routing_table.get_provider_impl(vector_db_id)
|
||||||
return await provider.query_chunks(vector_db_id, query, params)
|
return await provider.query_chunks(vector_db_id, query, params)
|
||||||
|
|
||||||
|
@ -129,7 +129,7 @@ class VectorIORouter(VectorIO):
|
||||||
embedding_dimension: int | None = None,
|
embedding_dimension: int | None = None,
|
||||||
provider_id: str | None = None,
|
provider_id: str | None = None,
|
||||||
) -> VectorStoreObject:
|
) -> VectorStoreObject:
|
||||||
logger.debug(f"VectorIORouter.openai_create_vector_store: name={name}, provider_id={provider_id}")
|
log.debug(f"VectorIORouter.openai_create_vector_store: name={name}, provider_id={provider_id}")
|
||||||
|
|
||||||
# If no embedding model is provided, use the first available one
|
# If no embedding model is provided, use the first available one
|
||||||
if embedding_model is None:
|
if embedding_model is None:
|
||||||
|
@ -137,7 +137,7 @@ class VectorIORouter(VectorIO):
|
||||||
if embedding_model_info is None:
|
if embedding_model_info is None:
|
||||||
raise ValueError("No embedding model provided and no embedding models available in the system")
|
raise ValueError("No embedding model provided and no embedding models available in the system")
|
||||||
embedding_model, embedding_dimension = embedding_model_info
|
embedding_model, embedding_dimension = embedding_model_info
|
||||||
logger.info(f"No embedding model specified, using first available: {embedding_model}")
|
log.info(f"No embedding model specified, using first available: {embedding_model}")
|
||||||
|
|
||||||
vector_db_id = f"vs_{uuid.uuid4()}"
|
vector_db_id = f"vs_{uuid.uuid4()}"
|
||||||
registered_vector_db = await self.routing_table.register_vector_db(
|
registered_vector_db = await self.routing_table.register_vector_db(
|
||||||
|
@ -168,7 +168,7 @@ class VectorIORouter(VectorIO):
|
||||||
after: str | None = None,
|
after: str | None = None,
|
||||||
before: str | None = None,
|
before: str | None = None,
|
||||||
) -> VectorStoreListResponse:
|
) -> VectorStoreListResponse:
|
||||||
logger.debug(f"VectorIORouter.openai_list_vector_stores: limit={limit}")
|
log.debug(f"VectorIORouter.openai_list_vector_stores: limit={limit}")
|
||||||
# Route to default provider for now - could aggregate from all providers in the future
|
# Route to default provider for now - could aggregate from all providers in the future
|
||||||
# call retrieve on each vector dbs to get list of vector stores
|
# call retrieve on each vector dbs to get list of vector stores
|
||||||
vector_dbs = await self.routing_table.get_all_with_type("vector_db")
|
vector_dbs = await self.routing_table.get_all_with_type("vector_db")
|
||||||
|
@ -179,7 +179,7 @@ class VectorIORouter(VectorIO):
|
||||||
vector_store = await provider.openai_retrieve_vector_store(vector_db.identifier)
|
vector_store = await provider.openai_retrieve_vector_store(vector_db.identifier)
|
||||||
all_stores.append(vector_store)
|
all_stores.append(vector_store)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error retrieving vector store {vector_db.identifier}: {e}")
|
log.error(f"Error retrieving vector store {vector_db.identifier}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Sort by created_at
|
# Sort by created_at
|
||||||
|
@ -215,7 +215,7 @@ class VectorIORouter(VectorIO):
|
||||||
self,
|
self,
|
||||||
vector_store_id: str,
|
vector_store_id: str,
|
||||||
) -> VectorStoreObject:
|
) -> VectorStoreObject:
|
||||||
logger.debug(f"VectorIORouter.openai_retrieve_vector_store: {vector_store_id}")
|
log.debug(f"VectorIORouter.openai_retrieve_vector_store: {vector_store_id}")
|
||||||
return await self.routing_table.openai_retrieve_vector_store(vector_store_id)
|
return await self.routing_table.openai_retrieve_vector_store(vector_store_id)
|
||||||
|
|
||||||
async def openai_update_vector_store(
|
async def openai_update_vector_store(
|
||||||
|
@ -225,7 +225,7 @@ class VectorIORouter(VectorIO):
|
||||||
expires_after: dict[str, Any] | None = None,
|
expires_after: dict[str, Any] | None = None,
|
||||||
metadata: dict[str, Any] | None = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
) -> VectorStoreObject:
|
) -> VectorStoreObject:
|
||||||
logger.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}")
|
log.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}")
|
||||||
return await self.routing_table.openai_update_vector_store(
|
return await self.routing_table.openai_update_vector_store(
|
||||||
vector_store_id=vector_store_id,
|
vector_store_id=vector_store_id,
|
||||||
name=name,
|
name=name,
|
||||||
|
@ -237,7 +237,7 @@ class VectorIORouter(VectorIO):
|
||||||
self,
|
self,
|
||||||
vector_store_id: str,
|
vector_store_id: str,
|
||||||
) -> VectorStoreDeleteResponse:
|
) -> VectorStoreDeleteResponse:
|
||||||
logger.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}")
|
log.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}")
|
||||||
return await self.routing_table.openai_delete_vector_store(vector_store_id)
|
return await self.routing_table.openai_delete_vector_store(vector_store_id)
|
||||||
|
|
||||||
async def openai_search_vector_store(
|
async def openai_search_vector_store(
|
||||||
|
@ -250,7 +250,7 @@ class VectorIORouter(VectorIO):
|
||||||
rewrite_query: bool | None = False,
|
rewrite_query: bool | None = False,
|
||||||
search_mode: str | None = "vector",
|
search_mode: str | None = "vector",
|
||||||
) -> VectorStoreSearchResponsePage:
|
) -> VectorStoreSearchResponsePage:
|
||||||
logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}")
|
log.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}")
|
||||||
return await self.routing_table.openai_search_vector_store(
|
return await self.routing_table.openai_search_vector_store(
|
||||||
vector_store_id=vector_store_id,
|
vector_store_id=vector_store_id,
|
||||||
query=query,
|
query=query,
|
||||||
|
@ -268,7 +268,7 @@ class VectorIORouter(VectorIO):
|
||||||
attributes: dict[str, Any] | None = None,
|
attributes: dict[str, Any] | None = None,
|
||||||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||||
) -> VectorStoreFileObject:
|
) -> VectorStoreFileObject:
|
||||||
logger.debug(f"VectorIORouter.openai_attach_file_to_vector_store: {vector_store_id}, {file_id}")
|
log.debug(f"VectorIORouter.openai_attach_file_to_vector_store: {vector_store_id}, {file_id}")
|
||||||
return await self.routing_table.openai_attach_file_to_vector_store(
|
return await self.routing_table.openai_attach_file_to_vector_store(
|
||||||
vector_store_id=vector_store_id,
|
vector_store_id=vector_store_id,
|
||||||
file_id=file_id,
|
file_id=file_id,
|
||||||
|
@ -285,7 +285,7 @@ class VectorIORouter(VectorIO):
|
||||||
before: str | None = None,
|
before: str | None = None,
|
||||||
filter: VectorStoreFileStatus | None = None,
|
filter: VectorStoreFileStatus | None = None,
|
||||||
) -> list[VectorStoreFileObject]:
|
) -> list[VectorStoreFileObject]:
|
||||||
logger.debug(f"VectorIORouter.openai_list_files_in_vector_store: {vector_store_id}")
|
log.debug(f"VectorIORouter.openai_list_files_in_vector_store: {vector_store_id}")
|
||||||
return await self.routing_table.openai_list_files_in_vector_store(
|
return await self.routing_table.openai_list_files_in_vector_store(
|
||||||
vector_store_id=vector_store_id,
|
vector_store_id=vector_store_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
|
@ -300,7 +300,7 @@ class VectorIORouter(VectorIO):
|
||||||
vector_store_id: str,
|
vector_store_id: str,
|
||||||
file_id: str,
|
file_id: str,
|
||||||
) -> VectorStoreFileObject:
|
) -> VectorStoreFileObject:
|
||||||
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file: {vector_store_id}, {file_id}")
|
log.debug(f"VectorIORouter.openai_retrieve_vector_store_file: {vector_store_id}, {file_id}")
|
||||||
return await self.routing_table.openai_retrieve_vector_store_file(
|
return await self.routing_table.openai_retrieve_vector_store_file(
|
||||||
vector_store_id=vector_store_id,
|
vector_store_id=vector_store_id,
|
||||||
file_id=file_id,
|
file_id=file_id,
|
||||||
|
@ -311,7 +311,7 @@ class VectorIORouter(VectorIO):
|
||||||
vector_store_id: str,
|
vector_store_id: str,
|
||||||
file_id: str,
|
file_id: str,
|
||||||
) -> VectorStoreFileContentsResponse:
|
) -> VectorStoreFileContentsResponse:
|
||||||
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}")
|
log.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}")
|
||||||
return await self.routing_table.openai_retrieve_vector_store_file_contents(
|
return await self.routing_table.openai_retrieve_vector_store_file_contents(
|
||||||
vector_store_id=vector_store_id,
|
vector_store_id=vector_store_id,
|
||||||
file_id=file_id,
|
file_id=file_id,
|
||||||
|
@ -323,7 +323,7 @@ class VectorIORouter(VectorIO):
|
||||||
file_id: str,
|
file_id: str,
|
||||||
attributes: dict[str, Any],
|
attributes: dict[str, Any],
|
||||||
) -> VectorStoreFileObject:
|
) -> VectorStoreFileObject:
|
||||||
logger.debug(f"VectorIORouter.openai_update_vector_store_file: {vector_store_id}, {file_id}")
|
log.debug(f"VectorIORouter.openai_update_vector_store_file: {vector_store_id}, {file_id}")
|
||||||
return await self.routing_table.openai_update_vector_store_file(
|
return await self.routing_table.openai_update_vector_store_file(
|
||||||
vector_store_id=vector_store_id,
|
vector_store_id=vector_store_id,
|
||||||
file_id=file_id,
|
file_id=file_id,
|
||||||
|
@ -335,7 +335,7 @@ class VectorIORouter(VectorIO):
|
||||||
vector_store_id: str,
|
vector_store_id: str,
|
||||||
file_id: str,
|
file_id: str,
|
||||||
) -> VectorStoreFileDeleteResponse:
|
) -> VectorStoreFileDeleteResponse:
|
||||||
logger.debug(f"VectorIORouter.openai_delete_vector_store_file: {vector_store_id}, {file_id}")
|
log.debug(f"VectorIORouter.openai_delete_vector_store_file: {vector_store_id}, {file_id}")
|
||||||
return await self.routing_table.openai_delete_vector_store_file(
|
return await self.routing_table.openai_delete_vector_store_file(
|
||||||
vector_store_id=vector_store_id,
|
vector_store_id=vector_store_id,
|
||||||
file_id=file_id,
|
file_id=file_id,
|
||||||
|
|
|
@ -14,7 +14,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
from .common import CommonRoutingTableImpl
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
log = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
||||||
|
|
|
@ -23,7 +23,7 @@ from llama_stack.core.store import DistributionRegistry
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
log = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
def get_impl_api(p: Any) -> Api:
|
def get_impl_api(p: Any) -> Api:
|
||||||
|
@ -177,7 +177,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
|
|
||||||
# Check if user has permission to access this object
|
# Check if user has permission to access this object
|
||||||
if not is_action_allowed(self.policy, "read", obj, get_authenticated_user()):
|
if not is_action_allowed(self.policy, "read", obj, get_authenticated_user()):
|
||||||
logger.debug(f"Access denied to {type} '{identifier}'")
|
log.debug(f"Access denied to {type} '{identifier}'")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return obj
|
return obj
|
||||||
|
@ -205,7 +205,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
raise AccessDeniedError("create", obj, creator)
|
raise AccessDeniedError("create", obj, creator)
|
||||||
if creator:
|
if creator:
|
||||||
obj.owner = creator
|
obj.owner = creator
|
||||||
logger.info(f"Setting owner for {obj.type} '{obj.identifier}' to {obj.owner.principal}")
|
log.info(f"Setting owner for {obj.type} '{obj.identifier}' to {obj.owner.principal}")
|
||||||
|
|
||||||
registered_obj = await register_object_with_provider(obj, p)
|
registered_obj = await register_object_with_provider(obj, p)
|
||||||
# TODO: This needs to be fixed for all APIs once they return the registered object
|
# TODO: This needs to be fixed for all APIs once they return the registered object
|
||||||
|
@ -250,7 +250,7 @@ async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) ->
|
||||||
if model is not None:
|
if model is not None:
|
||||||
return model
|
return model
|
||||||
|
|
||||||
logger.warning(
|
log.warning(
|
||||||
f"WARNING: model identifier '{model_id}' not found in routing table. Falling back to "
|
f"WARNING: model identifier '{model_id}' not found in routing table. Falling back to "
|
||||||
"searching in all providers. This is only for backwards compatibility and will stop working "
|
"searching in all providers. This is only for backwards compatibility and will stop working "
|
||||||
"soon. Migrate your calls to use fully scoped `provider_id/model_id` names."
|
"soon. Migrate your calls to use fully scoped `provider_id/model_id` names."
|
||||||
|
|
|
@ -26,7 +26,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
from .common import CommonRoutingTableImpl
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
log = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||||
|
|
|
@ -17,7 +17,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl, lookup_model
|
from .common import CommonRoutingTableImpl, lookup_model
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
log = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
|
@ -33,7 +33,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
try:
|
try:
|
||||||
models = await provider.list_models()
|
models = await provider.list_models()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"Model refresh failed for provider {provider_id}: {e}")
|
log.exception(f"Model refresh failed for provider {provider_id}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self.listed_providers.add(provider_id)
|
self.listed_providers.add(provider_id)
|
||||||
|
@ -132,7 +132,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
model_ids[model.provider_resource_id] = model.identifier
|
model_ids[model.provider_resource_id] = model.identifier
|
||||||
continue
|
continue
|
||||||
|
|
||||||
logger.debug(f"unregistering model {model.identifier}")
|
log.debug(f"unregistering model {model.identifier}")
|
||||||
await self.unregister_object(model)
|
await self.unregister_object(model)
|
||||||
|
|
||||||
for model in models:
|
for model in models:
|
||||||
|
@ -143,7 +143,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
if model.identifier == model.provider_resource_id:
|
if model.identifier == model.provider_resource_id:
|
||||||
model.identifier = f"{provider_id}/{model.provider_resource_id}"
|
model.identifier = f"{provider_id}/{model.provider_resource_id}"
|
||||||
|
|
||||||
logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})")
|
log.debug(f"registering model {model.identifier} ({model.provider_resource_id})")
|
||||||
await self.register_object(
|
await self.register_object(
|
||||||
ModelWithOwner(
|
ModelWithOwner(
|
||||||
identifier=model.identifier,
|
identifier=model.identifier,
|
||||||
|
|
|
@ -19,7 +19,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
from .common import CommonRoutingTableImpl
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
log = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||||
|
|
|
@ -15,7 +15,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
from .common import CommonRoutingTableImpl
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
log = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
|
|
|
@ -14,7 +14,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
from .common import CommonRoutingTableImpl
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
log = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None:
|
def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None:
|
||||||
|
|
|
@ -30,7 +30,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl, lookup_model
|
from .common import CommonRoutingTableImpl, lookup_model
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
log = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||||
|
@ -57,7 +57,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||||
if len(self.impls_by_provider_id) > 0:
|
if len(self.impls_by_provider_id) > 0:
|
||||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
if len(self.impls_by_provider_id) > 1:
|
if len(self.impls_by_provider_id) > 1:
|
||||||
logger.warning(
|
log.warning(
|
||||||
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
|
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -15,7 +15,7 @@ from llama_stack.core.server.auth_providers import create_auth_provider
|
||||||
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
|
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="auth")
|
log = get_logger(name=__name__, category="auth")
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationMiddleware:
|
class AuthenticationMiddleware:
|
||||||
|
@ -105,13 +105,13 @@ class AuthenticationMiddleware:
|
||||||
try:
|
try:
|
||||||
validation_result = await self.auth_provider.validate_token(token, scope)
|
validation_result = await self.auth_provider.validate_token(token, scope)
|
||||||
except httpx.TimeoutException:
|
except httpx.TimeoutException:
|
||||||
logger.exception("Authentication request timed out")
|
log.exception("Authentication request timed out")
|
||||||
return await self._send_auth_error(send, "Authentication service timeout")
|
return await self._send_auth_error(send, "Authentication service timeout")
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.exception("Error during authentication")
|
log.exception("Error during authentication")
|
||||||
return await self._send_auth_error(send, str(e))
|
return await self._send_auth_error(send, str(e))
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error during authentication")
|
log.exception("Error during authentication")
|
||||||
return await self._send_auth_error(send, "Authentication service error")
|
return await self._send_auth_error(send, "Authentication service error")
|
||||||
|
|
||||||
# Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware)
|
# Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware)
|
||||||
|
@ -122,7 +122,7 @@ class AuthenticationMiddleware:
|
||||||
scope["principal"] = validation_result.principal
|
scope["principal"] = validation_result.principal
|
||||||
if validation_result.attributes:
|
if validation_result.attributes:
|
||||||
scope["user_attributes"] = validation_result.attributes
|
scope["user_attributes"] = validation_result.attributes
|
||||||
logger.debug(
|
log.debug(
|
||||||
f"Authentication successful: {validation_result.principal} with {len(validation_result.attributes)} attributes"
|
f"Authentication successful: {validation_result.principal} with {len(validation_result.attributes)} attributes"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -23,7 +23,7 @@ from llama_stack.core.datatypes import (
|
||||||
)
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="auth")
|
log = get_logger(name=__name__, category="auth")
|
||||||
|
|
||||||
|
|
||||||
class AuthResponse(BaseModel):
|
class AuthResponse(BaseModel):
|
||||||
|
@ -163,7 +163,7 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
||||||
timeout=10.0, # Add a reasonable timeout
|
timeout=10.0, # Add a reasonable timeout
|
||||||
)
|
)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
logger.warning(f"Token introspection failed with status code: {response.status_code}")
|
log.warning(f"Token introspection failed with status code: {response.status_code}")
|
||||||
raise ValueError(f"Token introspection failed: {response.status_code}")
|
raise ValueError(f"Token introspection failed: {response.status_code}")
|
||||||
|
|
||||||
fields = response.json()
|
fields = response.json()
|
||||||
|
@ -176,13 +176,13 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
||||||
attributes=access_attributes,
|
attributes=access_attributes,
|
||||||
)
|
)
|
||||||
except httpx.TimeoutException:
|
except httpx.TimeoutException:
|
||||||
logger.exception("Token introspection request timed out")
|
log.exception("Token introspection request timed out")
|
||||||
raise
|
raise
|
||||||
except ValueError:
|
except ValueError:
|
||||||
# Re-raise ValueError exceptions to preserve their message
|
# Re-raise ValueError exceptions to preserve their message
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Error during token introspection")
|
log.exception("Error during token introspection")
|
||||||
raise ValueError("Token introspection error") from e
|
raise ValueError("Token introspection error") from e
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
|
@ -273,7 +273,7 @@ class CustomAuthProvider(AuthProvider):
|
||||||
timeout=10.0, # Add a reasonable timeout
|
timeout=10.0, # Add a reasonable timeout
|
||||||
)
|
)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
logger.warning(f"Authentication failed with status code: {response.status_code}")
|
log.warning(f"Authentication failed with status code: {response.status_code}")
|
||||||
raise ValueError(f"Authentication failed: {response.status_code}")
|
raise ValueError(f"Authentication failed: {response.status_code}")
|
||||||
|
|
||||||
# Parse and validate the auth response
|
# Parse and validate the auth response
|
||||||
|
@ -282,17 +282,17 @@ class CustomAuthProvider(AuthProvider):
|
||||||
auth_response = AuthResponse(**response_data)
|
auth_response = AuthResponse(**response_data)
|
||||||
return User(principal=auth_response.principal, attributes=auth_response.attributes)
|
return User(principal=auth_response.principal, attributes=auth_response.attributes)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Error parsing authentication response")
|
log.exception("Error parsing authentication response")
|
||||||
raise ValueError("Invalid authentication response format") from e
|
raise ValueError("Invalid authentication response format") from e
|
||||||
|
|
||||||
except httpx.TimeoutException:
|
except httpx.TimeoutException:
|
||||||
logger.exception("Authentication request timed out")
|
log.exception("Authentication request timed out")
|
||||||
raise
|
raise
|
||||||
except ValueError:
|
except ValueError:
|
||||||
# Re-raise ValueError exceptions to preserve their message
|
# Re-raise ValueError exceptions to preserve their message
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Error during authentication")
|
log.exception("Error during authentication")
|
||||||
raise ValueError("Authentication service error") from e
|
raise ValueError("Authentication service error") from e
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
|
@ -329,7 +329,7 @@ class GitHubTokenAuthProvider(AuthProvider):
|
||||||
try:
|
try:
|
||||||
user_info = await _get_github_user_info(token, self.config.github_api_base_url)
|
user_info = await _get_github_user_info(token, self.config.github_api_base_url)
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
logger.warning(f"GitHub token validation failed: {e}")
|
log.warning(f"GitHub token validation failed: {e}")
|
||||||
raise ValueError("GitHub token validation failed. Please check your token and try again.") from e
|
raise ValueError("GitHub token validation failed. Please check your token and try again.") from e
|
||||||
|
|
||||||
principal = user_info["user"]["login"]
|
principal = user_info["user"]["login"]
|
||||||
|
|
|
@ -15,7 +15,7 @@ from llama_stack.providers.utils.kvstore.api import KVStore
|
||||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||||
from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl
|
from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="quota")
|
log = get_logger(name=__name__, category="quota")
|
||||||
|
|
||||||
|
|
||||||
class QuotaMiddleware:
|
class QuotaMiddleware:
|
||||||
|
@ -46,7 +46,7 @@ class QuotaMiddleware:
|
||||||
self.window_seconds = window_seconds
|
self.window_seconds = window_seconds
|
||||||
|
|
||||||
if isinstance(self.kv_config, SqliteKVStoreConfig):
|
if isinstance(self.kv_config, SqliteKVStoreConfig):
|
||||||
logger.warning(
|
log.warning(
|
||||||
"QuotaMiddleware: Using SQLite backend. Expiry/TTL is not enforced; cleanup is manual. "
|
"QuotaMiddleware: Using SQLite backend. Expiry/TTL is not enforced; cleanup is manual. "
|
||||||
f"window_seconds={self.window_seconds}"
|
f"window_seconds={self.window_seconds}"
|
||||||
)
|
)
|
||||||
|
@ -84,11 +84,11 @@ class QuotaMiddleware:
|
||||||
else:
|
else:
|
||||||
await kv.set(key, str(count))
|
await kv.set(key, str(count))
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to access KV store for quota")
|
log.exception("Failed to access KV store for quota")
|
||||||
return await self._send_error(send, 500, "Quota service error")
|
return await self._send_error(send, 500, "Quota service error")
|
||||||
|
|
||||||
if count > limit:
|
if count > limit:
|
||||||
logger.warning(
|
log.warning(
|
||||||
"Quota exceeded for client %s: %d/%d",
|
"Quota exceeded for client %s: %d/%d",
|
||||||
key_id,
|
key_id,
|
||||||
count,
|
count,
|
||||||
|
|
|
@ -9,7 +9,7 @@ import asyncio
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging # allow-direct-logging
|
||||||
import os
|
import os
|
||||||
import ssl
|
import ssl
|
||||||
import sys
|
import sys
|
||||||
|
@ -80,7 +80,7 @@ from .quota import QuotaMiddleware
|
||||||
|
|
||||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="server")
|
log = get_logger(name=__name__, category="server")
|
||||||
|
|
||||||
|
|
||||||
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
|
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
|
||||||
|
@ -157,9 +157,9 @@ async def shutdown(app):
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
logger.info("Starting up")
|
log.info("Starting up")
|
||||||
yield
|
yield
|
||||||
logger.info("Shutting down")
|
log.info("Shutting down")
|
||||||
await shutdown(app)
|
await shutdown(app)
|
||||||
|
|
||||||
|
|
||||||
|
@ -182,11 +182,11 @@ async def sse_generator(event_gen_coroutine):
|
||||||
yield create_sse_event(item)
|
yield create_sse_event(item)
|
||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.info("Generator cancelled")
|
log.info("Generator cancelled")
|
||||||
if event_gen:
|
if event_gen:
|
||||||
await event_gen.aclose()
|
await event_gen.aclose()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Error in sse_generator")
|
log.exception("Error in sse_generator")
|
||||||
yield create_sse_event(
|
yield create_sse_event(
|
||||||
{
|
{
|
||||||
"error": {
|
"error": {
|
||||||
|
@ -206,11 +206,11 @@ async def log_request_pre_validation(request: Request):
|
||||||
log_output = rich.pretty.pretty_repr(parsed_body)
|
log_output = rich.pretty.pretty_repr(parsed_body)
|
||||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||||
log_output = repr(body_bytes)
|
log_output = repr(body_bytes)
|
||||||
logger.debug(f"Incoming raw request body for {request.method} {request.url.path}:\n{log_output}")
|
log.debug(f"Incoming raw request body for {request.method} {request.url.path}:\n{log_output}")
|
||||||
else:
|
else:
|
||||||
logger.debug(f"Incoming {request.method} {request.url.path} request with empty body.")
|
log.debug(f"Incoming {request.method} {request.url.path} request with empty body.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Could not read or log request body for {request.method} {request.url.path}: {e}")
|
log.warning(f"Could not read or log request body for {request.method} {request.url.path}: {e}")
|
||||||
|
|
||||||
|
|
||||||
def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
||||||
|
@ -238,10 +238,10 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
||||||
result.url = route
|
result.url = route
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if logger.isEnabledFor(logging.DEBUG):
|
if log.isEnabledFor(logging.DEBUG):
|
||||||
logger.exception(f"Error executing endpoint {route=} {method=}")
|
log.exception(f"Error executing endpoint {route=} {method=}")
|
||||||
else:
|
else:
|
||||||
logger.error(f"Error executing endpoint {route=} {method=}: {str(e)}")
|
log.error(f"Error executing endpoint {route=} {method=}: {str(e)}")
|
||||||
raise translate_exception(e) from e
|
raise translate_exception(e) from e
|
||||||
|
|
||||||
sig = inspect.signature(func)
|
sig = inspect.signature(func)
|
||||||
|
@ -291,7 +291,7 @@ class TracingMiddleware:
|
||||||
# Check if the path is a FastAPI built-in path
|
# Check if the path is a FastAPI built-in path
|
||||||
if path.startswith(self.fastapi_paths):
|
if path.startswith(self.fastapi_paths):
|
||||||
# Pass through to FastAPI's built-in handlers
|
# Pass through to FastAPI's built-in handlers
|
||||||
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
|
log.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
|
||||||
return await self.app(scope, receive, send)
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
if not hasattr(self, "route_impls"):
|
if not hasattr(self, "route_impls"):
|
||||||
|
@ -303,7 +303,7 @@ class TracingMiddleware:
|
||||||
)
|
)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
# If no matching endpoint is found, pass through to FastAPI
|
# If no matching endpoint is found, pass through to FastAPI
|
||||||
logger.debug(f"No matching route found for path: {path}, falling back to FastAPI")
|
log.debug(f"No matching route found for path: {path}, falling back to FastAPI")
|
||||||
return await self.app(scope, receive, send)
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
trace_attributes = {"__location__": "server", "raw_path": path}
|
trace_attributes = {"__location__": "server", "raw_path": path}
|
||||||
|
@ -404,15 +404,15 @@ def main(args: argparse.Namespace | None = None):
|
||||||
config_contents = yaml.safe_load(fp)
|
config_contents = yaml.safe_load(fp)
|
||||||
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
||||||
logger_config = LoggingConfig(**cfg)
|
logger_config = LoggingConfig(**cfg)
|
||||||
logger = get_logger(name=__name__, category="server", config=logger_config)
|
log = get_logger(name=__name__, category="server", config=logger_config)
|
||||||
if args.env:
|
if args.env:
|
||||||
for env_pair in args.env:
|
for env_pair in args.env:
|
||||||
try:
|
try:
|
||||||
key, value = validate_env_pair(env_pair)
|
key, value = validate_env_pair(env_pair)
|
||||||
logger.info(f"Setting CLI environment variable {key} => {value}")
|
log.info(f"Setting CLI environment variable {key} => {value}")
|
||||||
os.environ[key] = value
|
os.environ[key] = value
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(f"Error: {str(e)}")
|
log.error(f"Error: {str(e)}")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
config = replace_env_vars(config_contents)
|
config = replace_env_vars(config_contents)
|
||||||
config = StackRunConfig(**cast_image_name_to_string(config))
|
config = StackRunConfig(**cast_image_name_to_string(config))
|
||||||
|
@ -438,16 +438,16 @@ def main(args: argparse.Namespace | None = None):
|
||||||
impls = loop.run_until_complete(construct_stack(config))
|
impls = loop.run_until_complete(construct_stack(config))
|
||||||
|
|
||||||
except InvalidProviderError as e:
|
except InvalidProviderError as e:
|
||||||
logger.error(f"Error: {str(e)}")
|
log.error(f"Error: {str(e)}")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
if config.server.auth:
|
if config.server.auth:
|
||||||
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
|
log.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
|
||||||
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth, impls=impls)
|
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth, impls=impls)
|
||||||
else:
|
else:
|
||||||
if config.server.quota:
|
if config.server.quota:
|
||||||
quota = config.server.quota
|
quota = config.server.quota
|
||||||
logger.warning(
|
log.warning(
|
||||||
"Configured authenticated_max_requests (%d) but no auth is enabled; "
|
"Configured authenticated_max_requests (%d) but no auth is enabled; "
|
||||||
"falling back to anonymous_max_requests (%d) for all the requests",
|
"falling back to anonymous_max_requests (%d) for all the requests",
|
||||||
quota.authenticated_max_requests,
|
quota.authenticated_max_requests,
|
||||||
|
@ -455,7 +455,7 @@ def main(args: argparse.Namespace | None = None):
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.server.quota:
|
if config.server.quota:
|
||||||
logger.info("Enabling quota middleware for authenticated and anonymous clients")
|
log.info("Enabling quota middleware for authenticated and anonymous clients")
|
||||||
|
|
||||||
quota = config.server.quota
|
quota = config.server.quota
|
||||||
anonymous_max_requests = quota.anonymous_max_requests
|
anonymous_max_requests = quota.anonymous_max_requests
|
||||||
|
@ -516,7 +516,7 @@ def main(args: argparse.Namespace | None = None):
|
||||||
if not available_methods:
|
if not available_methods:
|
||||||
raise ValueError(f"No methods found for {route.name} on {impl}")
|
raise ValueError(f"No methods found for {route.name} on {impl}")
|
||||||
method = available_methods[0]
|
method = available_methods[0]
|
||||||
logger.debug(f"{method} {route.path}")
|
log.debug(f"{method} {route.path}")
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
|
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
|
||||||
|
@ -528,7 +528,7 @@ def main(args: argparse.Namespace | None = None):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"serving APIs: {apis_to_serve}")
|
log.debug(f"serving APIs: {apis_to_serve}")
|
||||||
|
|
||||||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||||
app.exception_handler(Exception)(global_exception_handler)
|
app.exception_handler(Exception)(global_exception_handler)
|
||||||
|
@ -553,21 +553,21 @@ def main(args: argparse.Namespace | None = None):
|
||||||
if config.server.tls_cafile:
|
if config.server.tls_cafile:
|
||||||
ssl_config["ssl_ca_certs"] = config.server.tls_cafile
|
ssl_config["ssl_ca_certs"] = config.server.tls_cafile
|
||||||
ssl_config["ssl_cert_reqs"] = ssl.CERT_REQUIRED
|
ssl_config["ssl_cert_reqs"] = ssl.CERT_REQUIRED
|
||||||
logger.info(
|
log.info(
|
||||||
f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}\n CA: {config.server.tls_cafile}"
|
f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}\n CA: {config.server.tls_cafile}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
|
log.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
|
||||||
|
|
||||||
listen_host = config.server.host or ["::", "0.0.0.0"]
|
listen_host = config.server.host or ["::", "0.0.0.0"]
|
||||||
logger.info(f"Listening on {listen_host}:{port}")
|
log.info(f"Listening on {listen_host}:{port}")
|
||||||
|
|
||||||
uvicorn_config = {
|
uvicorn_config = {
|
||||||
"app": app,
|
"app": app,
|
||||||
"host": listen_host,
|
"host": listen_host,
|
||||||
"port": port,
|
"port": port,
|
||||||
"lifespan": "on",
|
"lifespan": "on",
|
||||||
"log_level": logger.getEffectiveLevel(),
|
"log_level": log.getEffectiveLevel(),
|
||||||
"log_config": logger_config,
|
"log_config": logger_config,
|
||||||
}
|
}
|
||||||
if ssl_config:
|
if ssl_config:
|
||||||
|
@ -586,19 +586,19 @@ def main(args: argparse.Namespace | None = None):
|
||||||
try:
|
try:
|
||||||
loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
|
loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
|
||||||
except (KeyboardInterrupt, SystemExit):
|
except (KeyboardInterrupt, SystemExit):
|
||||||
logger.info("Received interrupt signal, shutting down gracefully...")
|
log.info("Received interrupt signal, shutting down gracefully...")
|
||||||
finally:
|
finally:
|
||||||
if not loop.is_closed():
|
if not loop.is_closed():
|
||||||
logger.debug("Closing event loop")
|
log.debug("Closing event loop")
|
||||||
loop.close()
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
def _log_run_config(run_config: StackRunConfig):
|
def _log_run_config(run_config: StackRunConfig):
|
||||||
"""Logs the run config with redacted fields and disabled providers removed."""
|
"""Logs the run config with redacted fields and disabled providers removed."""
|
||||||
logger.info("Run configuration:")
|
log.info("Run configuration:")
|
||||||
safe_config = redact_sensitive_fields(run_config.model_dump(mode="json"))
|
safe_config = redact_sensitive_fields(run_config.model_dump(mode="json"))
|
||||||
clean_config = remove_disabled_providers(safe_config)
|
clean_config = remove_disabled_providers(safe_config)
|
||||||
logger.info(yaml.dump(clean_config, indent=2))
|
log.info(yaml.dump(clean_config, indent=2))
|
||||||
|
|
||||||
|
|
||||||
def extract_path_params(route: str) -> list[str]:
|
def extract_path_params(route: str) -> list[str]:
|
||||||
|
|
|
@ -45,7 +45,7 @@ from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
log = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
class LlamaStack(
|
class LlamaStack(
|
||||||
|
@ -105,11 +105,11 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
||||||
|
|
||||||
method = getattr(impls[api], register_method)
|
method = getattr(impls[api], register_method)
|
||||||
for obj in objects:
|
for obj in objects:
|
||||||
logger.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}")
|
log.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}")
|
||||||
|
|
||||||
# Do not register models on disabled providers
|
# Do not register models on disabled providers
|
||||||
if hasattr(obj, "provider_id") and (not obj.provider_id or obj.provider_id == "__disabled__"):
|
if hasattr(obj, "provider_id") and (not obj.provider_id or obj.provider_id == "__disabled__"):
|
||||||
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.")
|
log.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# we want to maintain the type information in arguments to method.
|
# we want to maintain the type information in arguments to method.
|
||||||
|
@ -123,7 +123,7 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
||||||
objects_to_process = response.data if hasattr(response, "data") else response
|
objects_to_process = response.data if hasattr(response, "data") else response
|
||||||
|
|
||||||
for obj in objects_to_process:
|
for obj in objects_to_process:
|
||||||
logger.debug(
|
log.debug(
|
||||||
f"{rsrc.capitalize()}: {obj.identifier} served by {obj.provider_id}",
|
f"{rsrc.capitalize()}: {obj.identifier} served by {obj.provider_id}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -160,7 +160,7 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
|
||||||
try:
|
try:
|
||||||
resolved_provider_id = replace_env_vars(v["provider_id"], f"{path}[{i}].provider_id")
|
resolved_provider_id = replace_env_vars(v["provider_id"], f"{path}[{i}].provider_id")
|
||||||
if resolved_provider_id == "__disabled__":
|
if resolved_provider_id == "__disabled__":
|
||||||
logger.debug(
|
log.debug(
|
||||||
f"Skipping config env variable expansion for disabled provider: {v.get('provider_id', '')}"
|
f"Skipping config env variable expansion for disabled provider: {v.get('provider_id', '')}"
|
||||||
)
|
)
|
||||||
# Create a copy with resolved provider_id but original config
|
# Create a copy with resolved provider_id but original config
|
||||||
|
@ -315,7 +315,7 @@ async def construct_stack(
|
||||||
TEST_RECORDING_CONTEXT = setup_inference_recording()
|
TEST_RECORDING_CONTEXT = setup_inference_recording()
|
||||||
if TEST_RECORDING_CONTEXT:
|
if TEST_RECORDING_CONTEXT:
|
||||||
TEST_RECORDING_CONTEXT.__enter__()
|
TEST_RECORDING_CONTEXT.__enter__()
|
||||||
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
|
log.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
|
||||||
|
|
||||||
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
|
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
|
||||||
policy = run_config.server.auth.access_policy if run_config.server.auth else []
|
policy = run_config.server.auth.access_policy if run_config.server.auth else []
|
||||||
|
@ -337,12 +337,12 @@ async def construct_stack(
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
if task.cancelled():
|
if task.cancelled():
|
||||||
logger.error("Model refresh task cancelled")
|
log.error("Model refresh task cancelled")
|
||||||
elif task.exception():
|
elif task.exception():
|
||||||
logger.error(f"Model refresh task failed: {task.exception()}")
|
log.error(f"Model refresh task failed: {task.exception()}")
|
||||||
traceback.print_exception(task.exception())
|
traceback.print_exception(task.exception())
|
||||||
else:
|
else:
|
||||||
logger.debug("Model refresh task completed")
|
log.debug("Model refresh task completed")
|
||||||
|
|
||||||
REGISTRY_REFRESH_TASK.add_done_callback(cb)
|
REGISTRY_REFRESH_TASK.add_done_callback(cb)
|
||||||
return impls
|
return impls
|
||||||
|
@ -351,23 +351,23 @@ async def construct_stack(
|
||||||
async def shutdown_stack(impls: dict[Api, Any]):
|
async def shutdown_stack(impls: dict[Api, Any]):
|
||||||
for impl in impls.values():
|
for impl in impls.values():
|
||||||
impl_name = impl.__class__.__name__
|
impl_name = impl.__class__.__name__
|
||||||
logger.info(f"Shutting down {impl_name}")
|
log.info(f"Shutting down {impl_name}")
|
||||||
try:
|
try:
|
||||||
if hasattr(impl, "shutdown"):
|
if hasattr(impl, "shutdown"):
|
||||||
await asyncio.wait_for(impl.shutdown(), timeout=5)
|
await asyncio.wait_for(impl.shutdown(), timeout=5)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"No shutdown method for {impl_name}")
|
log.warning(f"No shutdown method for {impl_name}")
|
||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
logger.exception(f"Shutdown timeout for {impl_name}")
|
log.exception(f"Shutdown timeout for {impl_name}")
|
||||||
except (Exception, asyncio.CancelledError) as e:
|
except (Exception, asyncio.CancelledError) as e:
|
||||||
logger.exception(f"Failed to shutdown {impl_name}: {e}")
|
log.exception(f"Failed to shutdown {impl_name}: {e}")
|
||||||
|
|
||||||
global TEST_RECORDING_CONTEXT
|
global TEST_RECORDING_CONTEXT
|
||||||
if TEST_RECORDING_CONTEXT:
|
if TEST_RECORDING_CONTEXT:
|
||||||
try:
|
try:
|
||||||
TEST_RECORDING_CONTEXT.__exit__(None, None, None)
|
TEST_RECORDING_CONTEXT.__exit__(None, None, None)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error during inference recording cleanup: {e}")
|
log.error(f"Error during inference recording cleanup: {e}")
|
||||||
|
|
||||||
global REGISTRY_REFRESH_TASK
|
global REGISTRY_REFRESH_TASK
|
||||||
if REGISTRY_REFRESH_TASK:
|
if REGISTRY_REFRESH_TASK:
|
||||||
|
@ -375,14 +375,14 @@ async def shutdown_stack(impls: dict[Api, Any]):
|
||||||
|
|
||||||
|
|
||||||
async def refresh_registry_once(impls: dict[Api, Any]):
|
async def refresh_registry_once(impls: dict[Api, Any]):
|
||||||
logger.debug("refreshing registry")
|
log.debug("refreshing registry")
|
||||||
routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)]
|
routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)]
|
||||||
for routing_table in routing_tables:
|
for routing_table in routing_tables:
|
||||||
await routing_table.refresh()
|
await routing_table.refresh()
|
||||||
|
|
||||||
|
|
||||||
async def refresh_registry_task(impls: dict[Api, Any]):
|
async def refresh_registry_task(impls: dict[Api, Any]):
|
||||||
logger.info("starting registry refresh task")
|
log.info("starting registry refresh task")
|
||||||
while True:
|
while True:
|
||||||
await refresh_registry_once(impls)
|
await refresh_registry_once(impls)
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,7 @@ from pathlib import Path
|
||||||
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="config_resolution")
|
log = get_logger(name=__name__, category="config_resolution")
|
||||||
|
|
||||||
|
|
||||||
DISTRO_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "distributions"
|
DISTRO_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "distributions"
|
||||||
|
@ -42,25 +42,25 @@ def resolve_config_or_distro(
|
||||||
# Strategy 1: Try as file path first
|
# Strategy 1: Try as file path first
|
||||||
config_path = Path(config_or_distro)
|
config_path = Path(config_or_distro)
|
||||||
if config_path.exists() and config_path.is_file():
|
if config_path.exists() and config_path.is_file():
|
||||||
logger.info(f"Using file path: {config_path}")
|
log.info(f"Using file path: {config_path}")
|
||||||
return config_path.resolve()
|
return config_path.resolve()
|
||||||
|
|
||||||
# Strategy 2: Try as distribution name (if no .yaml extension)
|
# Strategy 2: Try as distribution name (if no .yaml extension)
|
||||||
if not config_or_distro.endswith(".yaml"):
|
if not config_or_distro.endswith(".yaml"):
|
||||||
distro_config = _get_distro_config_path(config_or_distro, mode)
|
distro_config = _get_distro_config_path(config_or_distro, mode)
|
||||||
if distro_config.exists():
|
if distro_config.exists():
|
||||||
logger.info(f"Using distribution: {distro_config}")
|
log.info(f"Using distribution: {distro_config}")
|
||||||
return distro_config
|
return distro_config
|
||||||
|
|
||||||
# Strategy 3: Try as built distribution name
|
# Strategy 3: Try as built distribution name
|
||||||
distrib_config = DISTRIBS_BASE_DIR / f"llamastack-{config_or_distro}" / f"{config_or_distro}-{mode}.yaml"
|
distrib_config = DISTRIBS_BASE_DIR / f"llamastack-{config_or_distro}" / f"{config_or_distro}-{mode}.yaml"
|
||||||
if distrib_config.exists():
|
if distrib_config.exists():
|
||||||
logger.info(f"Using built distribution: {distrib_config}")
|
log.info(f"Using built distribution: {distrib_config}")
|
||||||
return distrib_config
|
return distrib_config
|
||||||
|
|
||||||
distrib_config = DISTRIBS_BASE_DIR / f"{config_or_distro}" / f"{config_or_distro}-{mode}.yaml"
|
distrib_config = DISTRIBS_BASE_DIR / f"{config_or_distro}" / f"{config_or_distro}-{mode}.yaml"
|
||||||
if distrib_config.exists():
|
if distrib_config.exists():
|
||||||
logger.info(f"Using built distribution: {distrib_config}")
|
log.info(f"Using built distribution: {distrib_config}")
|
||||||
return distrib_config
|
return distrib_config
|
||||||
|
|
||||||
# Strategy 4: Failed - provide helpful error
|
# Strategy 4: Failed - provide helpful error
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
import importlib
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import subprocess
|
import subprocess
|
||||||
|
@ -12,9 +12,9 @@ import sys
|
||||||
|
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
import importlib
|
log = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
def formulate_run_args(image_type: str, image_name: str) -> list:
|
def formulate_run_args(image_type: str, image_name: str) -> list:
|
||||||
|
|
|
@ -6,7 +6,6 @@
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Annotated, Any, Literal, Union, get_args, get_origin
|
from typing import Annotated, Any, Literal, Union, get_args, get_origin
|
||||||
|
|
||||||
|
@ -14,7 +13,9 @@ from pydantic import BaseModel
|
||||||
from pydantic.fields import FieldInfo
|
from pydantic.fields import FieldInfo
|
||||||
from pydantic_core import PydanticUndefinedType
|
from pydantic_core import PydanticUndefinedType
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
log = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
def is_list_of_primitives(field_type):
|
def is_list_of_primitives(field_type):
|
||||||
|
|
|
@ -4,11 +4,11 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
import logging # allow-direct-logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
from logging.config import dictConfig
|
from logging.config import dictConfig # allow-direct-logging
|
||||||
|
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.errors import MarkupError
|
from rich.errors import MarkupError
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and its affiliates.
|
# Copyright (c) Meta Platforms, Inc. and its affiliates.
|
||||||
import math
|
import math
|
||||||
from logging import getLogger
|
from logging import getLogger # allow-direct-logging
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from logging import getLogger
|
from logging import getLogger # allow-direct-logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
import math
|
import math
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
@ -22,6 +21,8 @@ from PIL import Image as PIL_Image
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from torch.distributed import _functional_collectives as funcol
|
from torch.distributed import _functional_collectives as funcol
|
||||||
|
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
from ..model import ModelArgs, RMSNorm, apply_rotary_emb, precompute_freqs_cis
|
from ..model import ModelArgs, RMSNorm, apply_rotary_emb, precompute_freqs_cis
|
||||||
from .encoder_utils import (
|
from .encoder_utils import (
|
||||||
build_encoder_attention_mask,
|
build_encoder_attention_mask,
|
||||||
|
@ -34,9 +35,10 @@ from .encoder_utils import (
|
||||||
from .image_transform import VariableSizeImageTransform
|
from .image_transform import VariableSizeImageTransform
|
||||||
from .utils import get_negative_inf_value, to_2tuple
|
from .utils import get_negative_inf_value, to_2tuple
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
MP_SCALE = 8
|
MP_SCALE = 8
|
||||||
|
|
||||||
|
log = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
def reduce_from_tensor_model_parallel_region(input_):
|
def reduce_from_tensor_model_parallel_region(input_):
|
||||||
"""All-reduce the input tensor across model parallel group."""
|
"""All-reduce the input tensor across model parallel group."""
|
||||||
|
@ -415,7 +417,7 @@ class VisionEncoder(nn.Module):
|
||||||
)
|
)
|
||||||
state_dict[prefix + "gated_positional_embedding"] = global_pos_embed
|
state_dict[prefix + "gated_positional_embedding"] = global_pos_embed
|
||||||
state_dict[prefix + "gated_positional_embedding_gate"] = torch.zeros(1, dtype=global_pos_embed.dtype)
|
state_dict[prefix + "gated_positional_embedding_gate"] = torch.zeros(1, dtype=global_pos_embed.dtype)
|
||||||
logger.info(f"Initialized global positional embedding with size {global_pos_embed.size()}")
|
log.info(f"Initialized global positional embedding with size {global_pos_embed.size()}")
|
||||||
else:
|
else:
|
||||||
global_pos_embed = resize_global_position_embedding(
|
global_pos_embed = resize_global_position_embedding(
|
||||||
state_dict[prefix + "gated_positional_embedding"],
|
state_dict[prefix + "gated_positional_embedding"],
|
||||||
|
@ -423,7 +425,7 @@ class VisionEncoder(nn.Module):
|
||||||
self.max_num_tiles,
|
self.max_num_tiles,
|
||||||
self.max_num_tiles,
|
self.max_num_tiles,
|
||||||
)
|
)
|
||||||
logger.info(
|
log.info(
|
||||||
f"Resized global positional embedding from {state_dict[prefix + 'gated_positional_embedding'].size()} to {global_pos_embed.size()}"
|
f"Resized global positional embedding from {state_dict[prefix + 'gated_positional_embedding'].size()} to {global_pos_embed.size()}"
|
||||||
)
|
)
|
||||||
state_dict[prefix + "gated_positional_embedding"] = global_pos_embed
|
state_dict[prefix + "gated_positional_embedding"] = global_pos_embed
|
||||||
|
@ -771,7 +773,7 @@ class TilePositionEmbedding(nn.Module):
|
||||||
if embed is not None:
|
if embed is not None:
|
||||||
# reshape the weights to the correct shape
|
# reshape the weights to the correct shape
|
||||||
nt_old, nt_old, _, w = embed.shape
|
nt_old, nt_old, _, w = embed.shape
|
||||||
logging.info(f"Resizing tile embedding from {nt_old}x{nt_old} to {self.num_tiles}x{self.num_tiles}")
|
log.info(f"Resizing tile embedding from {nt_old}x{nt_old} to {self.num_tiles}x{self.num_tiles}")
|
||||||
embed_new = TilePositionEmbedding._dynamic_resize(embed, self.num_tiles)
|
embed_new = TilePositionEmbedding._dynamic_resize(embed, self.num_tiles)
|
||||||
# assign the weights to the module
|
# assign the weights to the module
|
||||||
state_dict[prefix + "embedding"] = embed_new
|
state_dict[prefix + "embedding"] = embed_new
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from collections.abc import Collection, Iterator, Sequence, Set
|
from collections.abc import Collection, Iterator, Sequence, Set
|
||||||
from logging import getLogger
|
from logging import getLogger # allow-direct-logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
from typing import (
|
||||||
Literal,
|
Literal,
|
||||||
|
|
|
@ -11,7 +11,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
|
from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference")
|
log = get_logger(name=__name__, category="inference")
|
||||||
|
|
||||||
BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
|
BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
|
||||||
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")
|
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")
|
||||||
|
@ -215,7 +215,7 @@ class ToolUtils:
|
||||||
# FIXME: Enable multiple tool calls
|
# FIXME: Enable multiple tool calls
|
||||||
return function_calls[0]
|
return function_calls[0]
|
||||||
else:
|
else:
|
||||||
logger.debug(f"Did not parse tool call from message body: {message_body}")
|
log.debug(f"Did not parse tool call from message body: {message_body}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
@ -13,11 +12,13 @@ from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
from ...datatypes import QuantizationMode
|
from ...datatypes import QuantizationMode
|
||||||
from ..model import Transformer, TransformerBlock
|
from ..model import Transformer, TransformerBlock
|
||||||
from ..moe import MoE
|
from ..moe import MoE
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
logger = get_logger(__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
def swiglu_wrapper_no_reduce(
|
def swiglu_wrapper_no_reduce(
|
||||||
|
@ -186,7 +187,7 @@ def logging_callbacks(
|
||||||
if use_rich_progress:
|
if use_rich_progress:
|
||||||
console.print(message)
|
console.print(message)
|
||||||
elif rank == 0: # Only log from rank 0 for non-rich logging
|
elif rank == 0: # Only log from rank 0 for non-rich logging
|
||||||
log.info(message)
|
logger.info(message)
|
||||||
|
|
||||||
total_blocks = sum(1 for _, block in model.named_modules() if should_quantize_block(block))
|
total_blocks = sum(1 for _, block in model.named_modules() if should_quantize_block(block))
|
||||||
progress = None
|
progress = None
|
||||||
|
@ -220,6 +221,6 @@ def logging_callbacks(
|
||||||
if completed is not None:
|
if completed is not None:
|
||||||
progress.update(task_id, completed=completed)
|
progress.update(task_id, completed=completed)
|
||||||
elif rank == 0 and completed and completed % 10 == 0:
|
elif rank == 0 and completed and completed % 10 == 0:
|
||||||
log.info(f"Rank {rank}: {completed}/{total_blocks} blocks completed")
|
logger.info(f"Rank {rank}: {completed}/{total_blocks} blocks completed")
|
||||||
|
|
||||||
return progress, log_status, update_status
|
return progress, log_status, update_status
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from collections.abc import Collection, Iterator, Sequence, Set
|
from collections.abc import Collection, Iterator, Sequence, Set
|
||||||
from logging import getLogger
|
from logging import getLogger # allow-direct-logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
from typing import (
|
||||||
Literal,
|
Literal,
|
||||||
|
|
|
@ -6,16 +6,17 @@
|
||||||
|
|
||||||
# type: ignore
|
# type: ignore
|
||||||
import collections
|
import collections
|
||||||
import logging
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__, category="core")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
||||||
|
|
||||||
log.info("Using efficient FP8 or INT4 operators in FBGEMM.")
|
logger.info("Using efficient FP8 or INT4 operators in FBGEMM.")
|
||||||
except ImportError:
|
except ImportError:
|
||||||
log.error("No efficient FP8 or INT4 operators. Please install FBGEMM.")
|
logger.error("No efficient FP8 or INT4 operators. Please install FBGEMM.")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
|
@ -84,7 +84,7 @@ MEMORY_QUERY_TOOL = "knowledge_search"
|
||||||
WEB_SEARCH_TOOL = "web_search"
|
WEB_SEARCH_TOOL = "web_search"
|
||||||
RAG_TOOL_GROUP = "builtin::rag"
|
RAG_TOOL_GROUP = "builtin::rag"
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="agents")
|
log = get_logger(name=__name__, category="agents")
|
||||||
|
|
||||||
|
|
||||||
class ChatAgent(ShieldRunnerMixin):
|
class ChatAgent(ShieldRunnerMixin):
|
||||||
|
@ -612,7 +612,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
|
|
||||||
if n_iter >= self.agent_config.max_infer_iters:
|
if n_iter >= self.agent_config.max_infer_iters:
|
||||||
logger.info(f"done with MAX iterations ({n_iter}), exiting.")
|
log.info(f"done with MAX iterations ({n_iter}), exiting.")
|
||||||
# NOTE: mark end_of_turn to indicate to client that we are done with the turn
|
# NOTE: mark end_of_turn to indicate to client that we are done with the turn
|
||||||
# Do not continue the tool call loop after this point
|
# Do not continue the tool call loop after this point
|
||||||
message.stop_reason = StopReason.end_of_turn
|
message.stop_reason = StopReason.end_of_turn
|
||||||
|
@ -620,7 +620,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
break
|
break
|
||||||
|
|
||||||
if stop_reason == StopReason.out_of_tokens:
|
if stop_reason == StopReason.out_of_tokens:
|
||||||
logger.info("out of token budget, exiting.")
|
log.info("out of token budget, exiting.")
|
||||||
yield message
|
yield message
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -634,7 +634,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
message.content = [message.content] + output_attachments
|
message.content = [message.content] + output_attachments
|
||||||
yield message
|
yield message
|
||||||
else:
|
else:
|
||||||
logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
|
log.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
|
||||||
input_messages = input_messages + [message]
|
input_messages = input_messages + [message]
|
||||||
else:
|
else:
|
||||||
input_messages = input_messages + [message]
|
input_messages = input_messages + [message]
|
||||||
|
@ -889,7 +889,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
else:
|
else:
|
||||||
tool_name_str = tool_name
|
tool_name_str = tool_name
|
||||||
|
|
||||||
logger.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}")
|
log.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}")
|
||||||
result = await self.tool_runtime_api.invoke_tool(
|
result = await self.tool_runtime_api.invoke_tool(
|
||||||
tool_name=tool_name_str,
|
tool_name=tool_name_str,
|
||||||
kwargs={
|
kwargs={
|
||||||
|
@ -899,7 +899,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
**self.tool_name_to_args.get(tool_name_str, {}),
|
**self.tool_name_to_args.get(tool_name_str, {}),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
logger.debug(f"tool call {tool_name_str} completed with result: {result}")
|
log.debug(f"tool call {tool_name_str} completed with result: {result}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
|
@ -42,6 +41,7 @@ from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.core.datatypes import AccessRule
|
from llama_stack.core.datatypes import AccessRule
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
||||||
from llama_stack.providers.utils.pagination import paginate_records
|
from llama_stack.providers.utils.pagination import paginate_records
|
||||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||||
|
@ -51,7 +51,7 @@ from .config import MetaReferenceAgentsImplConfig
|
||||||
from .openai_responses import OpenAIResponsesImpl
|
from .openai_responses import OpenAIResponsesImpl
|
||||||
from .persistence import AgentInfo
|
from .persistence import AgentInfo
|
||||||
|
|
||||||
logger = logging.getLogger()
|
log = get_logger(name=__name__, category="agents")
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceAgentsImpl(Agents):
|
class MetaReferenceAgentsImpl(Agents):
|
||||||
|
@ -268,7 +268,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
# Get the agent info using the key
|
# Get the agent info using the key
|
||||||
agent_info_json = await self.persistence_store.get(agent_key)
|
agent_info_json = await self.persistence_store.get(agent_key)
|
||||||
if not agent_info_json:
|
if not agent_info_json:
|
||||||
logger.error(f"Could not find agent info for key {agent_key}")
|
log.error(f"Could not find agent info for key {agent_key}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -281,7 +281,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error parsing agent info for {agent_id}: {e}")
|
log.error(f"Error parsing agent info for {agent_id}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Convert Agent objects to dictionaries
|
# Convert Agent objects to dictionaries
|
||||||
|
|
|
@ -75,7 +75,7 @@ from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefiniti
|
||||||
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
||||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="openai_responses")
|
log = get_logger(name=__name__, category="openai_responses")
|
||||||
|
|
||||||
OPENAI_RESPONSES_PREFIX = "openai_responses:"
|
OPENAI_RESPONSES_PREFIX = "openai_responses:"
|
||||||
|
|
||||||
|
@ -544,12 +544,12 @@ class OpenAIResponsesImpl:
|
||||||
break
|
break
|
||||||
|
|
||||||
if function_tool_calls:
|
if function_tool_calls:
|
||||||
logger.info("Exiting inference loop since there is a function (client-side) tool call")
|
log.info("Exiting inference loop since there is a function (client-side) tool call")
|
||||||
break
|
break
|
||||||
|
|
||||||
n_iter += 1
|
n_iter += 1
|
||||||
if n_iter >= max_infer_iters:
|
if n_iter >= max_infer_iters:
|
||||||
logger.info(f"Exiting inference loop since iteration count({n_iter}) exceeds {max_infer_iters=}")
|
log.info(f"Exiting inference loop since iteration count({n_iter}) exceeds {max_infer_iters=}")
|
||||||
break
|
break
|
||||||
|
|
||||||
messages = next_turn_messages
|
messages = next_turn_messages
|
||||||
|
@ -698,7 +698,7 @@ class OpenAIResponsesImpl:
|
||||||
)
|
)
|
||||||
return search_response.data
|
return search_response.data
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to search vector store {vector_store_id}: {e}")
|
log.warning(f"Failed to search vector store {vector_store_id}: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# Run all searches in parallel using gather
|
# Run all searches in parallel using gather
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
|
@ -15,9 +14,10 @@ from llama_stack.core.access_control.access_control import AccessDeniedError, is
|
||||||
from llama_stack.core.access_control.datatypes import AccessRule
|
from llama_stack.core.access_control.datatypes import AccessRule
|
||||||
from llama_stack.core.datatypes import User
|
from llama_stack.core.datatypes import User
|
||||||
from llama_stack.core.request_headers import get_authenticated_user
|
from llama_stack.core.request_headers import get_authenticated_user
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="agents")
|
||||||
|
|
||||||
|
|
||||||
class AgentSessionInfo(Session):
|
class AgentSessionInfo(Session):
|
||||||
|
|
|
@ -5,13 +5,13 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import Message
|
from llama_stack.apis.inference import Message
|
||||||
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.telemetry import tracing
|
from llama_stack.providers.utils.telemetry import tracing
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="agents")
|
||||||
|
|
||||||
|
|
||||||
class SafetyException(Exception): # noqa: N818
|
class SafetyException(Exception): # noqa: N818
|
||||||
|
|
|
@ -73,11 +73,12 @@ from .config import MetaReferenceInferenceConfig
|
||||||
from .generators import LlamaGenerator
|
from .generators import LlamaGenerator
|
||||||
from .model_parallel import LlamaModelParallelGenerator
|
from .model_parallel import LlamaModelParallelGenerator
|
||||||
|
|
||||||
log = get_logger(__name__, category="inference")
|
|
||||||
# there's a single model parallel process running serving the model. for now,
|
# there's a single model parallel process running serving the model. for now,
|
||||||
# we don't support multiple concurrent requests to this process.
|
# we don't support multiple concurrent requests to this process.
|
||||||
SEMAPHORE = asyncio.Semaphore(1)
|
SEMAPHORE = asyncio.Semaphore(1)
|
||||||
|
|
||||||
|
logger = get_logger(__name__, category="inference")
|
||||||
|
|
||||||
|
|
||||||
def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator:
|
def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator:
|
||||||
return LlamaGenerator(config, model_id, llama_model)
|
return LlamaGenerator(config, model_id, llama_model)
|
||||||
|
@ -144,7 +145,7 @@ class MetaReferenceInferenceImpl(
|
||||||
return model
|
return model
|
||||||
|
|
||||||
async def load_model(self, model_id, llama_model) -> None:
|
async def load_model(self, model_id, llama_model) -> None:
|
||||||
log.info(f"Loading model `{model_id}`")
|
logger.info(f"Loading model `{model_id}`")
|
||||||
|
|
||||||
builder_params = [self.config, model_id, llama_model]
|
builder_params = [self.config, model_id, llama_model]
|
||||||
|
|
||||||
|
@ -166,7 +167,7 @@ class MetaReferenceInferenceImpl(
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
self.llama_model = llama_model
|
self.llama_model = llama_model
|
||||||
|
|
||||||
log.info("Warming up...")
|
logger.info("Warming up...")
|
||||||
await self.completion(
|
await self.completion(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
content="Hello, world!",
|
content="Hello, world!",
|
||||||
|
@ -177,7 +178,7 @@ class MetaReferenceInferenceImpl(
|
||||||
messages=[UserMessage(content="Hi how are you?")],
|
messages=[UserMessage(content="Hi how are you?")],
|
||||||
sampling_params=SamplingParams(max_tokens=20),
|
sampling_params=SamplingParams(max_tokens=20),
|
||||||
)
|
)
|
||||||
log.info("Warmed up!")
|
logger.info("Warmed up!")
|
||||||
|
|
||||||
def check_model(self, request) -> None:
|
def check_model(self, request) -> None:
|
||||||
if self.model_id is None or self.llama_model is None:
|
if self.model_id is None or self.llama_model is None:
|
||||||
|
|
|
@ -12,7 +12,6 @@
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
|
@ -32,13 +31,14 @@ from fairscale.nn.model_parallel.initialize import (
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
||||||
|
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.datatypes import GenerationResult
|
from llama_stack.models.llama.datatypes import GenerationResult
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
ChatCompletionRequestWithRawContent,
|
ChatCompletionRequestWithRawContent,
|
||||||
CompletionRequestWithRawContent,
|
CompletionRequestWithRawContent,
|
||||||
)
|
)
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="inference")
|
||||||
|
|
||||||
|
|
||||||
class ProcessingMessageName(str, Enum):
|
class ProcessingMessageName(str, Enum):
|
||||||
|
@ -236,7 +236,7 @@ def worker_process_entrypoint(
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
break
|
break
|
||||||
|
|
||||||
log.info("[debug] worker process done")
|
log.info("[debug] worker process done")
|
||||||
|
|
||||||
|
|
||||||
def launch_dist_group(
|
def launch_dist_group(
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
|
@ -32,8 +31,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
|
|
||||||
from .config import SentenceTransformersInferenceConfig
|
from .config import SentenceTransformersInferenceConfig
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class SentenceTransformersInferenceImpl(
|
class SentenceTransformersInferenceImpl(
|
||||||
OpenAIChatCompletionToLlamaStackMixin,
|
OpenAIChatCompletionToLlamaStackMixin,
|
||||||
|
|
|
@ -6,7 +6,6 @@
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
@ -28,6 +27,7 @@ from llama_stack.apis.post_training import (
|
||||||
LoraFinetuningConfig,
|
LoraFinetuningConfig,
|
||||||
TrainingConfig,
|
TrainingConfig,
|
||||||
)
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
||||||
|
|
||||||
from ..config import HuggingFacePostTrainingConfig
|
from ..config import HuggingFacePostTrainingConfig
|
||||||
|
@ -44,7 +44,7 @@ from ..utils import (
|
||||||
split_dataset,
|
split_dataset,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
class HFFinetuningSingleDevice:
|
class HFFinetuningSingleDevice:
|
||||||
|
@ -69,14 +69,14 @@ class HFFinetuningSingleDevice:
|
||||||
try:
|
try:
|
||||||
messages = json.loads(row["chat_completion_input"])
|
messages = json.loads(row["chat_completion_input"])
|
||||||
if not isinstance(messages, list) or len(messages) != 1:
|
if not isinstance(messages, list) or len(messages) != 1:
|
||||||
logger.warning(f"Invalid chat_completion_input format: {row['chat_completion_input']}")
|
log.warning(f"Invalid chat_completion_input format: {row['chat_completion_input']}")
|
||||||
return None, None
|
return None, None
|
||||||
if "content" not in messages[0]:
|
if "content" not in messages[0]:
|
||||||
logger.warning(f"Message missing content: {messages[0]}")
|
log.warning(f"Message missing content: {messages[0]}")
|
||||||
return None, None
|
return None, None
|
||||||
return messages[0]["content"], row["expected_answer"]
|
return messages[0]["content"], row["expected_answer"]
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.warning(f"Failed to parse chat_completion_input: {row['chat_completion_input']}")
|
log.warning(f"Failed to parse chat_completion_input: {row['chat_completion_input']}")
|
||||||
return None, None
|
return None, None
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
@ -86,13 +86,13 @@ class HFFinetuningSingleDevice:
|
||||||
try:
|
try:
|
||||||
dialog = json.loads(row["dialog"])
|
dialog = json.loads(row["dialog"])
|
||||||
if not isinstance(dialog, list) or len(dialog) < 2:
|
if not isinstance(dialog, list) or len(dialog) < 2:
|
||||||
logger.warning(f"Dialog must have at least 2 messages: {row['dialog']}")
|
log.warning(f"Dialog must have at least 2 messages: {row['dialog']}")
|
||||||
return None, None
|
return None, None
|
||||||
if dialog[0].get("role") != "user":
|
if dialog[0].get("role") != "user":
|
||||||
logger.warning(f"First message must be from user: {dialog[0]}")
|
log.warning(f"First message must be from user: {dialog[0]}")
|
||||||
return None, None
|
return None, None
|
||||||
if not any(msg.get("role") == "assistant" for msg in dialog):
|
if not any(msg.get("role") == "assistant" for msg in dialog):
|
||||||
logger.warning("Dialog must have at least one assistant message")
|
log.warning("Dialog must have at least one assistant message")
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
# Convert to human/gpt format
|
# Convert to human/gpt format
|
||||||
|
@ -100,14 +100,14 @@ class HFFinetuningSingleDevice:
|
||||||
conversations = []
|
conversations = []
|
||||||
for msg in dialog:
|
for msg in dialog:
|
||||||
if "role" not in msg or "content" not in msg:
|
if "role" not in msg or "content" not in msg:
|
||||||
logger.warning(f"Message missing role or content: {msg}")
|
log.warning(f"Message missing role or content: {msg}")
|
||||||
continue
|
continue
|
||||||
conversations.append({"from": role_map[msg["role"]], "value": msg["content"]})
|
conversations.append({"from": role_map[msg["role"]], "value": msg["content"]})
|
||||||
|
|
||||||
# Format as a single conversation
|
# Format as a single conversation
|
||||||
return conversations[0]["value"], conversations[1]["value"]
|
return conversations[0]["value"], conversations[1]["value"]
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.warning(f"Failed to parse dialog: {row['dialog']}")
|
log.warning(f"Failed to parse dialog: {row['dialog']}")
|
||||||
return None, None
|
return None, None
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
@ -198,7 +198,7 @@ class HFFinetuningSingleDevice:
|
||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
logger.info("Starting training process with async wrapper")
|
log.info("Starting training process with async wrapper")
|
||||||
asyncio.run(
|
asyncio.run(
|
||||||
self._run_training(
|
self._run_training(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -228,14 +228,14 @@ class HFFinetuningSingleDevice:
|
||||||
raise ValueError("DataConfig is required for training")
|
raise ValueError("DataConfig is required for training")
|
||||||
|
|
||||||
# Load dataset
|
# Load dataset
|
||||||
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
|
log.info(f"Loading dataset: {config.data_config.dataset_id}")
|
||||||
rows = await load_rows_from_dataset(self.datasetio_api, config.data_config.dataset_id)
|
rows = await load_rows_from_dataset(self.datasetio_api, config.data_config.dataset_id)
|
||||||
if not self.validate_dataset_format(rows):
|
if not self.validate_dataset_format(rows):
|
||||||
raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input")
|
raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input")
|
||||||
logger.info(f"Loaded {len(rows)} rows from dataset")
|
log.info(f"Loaded {len(rows)} rows from dataset")
|
||||||
|
|
||||||
# Initialize tokenizer
|
# Initialize tokenizer
|
||||||
logger.info(f"Initializing tokenizer for model: {model}")
|
log.info(f"Initializing tokenizer for model: {model}")
|
||||||
try:
|
try:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model, **provider_config.model_specific_config)
|
tokenizer = AutoTokenizer.from_pretrained(model, **provider_config.model_specific_config)
|
||||||
|
|
||||||
|
@ -257,16 +257,16 @@ class HFFinetuningSingleDevice:
|
||||||
# This ensures consistent sequence lengths across the training process
|
# This ensures consistent sequence lengths across the training process
|
||||||
tokenizer.model_max_length = provider_config.max_seq_length
|
tokenizer.model_max_length = provider_config.max_seq_length
|
||||||
|
|
||||||
logger.info("Tokenizer initialized successfully")
|
log.info("Tokenizer initialized successfully")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to initialize tokenizer: {str(e)}") from e
|
raise RuntimeError(f"Failed to initialize tokenizer: {str(e)}") from e
|
||||||
|
|
||||||
# Create and preprocess dataset
|
# Create and preprocess dataset
|
||||||
logger.info("Creating and preprocessing dataset")
|
log.info("Creating and preprocessing dataset")
|
||||||
try:
|
try:
|
||||||
ds = self._create_dataset(rows, config, provider_config)
|
ds = self._create_dataset(rows, config, provider_config)
|
||||||
ds = self._preprocess_dataset(ds, tokenizer, provider_config)
|
ds = self._preprocess_dataset(ds, tokenizer, provider_config)
|
||||||
logger.info(f"Dataset created with {len(ds)} examples")
|
log.info(f"Dataset created with {len(ds)} examples")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Failed to create dataset: {str(e)}") from e
|
raise ValueError(f"Failed to create dataset: {str(e)}") from e
|
||||||
|
|
||||||
|
@ -293,11 +293,11 @@ class HFFinetuningSingleDevice:
|
||||||
Returns:
|
Returns:
|
||||||
Configured SFTConfig object
|
Configured SFTConfig object
|
||||||
"""
|
"""
|
||||||
logger.info("Configuring training arguments")
|
log.info("Configuring training arguments")
|
||||||
lr = 2e-5
|
lr = 2e-5
|
||||||
if config.optimizer_config:
|
if config.optimizer_config:
|
||||||
lr = config.optimizer_config.lr
|
lr = config.optimizer_config.lr
|
||||||
logger.info(f"Using custom learning rate: {lr}")
|
log.info(f"Using custom learning rate: {lr}")
|
||||||
|
|
||||||
# Validate data config
|
# Validate data config
|
||||||
if not config.data_config:
|
if not config.data_config:
|
||||||
|
@ -350,17 +350,17 @@ class HFFinetuningSingleDevice:
|
||||||
peft_config: Optional LoRA configuration
|
peft_config: Optional LoRA configuration
|
||||||
output_dir_path: Path to save the model
|
output_dir_path: Path to save the model
|
||||||
"""
|
"""
|
||||||
logger.info("Saving final model")
|
log.info("Saving final model")
|
||||||
model_obj.config.use_cache = True
|
model_obj.config.use_cache = True
|
||||||
|
|
||||||
if peft_config:
|
if peft_config:
|
||||||
logger.info("Merging LoRA weights with base model")
|
log.info("Merging LoRA weights with base model")
|
||||||
model_obj = trainer.model.merge_and_unload()
|
model_obj = trainer.model.merge_and_unload()
|
||||||
else:
|
else:
|
||||||
model_obj = trainer.model
|
model_obj = trainer.model
|
||||||
|
|
||||||
save_path = output_dir_path / "merged_model"
|
save_path = output_dir_path / "merged_model"
|
||||||
logger.info(f"Saving model to {save_path}")
|
log.info(f"Saving model to {save_path}")
|
||||||
model_obj.save_pretrained(save_path)
|
model_obj.save_pretrained(save_path)
|
||||||
|
|
||||||
async def _run_training(
|
async def _run_training(
|
||||||
|
@ -380,13 +380,13 @@ class HFFinetuningSingleDevice:
|
||||||
setup_signal_handlers()
|
setup_signal_handlers()
|
||||||
|
|
||||||
# Convert config dicts back to objects
|
# Convert config dicts back to objects
|
||||||
logger.info("Initializing configuration objects")
|
log.info("Initializing configuration objects")
|
||||||
provider_config_obj = HuggingFacePostTrainingConfig(**provider_config)
|
provider_config_obj = HuggingFacePostTrainingConfig(**provider_config)
|
||||||
config_obj = TrainingConfig(**config)
|
config_obj = TrainingConfig(**config)
|
||||||
|
|
||||||
# Initialize and validate device
|
# Initialize and validate device
|
||||||
device = setup_torch_device(provider_config_obj.device)
|
device = setup_torch_device(provider_config_obj.device)
|
||||||
logger.info(f"Using device '{device}'")
|
log.info(f"Using device '{device}'")
|
||||||
|
|
||||||
# Load dataset and tokenizer
|
# Load dataset and tokenizer
|
||||||
train_dataset, eval_dataset, tokenizer = await self.load_dataset(model, config_obj, provider_config_obj)
|
train_dataset, eval_dataset, tokenizer = await self.load_dataset(model, config_obj, provider_config_obj)
|
||||||
|
@ -409,7 +409,7 @@ class HFFinetuningSingleDevice:
|
||||||
model_obj = load_model(model, device, provider_config_obj)
|
model_obj = load_model(model, device, provider_config_obj)
|
||||||
|
|
||||||
# Initialize trainer
|
# Initialize trainer
|
||||||
logger.info("Initializing SFTTrainer")
|
log.info("Initializing SFTTrainer")
|
||||||
trainer = SFTTrainer(
|
trainer = SFTTrainer(
|
||||||
model=model_obj,
|
model=model_obj,
|
||||||
train_dataset=train_dataset,
|
train_dataset=train_dataset,
|
||||||
|
@ -420,9 +420,9 @@ class HFFinetuningSingleDevice:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Train
|
# Train
|
||||||
logger.info("Starting training")
|
log.info("Starting training")
|
||||||
trainer.train()
|
trainer.train()
|
||||||
logger.info("Training completed successfully")
|
log.info("Training completed successfully")
|
||||||
|
|
||||||
# Save final model if output directory is provided
|
# Save final model if output directory is provided
|
||||||
if output_dir_path:
|
if output_dir_path:
|
||||||
|
@ -430,12 +430,12 @@ class HFFinetuningSingleDevice:
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Clean up resources
|
# Clean up resources
|
||||||
logger.info("Cleaning up resources")
|
log.info("Cleaning up resources")
|
||||||
if hasattr(trainer, "model"):
|
if hasattr(trainer, "model"):
|
||||||
evacuate_model_from_device(trainer.model, device.type)
|
evacuate_model_from_device(trainer.model, device.type)
|
||||||
del trainer
|
del trainer
|
||||||
gc.collect()
|
gc.collect()
|
||||||
logger.info("Cleanup completed")
|
log.info("Cleanup completed")
|
||||||
|
|
||||||
async def train(
|
async def train(
|
||||||
self,
|
self,
|
||||||
|
@ -449,7 +449,7 @@ class HFFinetuningSingleDevice:
|
||||||
"""Train a model using HuggingFace's SFTTrainer"""
|
"""Train a model using HuggingFace's SFTTrainer"""
|
||||||
# Initialize and validate device
|
# Initialize and validate device
|
||||||
device = setup_torch_device(provider_config.device)
|
device = setup_torch_device(provider_config.device)
|
||||||
logger.info(f"Using device '{device}'")
|
log.info(f"Using device '{device}'")
|
||||||
|
|
||||||
output_dir_path = None
|
output_dir_path = None
|
||||||
if output_dir:
|
if output_dir:
|
||||||
|
@ -479,7 +479,7 @@ class HFFinetuningSingleDevice:
|
||||||
raise ValueError("DataConfig is required for training")
|
raise ValueError("DataConfig is required for training")
|
||||||
|
|
||||||
# Train in a separate process
|
# Train in a separate process
|
||||||
logger.info("Starting training in separate process")
|
log.info("Starting training in separate process")
|
||||||
try:
|
try:
|
||||||
# Setup multiprocessing for device
|
# Setup multiprocessing for device
|
||||||
if device.type in ["cuda", "mps"]:
|
if device.type in ["cuda", "mps"]:
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
import logging
|
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
@ -24,6 +23,7 @@ from llama_stack.apis.post_training import (
|
||||||
DPOAlignmentConfig,
|
DPOAlignmentConfig,
|
||||||
TrainingConfig,
|
TrainingConfig,
|
||||||
)
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
||||||
|
|
||||||
from ..config import HuggingFacePostTrainingConfig
|
from ..config import HuggingFacePostTrainingConfig
|
||||||
|
@ -40,7 +40,7 @@ from ..utils import (
|
||||||
split_dataset,
|
split_dataset,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = get_logger(__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
class HFDPOAlignmentSingleDevice:
|
class HFDPOAlignmentSingleDevice:
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
|
@ -19,10 +18,11 @@ from transformers import AutoConfig, AutoModelForCausalLM
|
||||||
|
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.post_training import Checkpoint, TrainingConfig
|
from llama_stack.apis.post_training import Checkpoint, TrainingConfig
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .config import HuggingFacePostTrainingConfig
|
from .config import HuggingFacePostTrainingConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = get_logger(__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
def setup_environment():
|
def setup_environment():
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
|
@ -19,6 +18,7 @@ from torch.utils.data import DataLoader, DistributedSampler
|
||||||
from torchtune import modules, training
|
from torchtune import modules, training
|
||||||
from torchtune import utils as torchtune_utils
|
from torchtune import utils as torchtune_utils
|
||||||
from torchtune.data import padded_collate_sft
|
from torchtune.data import padded_collate_sft
|
||||||
|
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
||||||
from torchtune.modules.loss import CEWithChunkedOutputLoss
|
from torchtune.modules.loss import CEWithChunkedOutputLoss
|
||||||
from torchtune.modules.peft import (
|
from torchtune.modules.peft import (
|
||||||
get_adapter_params,
|
get_adapter_params,
|
||||||
|
@ -45,6 +45,7 @@ from llama_stack.apis.post_training import (
|
||||||
)
|
)
|
||||||
from llama_stack.core.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
from llama_stack.core.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||||
from llama_stack.core.utils.model_utils import model_local_dir
|
from llama_stack.core.utils.model_utils import model_local_dir
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.sku_list import resolve_model
|
from llama_stack.models.llama.sku_list import resolve_model
|
||||||
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
||||||
from llama_stack.providers.inline.post_training.torchtune.common import utils
|
from llama_stack.providers.inline.post_training.torchtune.common import utils
|
||||||
|
@ -56,9 +57,7 @@ from llama_stack.providers.inline.post_training.torchtune.config import (
|
||||||
)
|
)
|
||||||
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
|
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
class LoraFinetuningSingleDevice:
|
class LoraFinetuningSingleDevice:
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.inference import Message
|
from llama_stack.apis.inference import Message
|
||||||
|
@ -15,13 +14,14 @@ from llama_stack.apis.safety import (
|
||||||
ViolationLevel,
|
ViolationLevel,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
interleaved_content_as_str,
|
interleaved_content_as_str,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .config import CodeScannerConfig
|
from .config import CodeScannerConfig
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="safety")
|
||||||
|
|
||||||
ALLOWED_CODE_SCANNER_MODEL_IDS = [
|
ALLOWED_CODE_SCANNER_MODEL_IDS = [
|
||||||
"CodeScanner",
|
"CodeScanner",
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -19,6 +18,7 @@ from llama_stack.apis.safety import (
|
||||||
)
|
)
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.core.utils.model_utils import model_local_dir
|
from llama_stack.core.utils.model_utils import model_local_dir
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
interleaved_content_as_str,
|
interleaved_content_as_str,
|
||||||
|
@ -26,10 +26,10 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
|
||||||
from .config import PromptGuardConfig, PromptGuardType
|
from .config import PromptGuardConfig, PromptGuardType
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
|
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
|
||||||
|
|
||||||
|
log = get_logger(name=__name__, category="safety")
|
||||||
|
|
||||||
|
|
||||||
class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||||
def __init__(self, config: PromptGuardConfig, _deps) -> None:
|
def __init__(self, config: PromptGuardConfig, _deps) -> None:
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
import collections
|
import collections
|
||||||
import functools
|
import functools
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import string
|
import string
|
||||||
|
@ -20,7 +19,9 @@ import nltk
|
||||||
from pythainlp.tokenize import sent_tokenize as sent_tokenize_thai
|
from pythainlp.tokenize import sent_tokenize as sent_tokenize_thai
|
||||||
from pythainlp.tokenize import word_tokenize as word_tokenize_thai
|
from pythainlp.tokenize import word_tokenize as word_tokenize_thai
|
||||||
|
|
||||||
logger = logging.getLogger()
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
log = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
WORD_LIST = [
|
WORD_LIST = [
|
||||||
"western",
|
"western",
|
||||||
|
@ -1726,7 +1727,7 @@ def get_langid(text: str, lid_path: str | None = None) -> str:
|
||||||
try:
|
try:
|
||||||
line_langs.append(langdetect.detect(line))
|
line_langs.append(langdetect.detect(line))
|
||||||
except langdetect.LangDetectException as e:
|
except langdetect.LangDetectException as e:
|
||||||
logger.info("Unable to detect language for text %s due to %s", line, e) # refex: disable=pytotw.037
|
log.info("Unable to detect language for text %s due to %s", line, e) # refex: disable=pytotw.037
|
||||||
|
|
||||||
if len(line_langs) == 0:
|
if len(line_langs) == 0:
|
||||||
return "en"
|
return "en"
|
||||||
|
@ -1885,7 +1886,7 @@ class ResponseLanguageChecker(Instruction):
|
||||||
return langdetect.detect(value) == self._language
|
return langdetect.detect(value) == self._language
|
||||||
except langdetect.LangDetectException as e:
|
except langdetect.LangDetectException as e:
|
||||||
# Count as instruction is followed.
|
# Count as instruction is followed.
|
||||||
logger.info("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037
|
log.info("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@ -3110,7 +3111,7 @@ class CapitalLettersEnglishChecker(Instruction):
|
||||||
return value.isupper() and langdetect.detect(value) == "en"
|
return value.isupper() and langdetect.detect(value) == "en"
|
||||||
except langdetect.LangDetectException as e:
|
except langdetect.LangDetectException as e:
|
||||||
# Count as instruction is followed.
|
# Count as instruction is followed.
|
||||||
logger.info("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037
|
log.info("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@ -3139,7 +3140,7 @@ class LowercaseLettersEnglishChecker(Instruction):
|
||||||
return value.islower() and langdetect.detect(value) == "en"
|
return value.islower() and langdetect.detect(value) == "en"
|
||||||
except langdetect.LangDetectException as e:
|
except langdetect.LangDetectException as e:
|
||||||
# Count as instruction is followed.
|
# Count as instruction is followed.
|
||||||
logger.info("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037
|
log.info("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
|
||||||
import secrets
|
import secrets
|
||||||
import string
|
import string
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
@ -32,6 +31,7 @@ from llama_stack.apis.tools import (
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
|
@ -42,7 +42,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
||||||
from .config import RagToolRuntimeConfig
|
from .config import RagToolRuntimeConfig
|
||||||
from .context_retriever import generate_rag_query
|
from .context_retriever import generate_rag_query
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="tools")
|
||||||
|
|
||||||
|
|
||||||
def make_random_string(length: int = 8):
|
def make_random_string(length: int = 8):
|
||||||
|
|
|
@ -8,7 +8,6 @@ import asyncio
|
||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import faiss
|
import faiss
|
||||||
|
@ -24,6 +23,7 @@ from llama_stack.apis.vector_io import (
|
||||||
QueryChunksResponse,
|
QueryChunksResponse,
|
||||||
VectorIO,
|
VectorIO,
|
||||||
)
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import (
|
from llama_stack.providers.datatypes import (
|
||||||
HealthResponse,
|
HealthResponse,
|
||||||
HealthStatus,
|
HealthStatus,
|
||||||
|
@ -39,7 +39,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
||||||
|
|
||||||
from .config import FaissVectorIOConfig
|
from .config import FaissVectorIOConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
VERSION = "v3"
|
VERSION = "v3"
|
||||||
VECTOR_DBS_PREFIX = f"vector_dbs:{VERSION}::"
|
VECTOR_DBS_PREFIX = f"vector_dbs:{VERSION}::"
|
||||||
|
@ -83,7 +83,7 @@ class FaissIndex(EmbeddingIndex):
|
||||||
self.index = faiss.deserialize_index(np.load(buffer, allow_pickle=False))
|
self.index = faiss.deserialize_index(np.load(buffer, allow_pickle=False))
|
||||||
self.chunk_ids = [chunk.chunk_id for chunk in self.chunk_by_index.values()]
|
self.chunk_ids = [chunk.chunk_id for chunk in self.chunk_by_index.values()]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(e, exc_info=True)
|
log.debug(e, exc_info=True)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Error deserializing Faiss index from storage. If you recently upgraded your Llama Stack, Faiss, "
|
"Error deserializing Faiss index from storage. If you recently upgraded your Llama Stack, Faiss, "
|
||||||
"or NumPy versions, you may need to delete the index and re-create it again or downgrade versions.\n"
|
"or NumPy versions, you may need to delete the index and re-create it again or downgrade versions.\n"
|
||||||
|
@ -262,7 +262,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
|
||||||
assert self.kvstore is not None
|
assert self.kvstore is not None
|
||||||
|
|
||||||
if vector_db_id not in self.cache:
|
if vector_db_id not in self.cache:
|
||||||
logger.warning(f"Vector DB {vector_db_id} not found")
|
log.warning(f"Vector DB {vector_db_id} not found")
|
||||||
return
|
return
|
||||||
|
|
||||||
await self.cache[vector_db_id].index.delete()
|
await self.cache[vector_db_id].index.delete()
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
|
||||||
import re
|
import re
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import struct
|
import struct
|
||||||
|
@ -24,6 +23,7 @@ from llama_stack.apis.vector_io import (
|
||||||
QueryChunksResponse,
|
QueryChunksResponse,
|
||||||
VectorIO,
|
VectorIO,
|
||||||
)
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||||
|
@ -35,7 +35,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
||||||
VectorDBWithIndex,
|
VectorDBWithIndex,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
# Specifying search mode is dependent on the VectorIO provider.
|
# Specifying search mode is dependent on the VectorIO provider.
|
||||||
VECTOR_SEARCH = "vector"
|
VECTOR_SEARCH = "vector"
|
||||||
|
@ -257,7 +257,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
|
|
||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
connection.rollback()
|
connection.rollback()
|
||||||
logger.error(f"Error inserting into {self.vector_table}: {e}")
|
log.error(f"Error inserting into {self.vector_table}: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
|
@ -306,7 +306,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
try:
|
try:
|
||||||
chunk = Chunk.model_validate_json(chunk_json)
|
chunk = Chunk.model_validate_json(chunk_json)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error parsing chunk JSON for id {_id}: {e}")
|
log.error(f"Error parsing chunk JSON for id {_id}: {e}")
|
||||||
continue
|
continue
|
||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
scores.append(score)
|
scores.append(score)
|
||||||
|
@ -352,7 +352,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
try:
|
try:
|
||||||
chunk = Chunk.model_validate_json(chunk_json)
|
chunk = Chunk.model_validate_json(chunk_json)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error parsing chunk JSON for id {_id}: {e}")
|
log.error(f"Error parsing chunk JSON for id {_id}: {e}")
|
||||||
continue
|
continue
|
||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
scores.append(score)
|
scores.append(score)
|
||||||
|
@ -447,7 +447,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
connection.commit()
|
connection.commit()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
connection.rollback()
|
connection.rollback()
|
||||||
logger.error(f"Error deleting chunk {chunk_id}: {e}")
|
log.error(f"Error deleting chunk {chunk_id}: {e}")
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
cur.close()
|
cur.close()
|
||||||
|
@ -530,7 +530,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
||||||
|
|
||||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||||
if vector_db_id not in self.cache:
|
if vector_db_id not in self.cache:
|
||||||
logger.warning(f"Vector DB {vector_db_id} not found")
|
log.warning(f"Vector DB {vector_db_id} not found")
|
||||||
return
|
return
|
||||||
await self.cache[vector_db_id].index.delete()
|
await self.cache[vector_db_id].index.delete()
|
||||||
del self.cache[vector_db_id]
|
del self.cache[vector_db_id]
|
||||||
|
|
|
@ -65,7 +65,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
from .config import FireworksImplConfig
|
from .config import FireworksImplConfig
|
||||||
from .models import MODEL_ENTRIES
|
from .models import MODEL_ENTRIES
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference")
|
log = get_logger(name=__name__, category="inference")
|
||||||
|
|
||||||
|
|
||||||
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||||
|
@ -256,7 +256,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
||||||
"stream": bool(request.stream),
|
"stream": bool(request.stream),
|
||||||
**self._build_options(request.sampling_params, request.response_format, request.logprobs),
|
**self._build_options(request.sampling_params, request.response_format, request.logprobs),
|
||||||
}
|
}
|
||||||
logger.debug(f"params to fireworks: {params}")
|
log.debug(f"params to fireworks: {params}")
|
||||||
|
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,6 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import logging
|
|
||||||
|
|
||||||
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
|
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
|
||||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||||
|
@ -11,8 +10,6 @@ from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
|
|
||||||
from .models import MODEL_ENTRIES
|
from .models import MODEL_ENTRIES
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
|
|
||||||
|
@ -33,6 +32,7 @@ from llama_stack.apis.inference import (
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
)
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat
|
from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
|
@ -54,7 +54,7 @@ from .openai_utils import (
|
||||||
)
|
)
|
||||||
from .utils import _is_nvidia_hosted
|
from .utils import _is_nvidia_hosted
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="inference")
|
||||||
|
|
||||||
|
|
||||||
class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
|
class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
|
||||||
|
@ -75,7 +75,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
|
||||||
# TODO(mf): filter by available models
|
# TODO(mf): filter by available models
|
||||||
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
||||||
|
|
||||||
logger.info(f"Initializing NVIDIAInferenceAdapter({config.url})...")
|
log.info(f"Initializing NVIDIAInferenceAdapter({config.url})...")
|
||||||
|
|
||||||
if _is_nvidia_hosted(config):
|
if _is_nvidia_hosted(config):
|
||||||
if not config.api_key:
|
if not config.api_key:
|
||||||
|
|
|
@ -4,13 +4,14 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
from . import NVIDIAConfig
|
from . import NVIDIAConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="inference")
|
||||||
|
|
||||||
|
|
||||||
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:
|
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:
|
||||||
|
@ -44,7 +45,7 @@ async def check_health(config: NVIDIAConfig) -> None:
|
||||||
RuntimeError: If the server is not running or ready
|
RuntimeError: If the server is not running or ready
|
||||||
"""
|
"""
|
||||||
if not _is_nvidia_hosted(config):
|
if not _is_nvidia_hosted(config):
|
||||||
logger.info("Checking NVIDIA NIM health...")
|
log.info("Checking NVIDIA NIM health...")
|
||||||
try:
|
try:
|
||||||
is_live, is_ready = await _get_health(config.url)
|
is_live, is_ready = await _get_health(config.url)
|
||||||
if not is_live:
|
if not is_live:
|
||||||
|
|
|
@ -85,7 +85,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
|
||||||
from .models import MODEL_ENTRIES
|
from .models import MODEL_ENTRIES
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference")
|
log = get_logger(name=__name__, category="inference")
|
||||||
|
|
||||||
|
|
||||||
class OllamaInferenceAdapter(
|
class OllamaInferenceAdapter(
|
||||||
|
@ -117,10 +117,10 @@ class OllamaInferenceAdapter(
|
||||||
return self._openai_client
|
return self._openai_client
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
logger.info(f"checking connectivity to Ollama at `{self.config.url}`...")
|
log.info(f"checking connectivity to Ollama at `{self.config.url}`...")
|
||||||
health_response = await self.health()
|
health_response = await self.health()
|
||||||
if health_response["status"] == HealthStatus.ERROR:
|
if health_response["status"] == HealthStatus.ERROR:
|
||||||
logger.warning(
|
log.warning(
|
||||||
"Ollama Server is not running, make sure to start it using `ollama serve` in a separate terminal"
|
"Ollama Server is not running, make sure to start it using `ollama serve` in a separate terminal"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -339,7 +339,7 @@ class OllamaInferenceAdapter(
|
||||||
"options": sampling_options,
|
"options": sampling_options,
|
||||||
"stream": request.stream,
|
"stream": request.stream,
|
||||||
}
|
}
|
||||||
logger.debug(f"params to ollama: {params}")
|
log.debug(f"params to ollama: {params}")
|
||||||
|
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
@ -437,7 +437,7 @@ class OllamaInferenceAdapter(
|
||||||
if provider_resource_id not in available_models:
|
if provider_resource_id not in available_models:
|
||||||
available_models_latest = [m.model.split(":latest")[0] for m in response.models]
|
available_models_latest = [m.model.split(":latest")[0] for m in response.models]
|
||||||
if provider_resource_id in available_models_latest:
|
if provider_resource_id in available_models_latest:
|
||||||
logger.warning(
|
log.warning(
|
||||||
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'"
|
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'"
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
|
@ -12,8 +11,6 @@ from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
from .config import OpenAIConfig
|
from .config import OpenAIConfig
|
||||||
from .models import MODEL_ENTRIES
|
from .models import MODEL_ENTRIES
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# This OpenAI adapter implements Inference methods using two mixins -
|
# This OpenAI adapter implements Inference methods using two mixins -
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
|
|
||||||
from huggingface_hub import AsyncInferenceClient, HfApi
|
from huggingface_hub import AsyncInferenceClient, HfApi
|
||||||
|
@ -34,6 +33,7 @@ from llama_stack.apis.inference import (
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.sku_list import all_registered_models
|
from llama_stack.models.llama.sku_list import all_registered_models
|
||||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
|
@ -58,7 +58,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
|
||||||
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
|
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
logger = get_logger(__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
def build_hf_repo_model_entries():
|
def build_hf_repo_model_entries():
|
||||||
|
@ -307,7 +307,7 @@ class TGIAdapter(_HfAdapter):
|
||||||
async def initialize(self, config: TGIImplConfig) -> None:
|
async def initialize(self, config: TGIImplConfig) -> None:
|
||||||
if not config.url:
|
if not config.url:
|
||||||
raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.")
|
raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.")
|
||||||
log.info(f"Initializing TGI client with url={config.url}")
|
logger.info(f"Initializing TGI client with url={config.url}")
|
||||||
self.client = AsyncInferenceClient(
|
self.client = AsyncInferenceClient(
|
||||||
model=config.url,
|
model=config.url,
|
||||||
)
|
)
|
||||||
|
|
|
@ -61,7 +61,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
from .config import TogetherImplConfig
|
from .config import TogetherImplConfig
|
||||||
from .models import MODEL_ENTRIES
|
from .models import MODEL_ENTRIES
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference")
|
log = get_logger(name=__name__, category="inference")
|
||||||
|
|
||||||
|
|
||||||
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||||
|
@ -232,7 +232,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
||||||
"stream": request.stream,
|
"stream": request.stream,
|
||||||
**self._build_options(request.sampling_params, request.logprobs, request.response_format),
|
**self._build_options(request.sampling_params, request.logprobs, request.response_format),
|
||||||
}
|
}
|
||||||
logger.debug(f"params to together: {params}")
|
log.debug(f"params to together: {params}")
|
||||||
return params
|
return params
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
@ -15,8 +14,6 @@ from llama_stack.providers.remote.post_training.nvidia.config import SFTLoRADefa
|
||||||
|
|
||||||
from .config import NvidiaPostTrainingConfig
|
from .config import NvidiaPostTrainingConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def warn_unsupported_params(config_dict: Any, supported_keys: set[str], config_name: str) -> None:
|
def warn_unsupported_params(config_dict: Any, supported_keys: set[str], config_name: str) -> None:
|
||||||
keys = set(config_dict.__annotations__.keys()) if isinstance(config_dict, BaseModel) else config_dict.keys()
|
keys = set(config_dict.__annotations__.keys()) if isinstance(config_dict, BaseModel) else config_dict.keys()
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.inference import Message
|
from llama_stack.apis.inference import Message
|
||||||
|
@ -16,12 +15,13 @@ from llama_stack.apis.safety import (
|
||||||
ViolationLevel,
|
ViolationLevel,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||||
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
||||||
|
|
||||||
from .config import BedrockSafetyConfig
|
from .config import BedrockSafetyConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="safety")
|
||||||
|
|
||||||
|
|
||||||
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||||
|
@ -76,13 +76,13 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
shield_params = shield.params
|
shield_params = shield.params
|
||||||
logger.debug(f"run_shield::{shield_params}::messages={messages}")
|
log.debug(f"run_shield::{shield_params}::messages={messages}")
|
||||||
|
|
||||||
# - convert the messages into format Bedrock expects
|
# - convert the messages into format Bedrock expects
|
||||||
content_messages = []
|
content_messages = []
|
||||||
for message in messages:
|
for message in messages:
|
||||||
content_messages.append({"text": {"text": message.content}})
|
content_messages.append({"text": {"text": message.content}})
|
||||||
logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
|
log.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
|
||||||
|
|
||||||
response = self.bedrock_runtime_client.apply_guardrail(
|
response = self.bedrock_runtime_client.apply_guardrail(
|
||||||
guardrailIdentifier=shield.provider_resource_id,
|
guardrailIdentifier=shield.provider_resource_id,
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
@ -17,8 +16,6 @@ from llama_stack.providers.utils.inference.openai_compat import convert_message_
|
||||||
|
|
||||||
from .config import NVIDIASafetyConfig
|
from .config import NVIDIASafetyConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||||
def __init__(self, config: NVIDIASafetyConfig) -> None:
|
def __init__(self, config: NVIDIASafetyConfig) -> None:
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
@ -20,12 +19,13 @@ from llama_stack.apis.safety import (
|
||||||
)
|
)
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new
|
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new
|
||||||
|
|
||||||
from .config import SambaNovaSafetyConfig
|
from .config import SambaNovaSafetyConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="safety")
|
||||||
|
|
||||||
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
||||||
|
|
||||||
|
@ -66,7 +66,7 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide
|
||||||
"guard" not in shield.provider_resource_id.lower()
|
"guard" not in shield.provider_resource_id.lower()
|
||||||
or shield.provider_resource_id.split("sambanova/")[-1] not in self.environment_available_models
|
or shield.provider_resource_id.split("sambanova/")[-1] not in self.environment_available_models
|
||||||
):
|
):
|
||||||
logger.warning(f"Shield {shield.provider_resource_id} not available in {list_models_url}")
|
log.warning(f"Shield {shield.provider_resource_id} not available in {list_models_url}")
|
||||||
|
|
||||||
async def unregister_shield(self, identifier: str) -> None:
|
async def unregister_shield(self, identifier: str) -> None:
|
||||||
pass
|
pass
|
||||||
|
@ -79,9 +79,9 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide
|
||||||
raise ValueError(f"Shield {shield_id} not found")
|
raise ValueError(f"Shield {shield_id} not found")
|
||||||
|
|
||||||
shield_params = shield.params
|
shield_params = shield.params
|
||||||
logger.debug(f"run_shield::{shield_params}::messages={messages}")
|
log.debug(f"run_shield::{shield_params}::messages={messages}")
|
||||||
content_messages = [await convert_message_to_openai_dict_new(m) for m in messages]
|
content_messages = [await convert_message_to_openai_dict_new(m) for m in messages]
|
||||||
logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
|
log.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
|
||||||
|
|
||||||
response = litellm.completion(
|
response = litellm.completion(
|
||||||
model=shield.provider_resource_id, messages=content_messages, api_key=self._get_api_key()
|
model=shield.provider_resource_id, messages=content_messages, api_key=self._get_api_key()
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
@ -20,6 +19,7 @@ from llama_stack.apis.vector_io import (
|
||||||
QueryChunksResponse,
|
QueryChunksResponse,
|
||||||
VectorIO,
|
VectorIO,
|
||||||
)
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||||
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
|
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
|
@ -32,8 +32,6 @@ from llama_stack.providers.utils.memory.vector_store import (
|
||||||
|
|
||||||
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
|
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI
|
ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI
|
||||||
|
|
||||||
VERSION = "v3"
|
VERSION = "v3"
|
||||||
|
@ -43,6 +41,8 @@ OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:chroma:{VERSION}::"
|
||||||
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:chroma:{VERSION}::"
|
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:chroma:{VERSION}::"
|
||||||
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:chroma:{VERSION}::"
|
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:chroma:{VERSION}::"
|
||||||
|
|
||||||
|
logger = get_logger(__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
# this is a helper to allow us to use async and non-async chroma clients interchangeably
|
# this is a helper to allow us to use async and non-async chroma clients interchangeably
|
||||||
async def maybe_await(result):
|
async def maybe_await(result):
|
||||||
|
@ -92,7 +92,7 @@ class ChromaIndex(EmbeddingIndex):
|
||||||
doc = json.loads(doc)
|
doc = json.loads(doc)
|
||||||
chunk = Chunk(**doc)
|
chunk = Chunk(**doc)
|
||||||
except Exception:
|
except Exception:
|
||||||
log.exception(f"Failed to parse document: {doc}")
|
logger.exception(f"Failed to parse document: {doc}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
score = 1.0 / float(dist) if dist != 0 else float("inf")
|
score = 1.0 / float(dist) if dist != 0 else float("inf")
|
||||||
|
@ -137,7 +137,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
inference_api: Api.inference,
|
inference_api: Api.inference,
|
||||||
files_api: Files | None,
|
files_api: Files | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
log.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
|
logger.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.client = None
|
self.client = None
|
||||||
|
@ -150,7 +150,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
self.vector_db_store = self.kvstore
|
self.vector_db_store = self.kvstore
|
||||||
|
|
||||||
if isinstance(self.config, RemoteChromaVectorIOConfig):
|
if isinstance(self.config, RemoteChromaVectorIOConfig):
|
||||||
log.info(f"Connecting to Chroma server at: {self.config.url}")
|
logger.info(f"Connecting to Chroma server at: {self.config.url}")
|
||||||
url = self.config.url.rstrip("/")
|
url = self.config.url.rstrip("/")
|
||||||
parsed = urlparse(url)
|
parsed = urlparse(url)
|
||||||
|
|
||||||
|
@ -159,7 +159,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
|
|
||||||
self.client = await chromadb.AsyncHttpClient(host=parsed.hostname, port=parsed.port)
|
self.client = await chromadb.AsyncHttpClient(host=parsed.hostname, port=parsed.port)
|
||||||
else:
|
else:
|
||||||
log.info(f"Connecting to Chroma local db at: {self.config.db_path}")
|
logger.info(f"Connecting to Chroma local db at: {self.config.db_path}")
|
||||||
self.client = chromadb.PersistentClient(path=self.config.db_path)
|
self.client = chromadb.PersistentClient(path=self.config.db_path)
|
||||||
self.openai_vector_stores = await self._load_openai_vector_stores()
|
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||||
|
|
||||||
|
@ -182,7 +182,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
|
|
||||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||||
if vector_db_id not in self.cache:
|
if vector_db_id not in self.cache:
|
||||||
log.warning(f"Vector DB {vector_db_id} not found")
|
logger.warning(f"Vector DB {vector_db_id} not found")
|
||||||
return
|
return
|
||||||
|
|
||||||
await self.cache[vector_db_id].index.delete()
|
await self.cache[vector_db_id].index.delete()
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
@ -21,6 +20,7 @@ from llama_stack.apis.vector_io import (
|
||||||
QueryChunksResponse,
|
QueryChunksResponse,
|
||||||
VectorIO,
|
VectorIO,
|
||||||
)
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||||
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
|
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
|
@ -34,7 +34,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti
|
||||||
|
|
||||||
from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig
|
from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
VERSION = "v3"
|
VERSION = "v3"
|
||||||
VECTOR_DBS_PREFIX = f"vector_dbs:milvus:{VERSION}::"
|
VECTOR_DBS_PREFIX = f"vector_dbs:milvus:{VERSION}::"
|
||||||
|
@ -68,7 +68,7 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
)
|
)
|
||||||
|
|
||||||
if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
|
if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
|
||||||
logger.info(f"Creating new collection {self.collection_name} with nullable sparse field")
|
log.info(f"Creating new collection {self.collection_name} with nullable sparse field")
|
||||||
# Create schema for vector search
|
# Create schema for vector search
|
||||||
schema = self.client.create_schema()
|
schema = self.client.create_schema()
|
||||||
schema.add_field(
|
schema.add_field(
|
||||||
|
@ -147,7 +147,7 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
data=data,
|
data=data,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}")
|
log.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||||
|
@ -203,7 +203,7 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
return QueryChunksResponse(chunks=filtered_chunks, scores=filtered_scores)
|
return QueryChunksResponse(chunks=filtered_chunks, scores=filtered_scores)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error performing BM25 search: {e}")
|
log.error(f"Error performing BM25 search: {e}")
|
||||||
# Fallback to simple text search
|
# Fallback to simple text search
|
||||||
return await self._fallback_keyword_search(query_string, k, score_threshold)
|
return await self._fallback_keyword_search(query_string, k, score_threshold)
|
||||||
|
|
||||||
|
@ -247,7 +247,7 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
self.client.delete, collection_name=self.collection_name, filter=f'chunk_id == "{chunk_id}"'
|
self.client.delete, collection_name=self.collection_name, filter=f'chunk_id == "{chunk_id}"'
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error deleting chunk {chunk_id} from Milvus collection {self.collection_name}: {e}")
|
log.error(f"Error deleting chunk {chunk_id} from Milvus collection {self.collection_name}: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@ -288,10 +288,10 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
)
|
)
|
||||||
self.cache[vector_db.identifier] = index
|
self.cache[vector_db.identifier] = index
|
||||||
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
||||||
logger.info(f"Connecting to Milvus server at {self.config.uri}")
|
log.info(f"Connecting to Milvus server at {self.config.uri}")
|
||||||
self.client = MilvusClient(**self.config.model_dump(exclude_none=True))
|
self.client = MilvusClient(**self.config.model_dump(exclude_none=True))
|
||||||
else:
|
else:
|
||||||
logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}")
|
log.info(f"Connecting to Milvus Lite at: {self.config.db_path}")
|
||||||
uri = os.path.expanduser(self.config.db_path)
|
uri = os.path.expanduser(self.config.db_path)
|
||||||
self.client = MilvusClient(uri=uri)
|
self.client = MilvusClient(uri=uri)
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import psycopg2
|
import psycopg2
|
||||||
|
@ -22,6 +21,7 @@ from llama_stack.apis.vector_io import (
|
||||||
QueryChunksResponse,
|
QueryChunksResponse,
|
||||||
VectorIO,
|
VectorIO,
|
||||||
)
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||||
|
@ -33,8 +33,6 @@ from llama_stack.providers.utils.memory.vector_store import (
|
||||||
|
|
||||||
from .config import PGVectorVectorIOConfig
|
from .config import PGVectorVectorIOConfig
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
VERSION = "v3"
|
VERSION = "v3"
|
||||||
VECTOR_DBS_PREFIX = f"vector_dbs:pgvector:{VERSION}::"
|
VECTOR_DBS_PREFIX = f"vector_dbs:pgvector:{VERSION}::"
|
||||||
VECTOR_INDEX_PREFIX = f"vector_index:pgvector:{VERSION}::"
|
VECTOR_INDEX_PREFIX = f"vector_index:pgvector:{VERSION}::"
|
||||||
|
@ -42,6 +40,8 @@ OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:pgvector:{VERSION}::"
|
||||||
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:pgvector:{VERSION}::"
|
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:pgvector:{VERSION}::"
|
||||||
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:pgvector:{VERSION}::"
|
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:pgvector:{VERSION}::"
|
||||||
|
|
||||||
|
logger = get_logger(__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
def check_extension_version(cur):
|
def check_extension_version(cur):
|
||||||
cur.execute("SELECT extversion FROM pg_extension WHERE extname = 'vector'")
|
cur.execute("SELECT extversion FROM pg_extension WHERE extname = 'vector'")
|
||||||
|
@ -187,7 +187,7 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
|
||||||
self.metadatadata_collection_name = "openai_vector_stores_metadata"
|
self.metadatadata_collection_name = "openai_vector_stores_metadata"
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
log.info(f"Initializing PGVector memory adapter with config: {self.config}")
|
logger.info(f"Initializing PGVector memory adapter with config: {self.config}")
|
||||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||||
await self.initialize_openai_vector_stores()
|
await self.initialize_openai_vector_stores()
|
||||||
|
|
||||||
|
@ -203,7 +203,7 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
|
||||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||||
version = check_extension_version(cur)
|
version = check_extension_version(cur)
|
||||||
if version:
|
if version:
|
||||||
log.info(f"Vector extension version: {version}")
|
logger.info(f"Vector extension version: {version}")
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Vector extension is not installed.")
|
raise RuntimeError("Vector extension is not installed.")
|
||||||
|
|
||||||
|
@ -216,13 +216,13 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception("Could not connect to PGVector database server")
|
logger.exception("Could not connect to PGVector database server")
|
||||||
raise RuntimeError("Could not connect to PGVector database server") from e
|
raise RuntimeError("Could not connect to PGVector database server") from e
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
if self.conn is not None:
|
if self.conn is not None:
|
||||||
self.conn.close()
|
self.conn.close()
|
||||||
log.info("Connection to PGVector database server closed")
|
logger.info("Connection to PGVector database server closed")
|
||||||
|
|
||||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||||
# Persist vector DB metadata in the KV store
|
# Persist vector DB metadata in the KV store
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
@ -24,6 +23,7 @@ from llama_stack.apis.vector_io import (
|
||||||
VectorStoreChunkingStrategy,
|
VectorStoreChunkingStrategy,
|
||||||
VectorStoreFileObject,
|
VectorStoreFileObject,
|
||||||
)
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||||
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
|
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
|
||||||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
||||||
|
@ -35,13 +35,14 @@ from llama_stack.providers.utils.memory.vector_store import (
|
||||||
|
|
||||||
from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
|
from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
|
||||||
CHUNK_ID_KEY = "_chunk_id"
|
CHUNK_ID_KEY = "_chunk_id"
|
||||||
|
|
||||||
# KV store prefixes for vector databases
|
# KV store prefixes for vector databases
|
||||||
VERSION = "v3"
|
VERSION = "v3"
|
||||||
VECTOR_DBS_PREFIX = f"vector_dbs:qdrant:{VERSION}::"
|
VECTOR_DBS_PREFIX = f"vector_dbs:qdrant:{VERSION}::"
|
||||||
|
|
||||||
|
logger = get_logger(__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
def convert_id(_id: str) -> str:
|
def convert_id(_id: str) -> str:
|
||||||
"""
|
"""
|
||||||
|
@ -96,7 +97,7 @@ class QdrantIndex(EmbeddingIndex):
|
||||||
points_selector=models.PointIdsList(points=[convert_id(chunk_id)]),
|
points_selector=models.PointIdsList(points=[convert_id(chunk_id)]),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error deleting chunk {chunk_id} from Qdrant collection {self.collection_name}: {e}")
|
logger.error(f"Error deleting chunk {chunk_id} from Qdrant collection {self.collection_name}: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||||
|
@ -118,7 +119,7 @@ class QdrantIndex(EmbeddingIndex):
|
||||||
try:
|
try:
|
||||||
chunk = Chunk(**point.payload["chunk_content"])
|
chunk = Chunk(**point.payload["chunk_content"])
|
||||||
except Exception:
|
except Exception:
|
||||||
log.exception("Failed to parse chunk")
|
logger.exception("Failed to parse chunk")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import weaviate
|
import weaviate
|
||||||
|
@ -19,6 +18,7 @@ from llama_stack.apis.files.files import Files
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||||
|
@ -33,8 +33,6 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti
|
||||||
|
|
||||||
from .config import WeaviateVectorIOConfig
|
from .config import WeaviateVectorIOConfig
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
VERSION = "v3"
|
VERSION = "v3"
|
||||||
VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::"
|
VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::"
|
||||||
VECTOR_INDEX_PREFIX = f"vector_index:weaviate:{VERSION}::"
|
VECTOR_INDEX_PREFIX = f"vector_index:weaviate:{VERSION}::"
|
||||||
|
@ -42,6 +40,8 @@ OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:weaviate:{VERSION}::"
|
||||||
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:weaviate:{VERSION}::"
|
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:weaviate:{VERSION}::"
|
||||||
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:weaviate:{VERSION}::"
|
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:weaviate:{VERSION}::"
|
||||||
|
|
||||||
|
logger = get_logger(__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
class WeaviateIndex(EmbeddingIndex):
|
class WeaviateIndex(EmbeddingIndex):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -102,7 +102,7 @@ class WeaviateIndex(EmbeddingIndex):
|
||||||
chunk_dict = json.loads(chunk_json)
|
chunk_dict = json.loads(chunk_json)
|
||||||
chunk = Chunk(**chunk_dict)
|
chunk = Chunk(**chunk_dict)
|
||||||
except Exception:
|
except Exception:
|
||||||
log.exception(f"Failed to parse document: {chunk_json}")
|
logger.exception(f"Failed to parse document: {chunk_json}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
score = 1.0 / doc.metadata.distance if doc.metadata.distance != 0 else float("inf")
|
score = 1.0 / doc.metadata.distance if doc.metadata.distance != 0 else float("inf")
|
||||||
|
@ -171,7 +171,7 @@ class WeaviateVectorIOAdapter(
|
||||||
|
|
||||||
def _get_client(self) -> weaviate.Client:
|
def _get_client(self) -> weaviate.Client:
|
||||||
if "localhost" in self.config.weaviate_cluster_url:
|
if "localhost" in self.config.weaviate_cluster_url:
|
||||||
log.info("using Weaviate locally in container")
|
logger.info("using Weaviate locally in container")
|
||||||
host, port = self.config.weaviate_cluster_url.split(":")
|
host, port = self.config.weaviate_cluster_url.split(":")
|
||||||
key = "local_test"
|
key = "local_test"
|
||||||
client = weaviate.connect_to_local(
|
client = weaviate.connect_to_local(
|
||||||
|
@ -179,7 +179,7 @@ class WeaviateVectorIOAdapter(
|
||||||
port=port,
|
port=port,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
log.info("Using Weaviate remote cluster with URL")
|
logger.info("Using Weaviate remote cluster with URL")
|
||||||
key = f"{self.config.weaviate_cluster_url}::{self.config.weaviate_api_key}"
|
key = f"{self.config.weaviate_cluster_url}::{self.config.weaviate_api_key}"
|
||||||
if key in self.client_cache:
|
if key in self.client_cache:
|
||||||
return self.client_cache[key]
|
return self.client_cache[key]
|
||||||
|
@ -197,7 +197,7 @@ class WeaviateVectorIOAdapter(
|
||||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||||
else:
|
else:
|
||||||
self.kvstore = None
|
self.kvstore = None
|
||||||
log.info("No kvstore configured, registry will not persist across restarts")
|
logger.info("No kvstore configured, registry will not persist across restarts")
|
||||||
|
|
||||||
# Load existing vector DB definitions
|
# Load existing vector DB definitions
|
||||||
if self.kvstore is not None:
|
if self.kvstore is not None:
|
||||||
|
@ -254,7 +254,7 @@ class WeaviateVectorIOAdapter(
|
||||||
client = self._get_client()
|
client = self._get_client()
|
||||||
sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True)
|
sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True)
|
||||||
if sanitized_collection_name not in self.cache or client.collections.exists(sanitized_collection_name) is False:
|
if sanitized_collection_name not in self.cache or client.collections.exists(sanitized_collection_name) is False:
|
||||||
log.warning(f"Vector DB {sanitized_collection_name} not found")
|
logger.warning(f"Vector DB {sanitized_collection_name} not found")
|
||||||
return
|
return
|
||||||
client.collections.delete(sanitized_collection_name)
|
client.collections.delete(sanitized_collection_name)
|
||||||
await self.cache[sanitized_collection_name].index.delete()
|
await self.cache[sanitized_collection_name].index.delete()
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import logging
|
|
||||||
import struct
|
import struct
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
@ -27,7 +26,9 @@ from llama_stack.providers.utils.inference.prompt_adapter import interleaved_con
|
||||||
EMBEDDING_MODELS = {}
|
EMBEDDING_MODELS = {}
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
log = get_logger(name=__name__, category="inference")
|
||||||
|
|
||||||
|
|
||||||
class SentenceTransformerEmbeddingMixin:
|
class SentenceTransformerEmbeddingMixin:
|
||||||
|
|
|
@ -54,7 +54,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
interleaved_content_as_str,
|
interleaved_content_as_str,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference")
|
log = get_logger(name=__name__, category="inference")
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMOpenAIMixin(
|
class LiteLLMOpenAIMixin(
|
||||||
|
@ -157,7 +157,7 @@ class LiteLLMOpenAIMixin(
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
params["model"] = self.get_litellm_model_name(params["model"])
|
params["model"] = self.get_litellm_model_name(params["model"])
|
||||||
|
|
||||||
logger.debug(f"params to litellm (openai compat): {params}")
|
log.debug(f"params to litellm (openai compat): {params}")
|
||||||
# see https://docs.litellm.ai/docs/completion/stream#async-completion
|
# see https://docs.litellm.ai/docs/completion/stream#async-completion
|
||||||
response = await litellm.acompletion(**params)
|
response = await litellm.acompletion(**params)
|
||||||
if stream:
|
if stream:
|
||||||
|
@ -460,7 +460,7 @@ class LiteLLMOpenAIMixin(
|
||||||
:return: True if the model is available dynamically, False otherwise.
|
:return: True if the model is available dynamically, False otherwise.
|
||||||
"""
|
"""
|
||||||
if self.litellm_provider_name not in litellm.models_by_provider:
|
if self.litellm_provider_name not in litellm.models_by_provider:
|
||||||
logger.error(f"Provider {self.litellm_provider_name} is not registered in litellm.")
|
log.error(f"Provider {self.litellm_provider_name} is not registered in litellm.")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return model in litellm.models_by_provider[self.litellm_provider_name]
|
return model in litellm.models_by_provider[self.litellm_provider_name]
|
||||||
|
|
|
@ -17,7 +17,7 @@ from llama_stack.providers.utils.inference import (
|
||||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
|
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
log = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
class RemoteInferenceProviderConfig(BaseModel):
|
class RemoteInferenceProviderConfig(BaseModel):
|
||||||
|
@ -135,7 +135,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||||
:param model: The model identifier to check.
|
:param model: The model identifier to check.
|
||||||
:return: True if the model is available dynamically, False otherwise.
|
:return: True if the model is available dynamically, False otherwise.
|
||||||
"""
|
"""
|
||||||
logger.info(
|
log.info(
|
||||||
f"check_model_availability is not implemented for {self.__class__.__name__}. Returning False by default."
|
f"check_model_availability is not implemented for {self.__class__.__name__}. Returning False by default."
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import struct
|
import struct
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
@ -116,6 +115,7 @@ from llama_stack.apis.inference import (
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
OpenAIChoice as OpenAIChatCompletionChoice,
|
OpenAIChoice as OpenAIChatCompletionChoice,
|
||||||
)
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
StopReason,
|
StopReason,
|
||||||
|
@ -128,7 +128,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
decode_assistant_message,
|
decode_assistant_message,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="inference")
|
||||||
|
|
||||||
|
|
||||||
class OpenAICompatCompletionChoiceDelta(BaseModel):
|
class OpenAICompatCompletionChoiceDelta(BaseModel):
|
||||||
|
@ -316,7 +316,7 @@ def process_chat_completion_response(
|
||||||
if t.tool_name in request_tools:
|
if t.tool_name in request_tools:
|
||||||
new_tool_calls.append(t)
|
new_tool_calls.append(t)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Tool {t.tool_name} not found in request tools")
|
log.warning(f"Tool {t.tool_name} not found in request tools")
|
||||||
|
|
||||||
if len(new_tool_calls) < len(raw_message.tool_calls):
|
if len(new_tool_calls) < len(raw_message.tool_calls):
|
||||||
raw_message.tool_calls = new_tool_calls
|
raw_message.tool_calls = new_tool_calls
|
||||||
|
@ -477,7 +477,7 @@ async def process_chat_completion_stream_response(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Tool {tool_call.tool_name} not found in request tools")
|
log.warning(f"Tool {tool_call.tool_name} not found in request tools")
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
|
@ -1198,7 +1198,7 @@ async def convert_openai_chat_completion_stream(
|
||||||
)
|
)
|
||||||
|
|
||||||
for idx, buffer in tool_call_idx_to_buffer.items():
|
for idx, buffer in tool_call_idx_to_buffer.items():
|
||||||
logger.debug(f"toolcall_buffer[{idx}]: {buffer}")
|
log.debug(f"toolcall_buffer[{idx}]: {buffer}")
|
||||||
if buffer["name"]:
|
if buffer["name"]:
|
||||||
delta = ")"
|
delta = ")"
|
||||||
buffer["content"] += delta
|
buffer["content"] += delta
|
||||||
|
|
|
@ -25,7 +25,7 @@ from llama_stack.apis.inference import (
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
log = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
class OpenAIMixin(ABC):
|
class OpenAIMixin(ABC):
|
||||||
|
@ -125,9 +125,9 @@ class OpenAIMixin(ABC):
|
||||||
Direct OpenAI completion API call.
|
Direct OpenAI completion API call.
|
||||||
"""
|
"""
|
||||||
if guided_choice is not None:
|
if guided_choice is not None:
|
||||||
logger.warning("guided_choice is not supported by the OpenAI API. Ignoring.")
|
log.warning("guided_choice is not supported by the OpenAI API. Ignoring.")
|
||||||
if prompt_logprobs is not None:
|
if prompt_logprobs is not None:
|
||||||
logger.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.")
|
log.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.")
|
||||||
|
|
||||||
# TODO: fix openai_completion to return type compatible with OpenAI's API response
|
# TODO: fix openai_completion to return type compatible with OpenAI's API response
|
||||||
return await self.client.completions.create( # type: ignore[no-any-return]
|
return await self.client.completions.create( # type: ignore[no-any-return]
|
||||||
|
@ -267,6 +267,6 @@ class OpenAIMixin(ABC):
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# All other errors (auth, rate limit, network, etc.)
|
# All other errors (auth, rate limit, network, etc.)
|
||||||
logger.warning(f"Failed to check model availability for {model}: {e}")
|
log.warning(f"Failed to check model availability for {model}: {e}")
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
|
@ -4,16 +4,16 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from pymongo import AsyncMongoClient
|
from pymongo import AsyncMongoClient
|
||||||
|
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
|
|
||||||
from ..config import MongoDBKVStoreConfig
|
from ..config import MongoDBKVStoreConfig
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
class MongoDBKVStoreImpl(KVStore):
|
class MongoDBKVStoreImpl(KVStore):
|
||||||
|
|
|
@ -4,16 +4,17 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
import psycopg2
|
import psycopg2
|
||||||
from psycopg2.extras import DictCursor
|
from psycopg2.extras import DictCursor
|
||||||
|
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
from ..api import KVStore
|
from ..api import KVStore
|
||||||
from ..config import PostgresKVStoreConfig
|
from ..config import PostgresKVStoreConfig
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
class PostgresKVStoreImpl(KVStore):
|
class PostgresKVStoreImpl(KVStore):
|
||||||
|
|
|
@ -6,7 +6,6 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
@ -37,10 +36,11 @@ from llama_stack.apis.vector_io import (
|
||||||
VectorStoreSearchResponse,
|
VectorStoreSearchResponse,
|
||||||
VectorStoreSearchResponsePage,
|
VectorStoreSearchResponsePage,
|
||||||
)
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||||
from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, make_overlapped_chunks
|
from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, make_overlapped_chunks
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
# Constants for OpenAI vector stores
|
# Constants for OpenAI vector stores
|
||||||
CHUNK_MULTIPLIER = 5
|
CHUNK_MULTIPLIER = 5
|
||||||
|
@ -378,7 +378,7 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
try:
|
try:
|
||||||
await self.unregister_vector_db(vector_store_id)
|
await self.unregister_vector_db(vector_store_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to delete underlying vector DB {vector_store_id}: {e}")
|
log.warning(f"Failed to delete underlying vector DB {vector_store_id}: {e}")
|
||||||
|
|
||||||
return VectorStoreDeleteResponse(
|
return VectorStoreDeleteResponse(
|
||||||
id=vector_store_id,
|
id=vector_store_id,
|
||||||
|
@ -460,7 +460,7 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error searching vector store {vector_store_id}: {e}")
|
log.error(f"Error searching vector store {vector_store_id}: {e}")
|
||||||
# Return empty results on error
|
# Return empty results on error
|
||||||
return VectorStoreSearchResponsePage(
|
return VectorStoreSearchResponsePage(
|
||||||
search_query=search_query,
|
search_query=search_query,
|
||||||
|
@ -614,7 +614,7 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
)
|
)
|
||||||
vector_store_file_object.status = "completed"
|
vector_store_file_object.status = "completed"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error attaching file to vector store: {e}")
|
log.error(f"Error attaching file to vector store: {e}")
|
||||||
vector_store_file_object.status = "failed"
|
vector_store_file_object.status = "failed"
|
||||||
vector_store_file_object.last_error = VectorStoreFileLastError(
|
vector_store_file_object.last_error = VectorStoreFileLastError(
|
||||||
code="server_error",
|
code="server_error",
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
import logging
|
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
@ -25,6 +24,7 @@ from llama_stack.apis.common.content_types import (
|
||||||
from llama_stack.apis.tools import RAGDocument
|
from llama_stack.apis.tools import RAGDocument
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
|
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
@ -32,12 +32,12 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
|
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Constants for reranker types
|
# Constants for reranker types
|
||||||
RERANKER_TYPE_RRF = "rrf"
|
RERANKER_TYPE_RRF = "rrf"
|
||||||
RERANKER_TYPE_WEIGHTED = "weighted"
|
RERANKER_TYPE_WEIGHTED = "weighted"
|
||||||
|
|
||||||
|
log = get_logger(name=__name__, category="memory")
|
||||||
|
|
||||||
|
|
||||||
def parse_pdf(data: bytes) -> str:
|
def parse_pdf(data: bytes) -> str:
|
||||||
# For PDF and DOC/DOCX files, we can't reliably convert to string
|
# For PDF and DOC/DOCX files, we can't reliably convert to string
|
||||||
|
|
|
@ -17,7 +17,7 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="scheduler")
|
log = get_logger(name=__name__, category="scheduler")
|
||||||
|
|
||||||
|
|
||||||
# TODO: revisit the list of possible statuses when defining a more coherent
|
# TODO: revisit the list of possible statuses when defining a more coherent
|
||||||
|
@ -186,7 +186,7 @@ class _NaiveSchedulerBackend(_SchedulerBackend):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
on_log_message_cb(str(e))
|
on_log_message_cb(str(e))
|
||||||
job.status = JobStatus.failed
|
job.status = JobStatus.failed
|
||||||
logger.exception(f"Job {job.id} failed.")
|
log.exception(f"Job {job.id} failed.")
|
||||||
|
|
||||||
asyncio.run_coroutine_threadsafe(do(), self._loop)
|
asyncio.run_coroutine_threadsafe(do(), self._loop)
|
||||||
|
|
||||||
|
@ -222,7 +222,7 @@ class Scheduler:
|
||||||
msg = (datetime.now(UTC), message)
|
msg = (datetime.now(UTC), message)
|
||||||
# At least for the time being, until there's a better way to expose
|
# At least for the time being, until there's a better way to expose
|
||||||
# logs to users, log messages on console
|
# logs to users, log messages on console
|
||||||
logger.info(f"Job {job.id}: {message}")
|
log.info(f"Job {job.id}: {message}")
|
||||||
job.append_log(msg)
|
job.append_log(msg)
|
||||||
self._backend.on_log_message_cb(job, msg)
|
self._backend.on_log_message_cb(job, msg)
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@ from llama_stack.log import get_logger
|
||||||
from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore
|
from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore
|
||||||
from .sqlstore import SqlStoreType
|
from .sqlstore import SqlStoreType
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="authorized_sqlstore")
|
log = get_logger(name=__name__, category="authorized_sqlstore")
|
||||||
|
|
||||||
# Hardcoded copy of the default policy that our SQL filtering implements
|
# Hardcoded copy of the default policy that our SQL filtering implements
|
||||||
# WARNING: If default_policy() changes, this constant must be updated accordingly
|
# WARNING: If default_policy() changes, this constant must be updated accordingly
|
||||||
|
@ -81,7 +81,7 @@ class AuthorizedSqlStore:
|
||||||
actual_default = default_policy()
|
actual_default = default_policy()
|
||||||
|
|
||||||
if SQL_OPTIMIZED_POLICY != actual_default:
|
if SQL_OPTIMIZED_POLICY != actual_default:
|
||||||
logger.warning(
|
log.warning(
|
||||||
f"SQL_OPTIMIZED_POLICY does not match default_policy(). "
|
f"SQL_OPTIMIZED_POLICY does not match default_policy(). "
|
||||||
f"SQL filtering will use conservative mode. "
|
f"SQL filtering will use conservative mode. "
|
||||||
f"Expected: {SQL_OPTIMIZED_POLICY}, Got: {actual_default}",
|
f"Expected: {SQL_OPTIMIZED_POLICY}, Got: {actual_default}",
|
||||||
|
|
|
@ -29,7 +29,7 @@ from llama_stack.log import get_logger
|
||||||
from .api import ColumnDefinition, ColumnType, SqlStore
|
from .api import ColumnDefinition, ColumnType, SqlStore
|
||||||
from .sqlstore import SqlAlchemySqlStoreConfig
|
from .sqlstore import SqlAlchemySqlStoreConfig
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="sqlstore")
|
log = get_logger(name=__name__, category="sqlstore")
|
||||||
|
|
||||||
TYPE_MAPPING: dict[ColumnType, Any] = {
|
TYPE_MAPPING: dict[ColumnType, Any] = {
|
||||||
ColumnType.INTEGER: Integer,
|
ColumnType.INTEGER: Integer,
|
||||||
|
@ -280,5 +280,5 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# If any error occurs during migration, log it but don't fail
|
# If any error occurs during migration, log it but don't fail
|
||||||
# The table creation will handle adding the column
|
# The table creation will handle adding the column
|
||||||
logger.error(f"Error adding column {column_name} to table {table}: {e}")
|
log.error(f"Error adding column {column_name} to table {table}: {e}")
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import contextvars
|
import contextvars
|
||||||
import logging
|
import logging # allow-direct-logging
|
||||||
import queue
|
import queue
|
||||||
import random
|
import random
|
||||||
import threading
|
import threading
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
@ -19,10 +18,9 @@ from llama_stack.apis.post_training import (
|
||||||
LoraFinetuningConfig,
|
LoraFinetuningConfig,
|
||||||
TrainingConfig,
|
TrainingConfig,
|
||||||
)
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
# Configure logging
|
log = get_logger(name=__name__, category="core")
|
||||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", force=True)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
skip_because_resource_intensive = pytest.mark.skip(
|
skip_because_resource_intensive = pytest.mark.skip(
|
||||||
|
@ -71,14 +69,14 @@ class TestPostTraining:
|
||||||
)
|
)
|
||||||
@pytest.mark.timeout(360) # 6 minutes timeout
|
@pytest.mark.timeout(360) # 6 minutes timeout
|
||||||
def test_supervised_fine_tune(self, llama_stack_client, purpose, source):
|
def test_supervised_fine_tune(self, llama_stack_client, purpose, source):
|
||||||
logger.info("Starting supervised fine-tuning test")
|
log.info("Starting supervised fine-tuning test")
|
||||||
|
|
||||||
# register dataset to train
|
# register dataset to train
|
||||||
dataset = llama_stack_client.datasets.register(
|
dataset = llama_stack_client.datasets.register(
|
||||||
purpose=purpose,
|
purpose=purpose,
|
||||||
source=source,
|
source=source,
|
||||||
)
|
)
|
||||||
logger.info(f"Registered dataset with ID: {dataset.identifier}")
|
log.info(f"Registered dataset with ID: {dataset.identifier}")
|
||||||
|
|
||||||
algorithm_config = LoraFinetuningConfig(
|
algorithm_config = LoraFinetuningConfig(
|
||||||
type="LoRA",
|
type="LoRA",
|
||||||
|
@ -105,7 +103,7 @@ class TestPostTraining:
|
||||||
)
|
)
|
||||||
|
|
||||||
job_uuid = f"test-job{uuid.uuid4()}"
|
job_uuid = f"test-job{uuid.uuid4()}"
|
||||||
logger.info(f"Starting training job with UUID: {job_uuid}")
|
log.info(f"Starting training job with UUID: {job_uuid}")
|
||||||
|
|
||||||
# train with HF trl SFTTrainer as the default
|
# train with HF trl SFTTrainer as the default
|
||||||
_ = llama_stack_client.post_training.supervised_fine_tune(
|
_ = llama_stack_client.post_training.supervised_fine_tune(
|
||||||
|
@ -121,21 +119,21 @@ class TestPostTraining:
|
||||||
while True:
|
while True:
|
||||||
status = llama_stack_client.post_training.job.status(job_uuid=job_uuid)
|
status = llama_stack_client.post_training.job.status(job_uuid=job_uuid)
|
||||||
if not status:
|
if not status:
|
||||||
logger.error("Job not found")
|
log.error("Job not found")
|
||||||
break
|
break
|
||||||
|
|
||||||
logger.info(f"Current status: {status}")
|
log.info(f"Current status: {status}")
|
||||||
assert status.status in ["scheduled", "in_progress", "completed"]
|
assert status.status in ["scheduled", "in_progress", "completed"]
|
||||||
if status.status == "completed":
|
if status.status == "completed":
|
||||||
break
|
break
|
||||||
|
|
||||||
logger.info("Waiting for job to complete...")
|
log.info("Waiting for job to complete...")
|
||||||
time.sleep(10) # Increased sleep time to reduce polling frequency
|
time.sleep(10) # Increased sleep time to reduce polling frequency
|
||||||
|
|
||||||
artifacts = llama_stack_client.post_training.job.artifacts(job_uuid=job_uuid)
|
artifacts = llama_stack_client.post_training.job.artifacts(job_uuid=job_uuid)
|
||||||
logger.info(f"Job artifacts: {artifacts}")
|
log.info(f"Job artifacts: {artifacts}")
|
||||||
|
|
||||||
logger.info(f"Registered dataset with ID: {dataset.identifier}")
|
log.info(f"Registered dataset with ID: {dataset.identifier}")
|
||||||
|
|
||||||
# TODO: Fix these tests to properly represent the Jobs API in training
|
# TODO: Fix these tests to properly represent the Jobs API in training
|
||||||
#
|
#
|
||||||
|
@ -181,17 +179,21 @@ class TestPostTraining:
|
||||||
)
|
)
|
||||||
@pytest.mark.timeout(360)
|
@pytest.mark.timeout(360)
|
||||||
def test_preference_optimize(self, llama_stack_client, purpose, source):
|
def test_preference_optimize(self, llama_stack_client, purpose, source):
|
||||||
logger.info("Starting DPO preference optimization test")
|
log.info("Starting DPO preference optimization test")
|
||||||
|
|
||||||
# register preference dataset to train
|
# register preference dataset to train
|
||||||
dataset = llama_stack_client.datasets.register(
|
dataset = llama_stack_client.datasets.register(
|
||||||
purpose=purpose,
|
purpose=purpose,
|
||||||
source=source,
|
source=source,
|
||||||
)
|
)
|
||||||
logger.info(f"Registered preference dataset with ID: {dataset.identifier}")
|
log.info(f"Registered preference dataset with ID: {dataset.identifier}")
|
||||||
|
|
||||||
# DPO algorithm configuration
|
# DPO algorithm configuration
|
||||||
algorithm_config = DPOAlignmentConfig(
|
algorithm_config = DPOAlignmentConfig(
|
||||||
|
reward_scale=1.0,
|
||||||
|
reward_clip=10.0,
|
||||||
|
epsilon=1e-8,
|
||||||
|
gamma=0.99,
|
||||||
beta=0.1,
|
beta=0.1,
|
||||||
loss_type=DPOLossType.sigmoid, # Default loss type
|
loss_type=DPOLossType.sigmoid, # Default loss type
|
||||||
)
|
)
|
||||||
|
@ -211,7 +213,7 @@ class TestPostTraining:
|
||||||
)
|
)
|
||||||
|
|
||||||
job_uuid = f"test-dpo-job-{uuid.uuid4()}"
|
job_uuid = f"test-dpo-job-{uuid.uuid4()}"
|
||||||
logger.info(f"Starting DPO training job with UUID: {job_uuid}")
|
log.info(f"Starting DPO training job with UUID: {job_uuid}")
|
||||||
|
|
||||||
# train with HuggingFace DPO implementation
|
# train with HuggingFace DPO implementation
|
||||||
_ = llama_stack_client.post_training.preference_optimize(
|
_ = llama_stack_client.post_training.preference_optimize(
|
||||||
|
@ -226,15 +228,15 @@ class TestPostTraining:
|
||||||
while True:
|
while True:
|
||||||
status = llama_stack_client.post_training.job.status(job_uuid=job_uuid)
|
status = llama_stack_client.post_training.job.status(job_uuid=job_uuid)
|
||||||
if not status:
|
if not status:
|
||||||
logger.error("DPO job not found")
|
log.error("DPO job not found")
|
||||||
break
|
break
|
||||||
|
|
||||||
logger.info(f"Current DPO status: {status}")
|
log.info(f"Current DPO status: {status}")
|
||||||
if status.status == "completed":
|
if status.status == "completed":
|
||||||
break
|
break
|
||||||
|
|
||||||
logger.info("Waiting for DPO job to complete...")
|
log.info("Waiting for DPO job to complete...")
|
||||||
time.sleep(10) # Increased sleep time to reduce polling frequency
|
time.sleep(10) # Increased sleep time to reduce polling frequency
|
||||||
|
|
||||||
artifacts = llama_stack_client.post_training.job.artifacts(job_uuid=job_uuid)
|
artifacts = llama_stack_client.post_training.job.artifacts(job_uuid=job_uuid)
|
||||||
logger.info(f"DPO job artifacts: {artifacts}")
|
log.info(f"DPO job artifacts: {artifacts}")
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
|
||||||
import time
|
import time
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
|
@ -13,8 +12,10 @@ from llama_stack_client import BadRequestError, LlamaStackClient
|
||||||
from openai import BadRequestError as OpenAIBadRequestError
|
from openai import BadRequestError as OpenAIBadRequestError
|
||||||
|
|
||||||
from llama_stack.apis.vector_io import Chunk
|
from llama_stack.apis.vector_io import Chunk
|
||||||
|
from llama_stack.core.library_client import LlamaStackAsLibraryClient
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
log = get_logger(name=__name__, category="vector-io")
|
||||||
|
|
||||||
|
|
||||||
def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
|
def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
|
||||||
|
@ -99,7 +100,7 @@ def compat_client_with_empty_stores(compat_client):
|
||||||
compat_client.vector_stores.delete(vector_store_id=store.id)
|
compat_client.vector_stores.delete(vector_store_id=store.id)
|
||||||
except Exception:
|
except Exception:
|
||||||
# If the API is not available or fails, just continue
|
# If the API is not available or fails, just continue
|
||||||
logger.warning("Failed to clear vector stores")
|
log.warning("Failed to clear vector stores")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def clear_files():
|
def clear_files():
|
||||||
|
@ -109,7 +110,7 @@ def compat_client_with_empty_stores(compat_client):
|
||||||
compat_client.files.delete(file_id=file.id)
|
compat_client.files.delete(file_id=file.id)
|
||||||
except Exception:
|
except Exception:
|
||||||
# If the API is not available or fails, just continue
|
# If the API is not available or fails, just continue
|
||||||
logger.warning("Failed to clear files")
|
log.warning("Failed to clear files")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
clear_vector_stores()
|
clear_vector_stores()
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging # allow-direct-logging
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue