mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-25 22:32:02 +00:00
feat: add auto-generated CI documentation pre-commit hook (#2890)
Our CI is entirely undocumented, this commit adds a README.md file with a table of the current CI and what is does --------- Signed-off-by: Nathan Weinberg <nweinber@redhat.com>
This commit is contained in:
parent
7f834339ba
commit
b381ed6d64
93 changed files with 495 additions and 477 deletions
|
|
@ -84,7 +84,7 @@ MEMORY_QUERY_TOOL = "knowledge_search"
|
|||
WEB_SEARCH_TOOL = "web_search"
|
||||
RAG_TOOL_GROUP = "builtin::rag"
|
||||
|
||||
logger = get_logger(name=__name__, category="agents")
|
||||
log = get_logger(name=__name__, category="agents")
|
||||
|
||||
|
||||
class ChatAgent(ShieldRunnerMixin):
|
||||
|
|
@ -612,7 +612,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
|
||||
if n_iter >= self.agent_config.max_infer_iters:
|
||||
logger.info(f"done with MAX iterations ({n_iter}), exiting.")
|
||||
log.info(f"done with MAX iterations ({n_iter}), exiting.")
|
||||
# NOTE: mark end_of_turn to indicate to client that we are done with the turn
|
||||
# Do not continue the tool call loop after this point
|
||||
message.stop_reason = StopReason.end_of_turn
|
||||
|
|
@ -620,7 +620,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
break
|
||||
|
||||
if stop_reason == StopReason.out_of_tokens:
|
||||
logger.info("out of token budget, exiting.")
|
||||
log.info("out of token budget, exiting.")
|
||||
yield message
|
||||
break
|
||||
|
||||
|
|
@ -634,7 +634,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
message.content = [message.content] + output_attachments
|
||||
yield message
|
||||
else:
|
||||
logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
|
||||
log.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
|
||||
input_messages = input_messages + [message]
|
||||
else:
|
||||
input_messages = input_messages + [message]
|
||||
|
|
@ -889,7 +889,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
else:
|
||||
tool_name_str = tool_name
|
||||
|
||||
logger.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}")
|
||||
log.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}")
|
||||
result = await self.tool_runtime_api.invoke_tool(
|
||||
tool_name=tool_name_str,
|
||||
kwargs={
|
||||
|
|
@ -899,7 +899,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
**self.tool_name_to_args.get(tool_name_str, {}),
|
||||
},
|
||||
)
|
||||
logger.debug(f"tool call {tool_name_str} completed with result: {result}")
|
||||
log.debug(f"tool call {tool_name_str} completed with result: {result}")
|
||||
return result
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import UTC, datetime
|
||||
|
|
@ -42,6 +41,7 @@ from llama_stack.apis.safety import Safety
|
|||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
||||
from llama_stack.providers.utils.pagination import paginate_records
|
||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||
|
|
@ -51,7 +51,7 @@ from .config import MetaReferenceAgentsImplConfig
|
|||
from .openai_responses import OpenAIResponsesImpl
|
||||
from .persistence import AgentInfo
|
||||
|
||||
logger = logging.getLogger()
|
||||
log = get_logger(name=__name__, category="agents")
|
||||
|
||||
|
||||
class MetaReferenceAgentsImpl(Agents):
|
||||
|
|
@ -268,7 +268,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
# Get the agent info using the key
|
||||
agent_info_json = await self.persistence_store.get(agent_key)
|
||||
if not agent_info_json:
|
||||
logger.error(f"Could not find agent info for key {agent_key}")
|
||||
log.error(f"Could not find agent info for key {agent_key}")
|
||||
continue
|
||||
|
||||
try:
|
||||
|
|
@ -281,7 +281,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing agent info for {agent_id}: {e}")
|
||||
log.error(f"Error parsing agent info for {agent_id}: {e}")
|
||||
continue
|
||||
|
||||
# Convert Agent objects to dictionaries
|
||||
|
|
|
|||
|
|
@ -75,7 +75,7 @@ from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefiniti
|
|||
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||
|
||||
logger = get_logger(name=__name__, category="openai_responses")
|
||||
log = get_logger(name=__name__, category="openai_responses")
|
||||
|
||||
OPENAI_RESPONSES_PREFIX = "openai_responses:"
|
||||
|
||||
|
|
@ -544,12 +544,12 @@ class OpenAIResponsesImpl:
|
|||
break
|
||||
|
||||
if function_tool_calls:
|
||||
logger.info("Exiting inference loop since there is a function (client-side) tool call")
|
||||
log.info("Exiting inference loop since there is a function (client-side) tool call")
|
||||
break
|
||||
|
||||
n_iter += 1
|
||||
if n_iter >= max_infer_iters:
|
||||
logger.info(f"Exiting inference loop since iteration count({n_iter}) exceeds {max_infer_iters=}")
|
||||
log.info(f"Exiting inference loop since iteration count({n_iter}) exceeds {max_infer_iters=}")
|
||||
break
|
||||
|
||||
messages = next_turn_messages
|
||||
|
|
@ -698,7 +698,7 @@ class OpenAIResponsesImpl:
|
|||
)
|
||||
return search_response.data
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to search vector store {vector_store_id}: {e}")
|
||||
log.warning(f"Failed to search vector store {vector_store_id}: {e}")
|
||||
return []
|
||||
|
||||
# Run all searches in parallel using gather
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
|
|
@ -15,9 +14,10 @@ from llama_stack.core.access_control.access_control import AccessDeniedError, is
|
|||
from llama_stack.core.access_control.datatypes import AccessRule
|
||||
from llama_stack.core.datatypes import User
|
||||
from llama_stack.core.request_headers import get_authenticated_user
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="agents")
|
||||
|
||||
|
||||
class AgentSessionInfo(Session):
|
||||
|
|
|
|||
|
|
@ -5,13 +5,13 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.telemetry import tracing
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="agents")
|
||||
|
||||
|
||||
class SafetyException(Exception): # noqa: N818
|
||||
|
|
|
|||
|
|
@ -73,11 +73,12 @@ from .config import MetaReferenceInferenceConfig
|
|||
from .generators import LlamaGenerator
|
||||
from .model_parallel import LlamaModelParallelGenerator
|
||||
|
||||
log = get_logger(__name__, category="inference")
|
||||
# there's a single model parallel process running serving the model. for now,
|
||||
# we don't support multiple concurrent requests to this process.
|
||||
SEMAPHORE = asyncio.Semaphore(1)
|
||||
|
||||
logger = get_logger(__name__, category="inference")
|
||||
|
||||
|
||||
def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator:
|
||||
return LlamaGenerator(config, model_id, llama_model)
|
||||
|
|
@ -144,7 +145,7 @@ class MetaReferenceInferenceImpl(
|
|||
return model
|
||||
|
||||
async def load_model(self, model_id, llama_model) -> None:
|
||||
log.info(f"Loading model `{model_id}`")
|
||||
logger.info(f"Loading model `{model_id}`")
|
||||
|
||||
builder_params = [self.config, model_id, llama_model]
|
||||
|
||||
|
|
@ -166,7 +167,7 @@ class MetaReferenceInferenceImpl(
|
|||
self.model_id = model_id
|
||||
self.llama_model = llama_model
|
||||
|
||||
log.info("Warming up...")
|
||||
logger.info("Warming up...")
|
||||
await self.completion(
|
||||
model_id=model_id,
|
||||
content="Hello, world!",
|
||||
|
|
@ -177,7 +178,7 @@ class MetaReferenceInferenceImpl(
|
|||
messages=[UserMessage(content="Hi how are you?")],
|
||||
sampling_params=SamplingParams(max_tokens=20),
|
||||
)
|
||||
log.info("Warmed up!")
|
||||
logger.info("Warmed up!")
|
||||
|
||||
def check_model(self, request) -> None:
|
||||
if self.model_id is None or self.llama_model is None:
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@
|
|||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import tempfile
|
||||
|
|
@ -32,13 +31,14 @@ from fairscale.nn.model_parallel.initialize import (
|
|||
from pydantic import BaseModel, Field
|
||||
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import GenerationResult
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
ChatCompletionRequestWithRawContent,
|
||||
CompletionRequestWithRawContent,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
class ProcessingMessageName(str, Enum):
|
||||
|
|
@ -236,7 +236,7 @@ def worker_process_entrypoint(
|
|||
except StopIteration:
|
||||
break
|
||||
|
||||
log.info("[debug] worker process done")
|
||||
log.info("[debug] worker process done")
|
||||
|
||||
|
||||
def launch_dist_group(
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
|
|
@ -32,8 +31,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
|
||||
from .config import SentenceTransformersInferenceConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SentenceTransformersInferenceImpl(
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@
|
|||
|
||||
import gc
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
|
@ -28,6 +27,7 @@ from llama_stack.apis.post_training import (
|
|||
LoraFinetuningConfig,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
||||
|
||||
from ..config import HuggingFacePostTrainingConfig
|
||||
|
|
@ -44,7 +44,7 @@ from ..utils import (
|
|||
split_dataset,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class HFFinetuningSingleDevice:
|
||||
|
|
@ -69,14 +69,14 @@ class HFFinetuningSingleDevice:
|
|||
try:
|
||||
messages = json.loads(row["chat_completion_input"])
|
||||
if not isinstance(messages, list) or len(messages) != 1:
|
||||
logger.warning(f"Invalid chat_completion_input format: {row['chat_completion_input']}")
|
||||
log.warning(f"Invalid chat_completion_input format: {row['chat_completion_input']}")
|
||||
return None, None
|
||||
if "content" not in messages[0]:
|
||||
logger.warning(f"Message missing content: {messages[0]}")
|
||||
log.warning(f"Message missing content: {messages[0]}")
|
||||
return None, None
|
||||
return messages[0]["content"], row["expected_answer"]
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse chat_completion_input: {row['chat_completion_input']}")
|
||||
log.warning(f"Failed to parse chat_completion_input: {row['chat_completion_input']}")
|
||||
return None, None
|
||||
return None, None
|
||||
|
||||
|
|
@ -86,13 +86,13 @@ class HFFinetuningSingleDevice:
|
|||
try:
|
||||
dialog = json.loads(row["dialog"])
|
||||
if not isinstance(dialog, list) or len(dialog) < 2:
|
||||
logger.warning(f"Dialog must have at least 2 messages: {row['dialog']}")
|
||||
log.warning(f"Dialog must have at least 2 messages: {row['dialog']}")
|
||||
return None, None
|
||||
if dialog[0].get("role") != "user":
|
||||
logger.warning(f"First message must be from user: {dialog[0]}")
|
||||
log.warning(f"First message must be from user: {dialog[0]}")
|
||||
return None, None
|
||||
if not any(msg.get("role") == "assistant" for msg in dialog):
|
||||
logger.warning("Dialog must have at least one assistant message")
|
||||
log.warning("Dialog must have at least one assistant message")
|
||||
return None, None
|
||||
|
||||
# Convert to human/gpt format
|
||||
|
|
@ -100,14 +100,14 @@ class HFFinetuningSingleDevice:
|
|||
conversations = []
|
||||
for msg in dialog:
|
||||
if "role" not in msg or "content" not in msg:
|
||||
logger.warning(f"Message missing role or content: {msg}")
|
||||
log.warning(f"Message missing role or content: {msg}")
|
||||
continue
|
||||
conversations.append({"from": role_map[msg["role"]], "value": msg["content"]})
|
||||
|
||||
# Format as a single conversation
|
||||
return conversations[0]["value"], conversations[1]["value"]
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse dialog: {row['dialog']}")
|
||||
log.warning(f"Failed to parse dialog: {row['dialog']}")
|
||||
return None, None
|
||||
return None, None
|
||||
|
||||
|
|
@ -198,7 +198,7 @@ class HFFinetuningSingleDevice:
|
|||
"""
|
||||
import asyncio
|
||||
|
||||
logger.info("Starting training process with async wrapper")
|
||||
log.info("Starting training process with async wrapper")
|
||||
asyncio.run(
|
||||
self._run_training(
|
||||
model=model,
|
||||
|
|
@ -228,14 +228,14 @@ class HFFinetuningSingleDevice:
|
|||
raise ValueError("DataConfig is required for training")
|
||||
|
||||
# Load dataset
|
||||
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
|
||||
log.info(f"Loading dataset: {config.data_config.dataset_id}")
|
||||
rows = await load_rows_from_dataset(self.datasetio_api, config.data_config.dataset_id)
|
||||
if not self.validate_dataset_format(rows):
|
||||
raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input")
|
||||
logger.info(f"Loaded {len(rows)} rows from dataset")
|
||||
log.info(f"Loaded {len(rows)} rows from dataset")
|
||||
|
||||
# Initialize tokenizer
|
||||
logger.info(f"Initializing tokenizer for model: {model}")
|
||||
log.info(f"Initializing tokenizer for model: {model}")
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model, **provider_config.model_specific_config)
|
||||
|
||||
|
|
@ -257,16 +257,16 @@ class HFFinetuningSingleDevice:
|
|||
# This ensures consistent sequence lengths across the training process
|
||||
tokenizer.model_max_length = provider_config.max_seq_length
|
||||
|
||||
logger.info("Tokenizer initialized successfully")
|
||||
log.info("Tokenizer initialized successfully")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to initialize tokenizer: {str(e)}") from e
|
||||
|
||||
# Create and preprocess dataset
|
||||
logger.info("Creating and preprocessing dataset")
|
||||
log.info("Creating and preprocessing dataset")
|
||||
try:
|
||||
ds = self._create_dataset(rows, config, provider_config)
|
||||
ds = self._preprocess_dataset(ds, tokenizer, provider_config)
|
||||
logger.info(f"Dataset created with {len(ds)} examples")
|
||||
log.info(f"Dataset created with {len(ds)} examples")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to create dataset: {str(e)}") from e
|
||||
|
||||
|
|
@ -293,11 +293,11 @@ class HFFinetuningSingleDevice:
|
|||
Returns:
|
||||
Configured SFTConfig object
|
||||
"""
|
||||
logger.info("Configuring training arguments")
|
||||
log.info("Configuring training arguments")
|
||||
lr = 2e-5
|
||||
if config.optimizer_config:
|
||||
lr = config.optimizer_config.lr
|
||||
logger.info(f"Using custom learning rate: {lr}")
|
||||
log.info(f"Using custom learning rate: {lr}")
|
||||
|
||||
# Validate data config
|
||||
if not config.data_config:
|
||||
|
|
@ -350,17 +350,17 @@ class HFFinetuningSingleDevice:
|
|||
peft_config: Optional LoRA configuration
|
||||
output_dir_path: Path to save the model
|
||||
"""
|
||||
logger.info("Saving final model")
|
||||
log.info("Saving final model")
|
||||
model_obj.config.use_cache = True
|
||||
|
||||
if peft_config:
|
||||
logger.info("Merging LoRA weights with base model")
|
||||
log.info("Merging LoRA weights with base model")
|
||||
model_obj = trainer.model.merge_and_unload()
|
||||
else:
|
||||
model_obj = trainer.model
|
||||
|
||||
save_path = output_dir_path / "merged_model"
|
||||
logger.info(f"Saving model to {save_path}")
|
||||
log.info(f"Saving model to {save_path}")
|
||||
model_obj.save_pretrained(save_path)
|
||||
|
||||
async def _run_training(
|
||||
|
|
@ -380,13 +380,13 @@ class HFFinetuningSingleDevice:
|
|||
setup_signal_handlers()
|
||||
|
||||
# Convert config dicts back to objects
|
||||
logger.info("Initializing configuration objects")
|
||||
log.info("Initializing configuration objects")
|
||||
provider_config_obj = HuggingFacePostTrainingConfig(**provider_config)
|
||||
config_obj = TrainingConfig(**config)
|
||||
|
||||
# Initialize and validate device
|
||||
device = setup_torch_device(provider_config_obj.device)
|
||||
logger.info(f"Using device '{device}'")
|
||||
log.info(f"Using device '{device}'")
|
||||
|
||||
# Load dataset and tokenizer
|
||||
train_dataset, eval_dataset, tokenizer = await self.load_dataset(model, config_obj, provider_config_obj)
|
||||
|
|
@ -409,7 +409,7 @@ class HFFinetuningSingleDevice:
|
|||
model_obj = load_model(model, device, provider_config_obj)
|
||||
|
||||
# Initialize trainer
|
||||
logger.info("Initializing SFTTrainer")
|
||||
log.info("Initializing SFTTrainer")
|
||||
trainer = SFTTrainer(
|
||||
model=model_obj,
|
||||
train_dataset=train_dataset,
|
||||
|
|
@ -420,9 +420,9 @@ class HFFinetuningSingleDevice:
|
|||
|
||||
try:
|
||||
# Train
|
||||
logger.info("Starting training")
|
||||
log.info("Starting training")
|
||||
trainer.train()
|
||||
logger.info("Training completed successfully")
|
||||
log.info("Training completed successfully")
|
||||
|
||||
# Save final model if output directory is provided
|
||||
if output_dir_path:
|
||||
|
|
@ -430,12 +430,12 @@ class HFFinetuningSingleDevice:
|
|||
|
||||
finally:
|
||||
# Clean up resources
|
||||
logger.info("Cleaning up resources")
|
||||
log.info("Cleaning up resources")
|
||||
if hasattr(trainer, "model"):
|
||||
evacuate_model_from_device(trainer.model, device.type)
|
||||
del trainer
|
||||
gc.collect()
|
||||
logger.info("Cleanup completed")
|
||||
log.info("Cleanup completed")
|
||||
|
||||
async def train(
|
||||
self,
|
||||
|
|
@ -449,7 +449,7 @@ class HFFinetuningSingleDevice:
|
|||
"""Train a model using HuggingFace's SFTTrainer"""
|
||||
# Initialize and validate device
|
||||
device = setup_torch_device(provider_config.device)
|
||||
logger.info(f"Using device '{device}'")
|
||||
log.info(f"Using device '{device}'")
|
||||
|
||||
output_dir_path = None
|
||||
if output_dir:
|
||||
|
|
@ -479,7 +479,7 @@ class HFFinetuningSingleDevice:
|
|||
raise ValueError("DataConfig is required for training")
|
||||
|
||||
# Train in a separate process
|
||||
logger.info("Starting training in separate process")
|
||||
log.info("Starting training in separate process")
|
||||
try:
|
||||
# Setup multiprocessing for device
|
||||
if device.type in ["cuda", "mps"]:
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import multiprocessing
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
|
@ -24,6 +23,7 @@ from llama_stack.apis.post_training import (
|
|||
DPOAlignmentConfig,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
||||
|
||||
from ..config import HuggingFacePostTrainingConfig
|
||||
|
|
@ -40,7 +40,7 @@ from ..utils import (
|
|||
split_dataset,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(__name__, category="core")
|
||||
|
||||
|
||||
class HFDPOAlignmentSingleDevice:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
|
|
@ -19,10 +18,11 @@ from transformers import AutoConfig, AutoModelForCausalLM
|
|||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.post_training import Checkpoint, TrainingConfig
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .config import HuggingFacePostTrainingConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(__name__, category="core")
|
||||
|
||||
|
||||
def setup_environment():
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
|
|
@ -19,6 +18,7 @@ from torch.utils.data import DataLoader, DistributedSampler
|
|||
from torchtune import modules, training
|
||||
from torchtune import utils as torchtune_utils
|
||||
from torchtune.data import padded_collate_sft
|
||||
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
||||
from torchtune.modules.loss import CEWithChunkedOutputLoss
|
||||
from torchtune.modules.peft import (
|
||||
get_adapter_params,
|
||||
|
|
@ -45,6 +45,7 @@ from llama_stack.apis.post_training import (
|
|||
)
|
||||
from llama_stack.core.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||
from llama_stack.core.utils.model_utils import model_local_dir
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
||||
from llama_stack.providers.inline.post_training.torchtune.common import utils
|
||||
|
|
@ -56,9 +57,7 @@ from llama_stack.providers.inline.post_training.torchtune.config import (
|
|||
)
|
||||
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class LoraFinetuningSingleDevice:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
|
|
@ -15,13 +14,14 @@ from llama_stack.apis.safety import (
|
|||
ViolationLevel,
|
||||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
from .config import CodeScannerConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="safety")
|
||||
|
||||
ALLOWED_CODE_SCANNER_MODEL_IDS = [
|
||||
"CodeScanner",
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
|
@ -19,6 +18,7 @@ from llama_stack.apis.safety import (
|
|||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.core.utils.model_utils import model_local_dir
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
|
|
@ -26,10 +26,10 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
from .config import PromptGuardConfig, PromptGuardType
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
|
||||
|
||||
log = get_logger(name=__name__, category="safety")
|
||||
|
||||
|
||||
class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||
def __init__(self, config: PromptGuardConfig, _deps) -> None:
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@
|
|||
import collections
|
||||
import functools
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
import string
|
||||
|
|
@ -20,7 +19,9 @@ import nltk
|
|||
from pythainlp.tokenize import sent_tokenize as sent_tokenize_thai
|
||||
from pythainlp.tokenize import word_tokenize as word_tokenize_thai
|
||||
|
||||
logger = logging.getLogger()
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
WORD_LIST = [
|
||||
"western",
|
||||
|
|
@ -1726,7 +1727,7 @@ def get_langid(text: str, lid_path: str | None = None) -> str:
|
|||
try:
|
||||
line_langs.append(langdetect.detect(line))
|
||||
except langdetect.LangDetectException as e:
|
||||
logger.info("Unable to detect language for text %s due to %s", line, e) # refex: disable=pytotw.037
|
||||
log.info("Unable to detect language for text %s due to %s", line, e) # refex: disable=pytotw.037
|
||||
|
||||
if len(line_langs) == 0:
|
||||
return "en"
|
||||
|
|
@ -1885,7 +1886,7 @@ class ResponseLanguageChecker(Instruction):
|
|||
return langdetect.detect(value) == self._language
|
||||
except langdetect.LangDetectException as e:
|
||||
# Count as instruction is followed.
|
||||
logger.info("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037
|
||||
log.info("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037
|
||||
return True
|
||||
|
||||
|
||||
|
|
@ -3110,7 +3111,7 @@ class CapitalLettersEnglishChecker(Instruction):
|
|||
return value.isupper() and langdetect.detect(value) == "en"
|
||||
except langdetect.LangDetectException as e:
|
||||
# Count as instruction is followed.
|
||||
logger.info("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037
|
||||
log.info("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037
|
||||
return True
|
||||
|
||||
|
||||
|
|
@ -3139,7 +3140,7 @@ class LowercaseLettersEnglishChecker(Instruction):
|
|||
return value.islower() and langdetect.detect(value) == "en"
|
||||
except langdetect.LangDetectException as e:
|
||||
# Count as instruction is followed.
|
||||
logger.info("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037
|
||||
log.info("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037
|
||||
return True
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import secrets
|
||||
import string
|
||||
from typing import Any
|
||||
|
|
@ -32,6 +31,7 @@ from llama_stack.apis.tools import (
|
|||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
|
|
@ -42,7 +42,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
from .config import RagToolRuntimeConfig
|
||||
from .context_retriever import generate_rag_query
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="tools")
|
||||
|
||||
|
||||
def make_random_string(length: int = 8):
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ import asyncio
|
|||
import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import faiss
|
||||
|
|
@ -24,6 +23,7 @@ from llama_stack.apis.vector_io import (
|
|||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import (
|
||||
HealthResponse,
|
||||
HealthStatus,
|
||||
|
|
@ -39,7 +39,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
|
||||
from .config import FaissVectorIOConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:{VERSION}::"
|
||||
|
|
@ -83,7 +83,7 @@ class FaissIndex(EmbeddingIndex):
|
|||
self.index = faiss.deserialize_index(np.load(buffer, allow_pickle=False))
|
||||
self.chunk_ids = [chunk.chunk_id for chunk in self.chunk_by_index.values()]
|
||||
except Exception as e:
|
||||
logger.debug(e, exc_info=True)
|
||||
log.debug(e, exc_info=True)
|
||||
raise ValueError(
|
||||
"Error deserializing Faiss index from storage. If you recently upgraded your Llama Stack, Faiss, "
|
||||
"or NumPy versions, you may need to delete the index and re-create it again or downgrade versions.\n"
|
||||
|
|
@ -262,7 +262,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
|
|||
assert self.kvstore is not None
|
||||
|
||||
if vector_db_id not in self.cache:
|
||||
logger.warning(f"Vector DB {vector_db_id} not found")
|
||||
log.warning(f"Vector DB {vector_db_id} not found")
|
||||
return
|
||||
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import sqlite3
|
||||
import struct
|
||||
|
|
@ -24,6 +23,7 @@ from llama_stack.apis.vector_io import (
|
|||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
|
|
@ -35,7 +35,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
VectorDBWithIndex,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
# Specifying search mode is dependent on the VectorIO provider.
|
||||
VECTOR_SEARCH = "vector"
|
||||
|
|
@ -257,7 +257,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
|
||||
except sqlite3.Error as e:
|
||||
connection.rollback()
|
||||
logger.error(f"Error inserting into {self.vector_table}: {e}")
|
||||
log.error(f"Error inserting into {self.vector_table}: {e}")
|
||||
raise
|
||||
|
||||
finally:
|
||||
|
|
@ -306,7 +306,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
try:
|
||||
chunk = Chunk.model_validate_json(chunk_json)
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing chunk JSON for id {_id}: {e}")
|
||||
log.error(f"Error parsing chunk JSON for id {_id}: {e}")
|
||||
continue
|
||||
chunks.append(chunk)
|
||||
scores.append(score)
|
||||
|
|
@ -352,7 +352,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
try:
|
||||
chunk = Chunk.model_validate_json(chunk_json)
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing chunk JSON for id {_id}: {e}")
|
||||
log.error(f"Error parsing chunk JSON for id {_id}: {e}")
|
||||
continue
|
||||
chunks.append(chunk)
|
||||
scores.append(score)
|
||||
|
|
@ -447,7 +447,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
connection.commit()
|
||||
except Exception as e:
|
||||
connection.rollback()
|
||||
logger.error(f"Error deleting chunk {chunk_id}: {e}")
|
||||
log.error(f"Error deleting chunk {chunk_id}: {e}")
|
||||
raise
|
||||
finally:
|
||||
cur.close()
|
||||
|
|
@ -530,7 +530,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
|||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
if vector_db_id not in self.cache:
|
||||
logger.warning(f"Vector DB {vector_db_id} not found")
|
||||
log.warning(f"Vector DB {vector_db_id} not found")
|
||||
return
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue