diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4309f289a..8ccedee7c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -161,6 +161,25 @@ repos: pass_filenames: false require_serial: true + - id: check-log-usage + name: Ensure 'llama_stack.log' usage for logging + entry: bash + language: system + types: [python] + pass_filenames: true + args: + - -c + - | + matches=$(grep -EnH '^[^#]*\b(import\s+logging|from\s+logging\b)' "$@" | grep -v -e '#\s*allow-direct-logging' || true) + if [ -n "$matches" ]; then + # GitHub Actions annotation format + while IFS=: read -r file line_num rest; do + echo "::error file=$file,line=$line_num::Do not use 'import logging' or 'from logging import' in $file. Use the custom log instead: from llama_stack.log import get_logger; logger = get_logger(). If direct logging is truly needed, add: # allow-direct-logging" + done <<< "$matches" + exit 1 + fi + exit 0 + ci: autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks autoupdate_commit_msg: ⬆ [pre-commit.ci] pre-commit autoupdate diff --git a/llama_stack/core/build.py b/llama_stack/core/build.py index 4b20588fd..fa1fe632b 100644 --- a/llama_stack/core/build.py +++ b/llama_stack/core/build.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import importlib.resources -import logging import sys from pydantic import BaseModel @@ -17,9 +16,10 @@ from llama_stack.core.external import load_external_apis from llama_stack.core.utils.exec import run_command from llama_stack.core.utils.image_types import LlamaStackImageType from llama_stack.distributions.template import DistributionTemplate +from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="core") # These are the dependencies needed by the distribution server. # `llama-stack` is automatically installed by the installation script. diff --git a/llama_stack/core/configure.py b/llama_stack/core/configure.py index 9e18b438c..64473c053 100644 --- a/llama_stack/core/configure.py +++ b/llama_stack/core/configure.py @@ -3,7 +3,6 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging import textwrap from typing import Any @@ -21,9 +20,10 @@ from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars from llama_stack.core.utils.config_dirs import EXTERNAL_PROVIDERS_DIR from llama_stack.core.utils.dynamic import instantiate_class_type from llama_stack.core.utils.prompt_for_config import prompt_for_config +from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api, ProviderSpec -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="core") def configure_single_provider(registry: dict[str, ProviderSpec], provider: Provider) -> Provider: diff --git a/llama_stack/core/library_client.py b/llama_stack/core/library_client.py index a93fe509e..dd1fc8a50 100644 --- a/llama_stack/core/library_client.py +++ b/llama_stack/core/library_client.py @@ -7,7 +7,7 @@ import asyncio import inspect import json -import logging +import logging # allow-direct-logging import os import sys from concurrent.futures import ThreadPoolExecutor @@ -48,6 +48,7 @@ from llama_stack.core.stack import ( from llama_stack.core.utils.config import redact_sensitive_fields from llama_stack.core.utils.context import preserve_contexts_async_generator from llama_stack.core.utils.exec import in_notebook +from llama_stack.log import get_logger from llama_stack.providers.utils.telemetry.tracing import ( CURRENT_TRACE_CONTEXT, end_trace, @@ -55,7 +56,7 @@ from llama_stack.providers.utils.telemetry.tracing import ( start_trace, ) -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="core") T = TypeVar("T") diff --git a/llama_stack/core/request_headers.py b/llama_stack/core/request_headers.py index 35ac72775..f1ce8281f 100644 --- a/llama_stack/core/request_headers.py +++ b/llama_stack/core/request_headers.py @@ -6,15 +6,15 @@ import contextvars import json -import logging from contextlib import AbstractContextManager from typing import Any from llama_stack.core.datatypes import User +from llama_stack.log import get_logger from .utils.dynamic import instantiate_class_type -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="core") # Context variable for request provider data and auth attributes PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None) diff --git a/llama_stack/core/server/server.py b/llama_stack/core/server/server.py index e9d70fc8d..38212ac66 100644 --- a/llama_stack/core/server/server.py +++ b/llama_stack/core/server/server.py @@ -9,7 +9,7 @@ import asyncio import functools import inspect import json -import logging +import logging # allow-direct-logging import os import ssl import sys diff --git a/llama_stack/core/utils/exec.py b/llama_stack/core/utils/exec.py index 1b2b782fe..12fb82d01 100644 --- a/llama_stack/core/utils/exec.py +++ b/llama_stack/core/utils/exec.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging +import importlib import os import signal import subprocess @@ -12,9 +12,9 @@ import sys from termcolor import cprint -log = logging.getLogger(__name__) +from llama_stack.log import get_logger -import importlib +log = get_logger(name=__name__, category="core") def formulate_run_args(image_type: str, image_name: str) -> list: diff --git a/llama_stack/core/utils/prompt_for_config.py b/llama_stack/core/utils/prompt_for_config.py index 26f6920e0..bac0531ed 100644 --- a/llama_stack/core/utils/prompt_for_config.py +++ b/llama_stack/core/utils/prompt_for_config.py @@ -6,7 +6,6 @@ import inspect import json -import logging from enum import Enum from typing import Annotated, Any, Literal, Union, get_args, get_origin @@ -14,7 +13,9 @@ from pydantic import BaseModel from pydantic.fields import FieldInfo from pydantic_core import PydanticUndefinedType -log = logging.getLogger(__name__) +from llama_stack.log import get_logger + +log = get_logger(name=__name__, category="core") def is_list_of_primitives(field_type): diff --git a/llama_stack/log.py b/llama_stack/log.py index 7507aface..8dcdcb0e3 100644 --- a/llama_stack/log.py +++ b/llama_stack/log.py @@ -4,11 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging +import logging # allow-direct-logging import os import re import sys -from logging.config import dictConfig +from logging.config import dictConfig # allow-direct-logging from rich.console import Console from rich.errors import MarkupError diff --git a/llama_stack/models/llama/llama3/multimodal/encoder_utils.py b/llama_stack/models/llama/llama3/multimodal/encoder_utils.py index 5b5969d89..90ced13b2 100644 --- a/llama_stack/models/llama/llama3/multimodal/encoder_utils.py +++ b/llama_stack/models/llama/llama3/multimodal/encoder_utils.py @@ -13,14 +13,15 @@ # Copyright (c) Meta Platforms, Inc. and its affiliates. import math -from logging import getLogger import torch import torch.nn.functional as F +from llama_stack.log import get_logger + from .utils import get_negative_inf_value, to_2tuple -logger = getLogger() +logger = get_logger(name=__name__, category="models::llama") def resize_local_position_embedding(orig_pos_embed, grid_size): diff --git a/llama_stack/models/llama/llama3/multimodal/image_transform.py b/llama_stack/models/llama/llama3/multimodal/image_transform.py index f2761ee47..7b20a31fa 100644 --- a/llama_stack/models/llama/llama3/multimodal/image_transform.py +++ b/llama_stack/models/llama/llama3/multimodal/image_transform.py @@ -13,7 +13,6 @@ import math from collections import defaultdict -from logging import getLogger from typing import Any import torch @@ -21,9 +20,11 @@ import torchvision.transforms as tv from PIL import Image from torchvision.transforms import functional as F +from llama_stack.log import get_logger + IMAGE_RES = 224 -logger = getLogger() +logger = get_logger(name=__name__, category="models::llama") class VariableSizeImageTransform: diff --git a/llama_stack/models/llama/llama3/multimodal/model.py b/llama_stack/models/llama/llama3/multimodal/model.py index 5f1c3605c..096156a5f 100644 --- a/llama_stack/models/llama/llama3/multimodal/model.py +++ b/llama_stack/models/llama/llama3/multimodal/model.py @@ -3,8 +3,6 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. - -import logging import math from collections.abc import Callable from functools import partial @@ -22,6 +20,8 @@ from PIL import Image as PIL_Image from torch import Tensor, nn from torch.distributed import _functional_collectives as funcol +from llama_stack.log import get_logger + from ..model import ModelArgs, RMSNorm, apply_rotary_emb, precompute_freqs_cis from .encoder_utils import ( build_encoder_attention_mask, @@ -34,9 +34,10 @@ from .encoder_utils import ( from .image_transform import VariableSizeImageTransform from .utils import get_negative_inf_value, to_2tuple -logger = logging.getLogger(__name__) MP_SCALE = 8 +logger = get_logger(name=__name__, category="models") + def reduce_from_tensor_model_parallel_region(input_): """All-reduce the input tensor across model parallel group.""" @@ -771,7 +772,7 @@ class TilePositionEmbedding(nn.Module): if embed is not None: # reshape the weights to the correct shape nt_old, nt_old, _, w = embed.shape - logging.info(f"Resizing tile embedding from {nt_old}x{nt_old} to {self.num_tiles}x{self.num_tiles}") + logger.info(f"Resizing tile embedding from {nt_old}x{nt_old} to {self.num_tiles}x{self.num_tiles}") embed_new = TilePositionEmbedding._dynamic_resize(embed, self.num_tiles) # assign the weights to the module state_dict[prefix + "embedding"] = embed_new diff --git a/llama_stack/models/llama/llama3/tokenizer.py b/llama_stack/models/llama/llama3/tokenizer.py index e47b579e3..ad7ced1c5 100644 --- a/llama_stack/models/llama/llama3/tokenizer.py +++ b/llama_stack/models/llama/llama3/tokenizer.py @@ -4,8 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. + from collections.abc import Collection, Iterator, Sequence, Set -from logging import getLogger from pathlib import Path from typing import ( Literal, @@ -14,11 +14,9 @@ from typing import ( import tiktoken +from llama_stack.log import get_logger from llama_stack.models.llama.tokenizer_utils import load_bpe_file -logger = getLogger(__name__) - - # The tiktoken tokenizer can handle <=400k chars without # pyo3_runtime.PanicException. TIKTOKEN_MAX_ENCODE_CHARS = 400_000 @@ -31,6 +29,8 @@ MAX_NO_WHITESPACES_CHARS = 25_000 _INSTANCE = None +logger = get_logger(name=__name__, category="models::llama") + class Tokenizer: """ diff --git a/llama_stack/models/llama/llama4/quantization/loader.py b/llama_stack/models/llama/llama4/quantization/loader.py index 223744a5f..8220a9040 100644 --- a/llama_stack/models/llama/llama4/quantization/loader.py +++ b/llama_stack/models/llama/llama4/quantization/loader.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging import os from collections.abc import Callable @@ -13,11 +12,13 @@ from fairscale.nn.model_parallel.initialize import get_model_parallel_rank from torch import Tensor, nn from torch.nn import functional as F +from llama_stack.log import get_logger + from ...datatypes import QuantizationMode from ..model import Transformer, TransformerBlock from ..moe import MoE -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="models") def swiglu_wrapper_no_reduce( diff --git a/llama_stack/models/llama/llama4/tokenizer.py b/llama_stack/models/llama/llama4/tokenizer.py index e12b2cae0..bfbace8f9 100644 --- a/llama_stack/models/llama/llama4/tokenizer.py +++ b/llama_stack/models/llama/llama4/tokenizer.py @@ -5,7 +5,6 @@ # the root directory of this source tree. from collections.abc import Collection, Iterator, Sequence, Set -from logging import getLogger from pathlib import Path from typing import ( Literal, @@ -14,11 +13,9 @@ from typing import ( import tiktoken +from llama_stack.log import get_logger from llama_stack.models.llama.tokenizer_utils import load_bpe_file -logger = getLogger(__name__) - - # The tiktoken tokenizer can handle <=400k chars without # pyo3_runtime.PanicException. TIKTOKEN_MAX_ENCODE_CHARS = 400_000 @@ -101,6 +98,8 @@ BASIC_SPECIAL_TOKENS = [ "<|fim_suffix|>", ] +logger = get_logger(name=__name__, category="models::llama") + class Tokenizer: """ diff --git a/llama_stack/models/llama/quantize_impls.py b/llama_stack/models/llama/quantize_impls.py index a6400c5c9..7fab2d3a6 100644 --- a/llama_stack/models/llama/quantize_impls.py +++ b/llama_stack/models/llama/quantize_impls.py @@ -6,9 +6,10 @@ # type: ignore import collections -import logging -log = logging.getLogger(__name__) +from llama_stack.log import get_logger + +log = get_logger(name=__name__, category="llama") try: import fbgemm_gpu.experimental.gen_ai # noqa: F401 diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index 30196c429..5794ad2c0 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging import uuid from collections.abc import AsyncGenerator from datetime import UTC, datetime @@ -42,6 +41,7 @@ from llama_stack.apis.safety import Safety from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.apis.vector_io import VectorIO from llama_stack.core.datatypes import AccessRule +from llama_stack.log import get_logger from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl from llama_stack.providers.utils.pagination import paginate_records from llama_stack.providers.utils.responses.responses_store import ResponsesStore @@ -51,7 +51,7 @@ from .config import MetaReferenceAgentsImplConfig from .persistence import AgentInfo from .responses.openai_responses import OpenAIResponsesImpl -logger = logging.getLogger() +logger = get_logger(name=__name__, category="agents") class MetaReferenceAgentsImpl(Agents): diff --git a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py new file mode 100644 index 000000000..6850ae97e --- /dev/null +++ b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py @@ -0,0 +1,1154 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +import json +import time +import uuid +from collections.abc import AsyncIterator +from typing import Any + +from openai.types.chat import ChatCompletionToolParam +from pydantic import BaseModel + +from llama_stack.apis.agents import Order +from llama_stack.apis.agents.openai_responses import ( + AllowedToolsFilter, + ListOpenAIResponseInputItem, + ListOpenAIResponseObject, + OpenAIDeleteResponseObject, + OpenAIResponseContentPartOutputText, + OpenAIResponseInput, + OpenAIResponseInputFunctionToolCallOutput, + OpenAIResponseInputMessageContent, + OpenAIResponseInputMessageContentImage, + OpenAIResponseInputMessageContentText, + OpenAIResponseInputTool, + OpenAIResponseInputToolFileSearch, + OpenAIResponseInputToolMCP, + OpenAIResponseMessage, + OpenAIResponseObject, + OpenAIResponseObjectStream, + OpenAIResponseObjectStreamResponseCompleted, + OpenAIResponseObjectStreamResponseContentPartAdded, + OpenAIResponseObjectStreamResponseContentPartDone, + OpenAIResponseObjectStreamResponseCreated, + OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta, + OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone, + OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta, + OpenAIResponseObjectStreamResponseMcpCallArgumentsDone, + OpenAIResponseObjectStreamResponseMcpCallCompleted, + OpenAIResponseObjectStreamResponseMcpCallFailed, + OpenAIResponseObjectStreamResponseMcpCallInProgress, + OpenAIResponseObjectStreamResponseOutputItemAdded, + OpenAIResponseObjectStreamResponseOutputItemDone, + OpenAIResponseObjectStreamResponseOutputTextDelta, + OpenAIResponseObjectStreamResponseWebSearchCallCompleted, + OpenAIResponseObjectStreamResponseWebSearchCallInProgress, + OpenAIResponseObjectStreamResponseWebSearchCallSearching, + OpenAIResponseOutput, + OpenAIResponseOutputMessageContent, + OpenAIResponseOutputMessageContentOutputText, + OpenAIResponseOutputMessageFileSearchToolCall, + OpenAIResponseOutputMessageFileSearchToolCallResults, + OpenAIResponseOutputMessageFunctionToolCall, + OpenAIResponseOutputMessageMCPListTools, + OpenAIResponseOutputMessageWebSearchToolCall, + OpenAIResponseText, + OpenAIResponseTextFormat, + WebSearchToolTypes, +) +from llama_stack.apis.common.content_types import TextContentItem +from llama_stack.apis.inference import ( + Inference, + OpenAIAssistantMessageParam, + OpenAIChatCompletion, + OpenAIChatCompletionContentPartImageParam, + OpenAIChatCompletionContentPartParam, + OpenAIChatCompletionContentPartTextParam, + OpenAIChatCompletionToolCall, + OpenAIChatCompletionToolCallFunction, + OpenAIChoice, + OpenAIDeveloperMessageParam, + OpenAIImageURL, + OpenAIJSONSchema, + OpenAIMessageParam, + OpenAIResponseFormatJSONObject, + OpenAIResponseFormatJSONSchema, + OpenAIResponseFormatParam, + OpenAIResponseFormatText, + OpenAISystemMessageParam, + OpenAIToolMessageParam, + OpenAIUserMessageParam, +) +from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime +from llama_stack.apis.vector_io import VectorIO +from llama_stack.log import get_logger +from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition +from llama_stack.providers.utils.inference.openai_compat import ( + convert_tooldef_to_openai_tool, +) +from llama_stack.providers.utils.responses.responses_store import ResponsesStore + +logger = get_logger(name=__name__, category="agents") + +OPENAI_RESPONSES_PREFIX = "openai_responses:" + + +class ToolExecutionResult(BaseModel): + """Result of streaming tool execution.""" + + stream_event: OpenAIResponseObjectStream | None = None + sequence_number: int + final_output_message: OpenAIResponseOutput | None = None + final_input_message: OpenAIMessageParam | None = None + + +async def _convert_response_content_to_chat_content( + content: (str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]), +) -> str | list[OpenAIChatCompletionContentPartParam]: + """ + Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts. + + The content schemas of each API look similar, but are not exactly the same. + """ + if isinstance(content, str): + return content + + converted_parts = [] + for content_part in content: + if isinstance(content_part, OpenAIResponseInputMessageContentText): + converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text)) + elif isinstance(content_part, OpenAIResponseOutputMessageContentOutputText): + converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text)) + elif isinstance(content_part, OpenAIResponseInputMessageContentImage): + if content_part.image_url: + image_url = OpenAIImageURL(url=content_part.image_url, detail=content_part.detail) + converted_parts.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url)) + elif isinstance(content_part, str): + converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part)) + else: + raise ValueError( + f"Llama Stack OpenAI Responses does not yet support content type '{type(content_part)}' in this context" + ) + return converted_parts + + +async def _convert_response_input_to_chat_messages( + input: str | list[OpenAIResponseInput], +) -> list[OpenAIMessageParam]: + """ + Convert the input from an OpenAI Response API request into OpenAI Chat Completion messages. + """ + messages: list[OpenAIMessageParam] = [] + if isinstance(input, list): + for input_item in input: + if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput): + messages.append( + OpenAIToolMessageParam( + content=input_item.output, + tool_call_id=input_item.call_id, + ) + ) + elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall): + tool_call = OpenAIChatCompletionToolCall( + index=0, + id=input_item.call_id, + function=OpenAIChatCompletionToolCallFunction( + name=input_item.name, + arguments=input_item.arguments, + ), + ) + messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call])) + else: + content = await _convert_response_content_to_chat_content(input_item.content) + message_type = await _get_message_type_by_role(input_item.role) + if message_type is None: + raise ValueError( + f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context" + ) + messages.append(message_type(content=content)) + else: + messages.append(OpenAIUserMessageParam(content=input)) + return messages + + +async def _convert_chat_choice_to_response_message( + choice: OpenAIChoice, +) -> OpenAIResponseMessage: + """ + Convert an OpenAI Chat Completion choice into an OpenAI Response output message. + """ + output_content = "" + if isinstance(choice.message.content, str): + output_content = choice.message.content + elif isinstance(choice.message.content, OpenAIChatCompletionContentPartTextParam): + output_content = choice.message.content.text + else: + raise ValueError( + f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}" + ) + + return OpenAIResponseMessage( + id=f"msg_{uuid.uuid4()}", + content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)], + status="completed", + role="assistant", + ) + + +async def _convert_response_text_to_chat_response_format( + text: OpenAIResponseText, +) -> OpenAIResponseFormatParam: + """ + Convert an OpenAI Response text parameter into an OpenAI Chat Completion response format. + """ + if not text.format or text.format["type"] == "text": + return OpenAIResponseFormatText(type="text") + if text.format["type"] == "json_object": + return OpenAIResponseFormatJSONObject() + if text.format["type"] == "json_schema": + return OpenAIResponseFormatJSONSchema( + json_schema=OpenAIJSONSchema(name=text.format["name"], schema=text.format["schema"]) + ) + raise ValueError(f"Unsupported text format: {text.format}") + + +async def _get_message_type_by_role(role: str): + role_to_type = { + "user": OpenAIUserMessageParam, + "system": OpenAISystemMessageParam, + "assistant": OpenAIAssistantMessageParam, + "developer": OpenAIDeveloperMessageParam, + } + return role_to_type.get(role) + + +class OpenAIResponsePreviousResponseWithInputItems(BaseModel): + input_items: ListOpenAIResponseInputItem + response: OpenAIResponseObject + + +class ChatCompletionContext(BaseModel): + model: str + messages: list[OpenAIMessageParam] + response_tools: list[OpenAIResponseInputTool] | None = None + chat_tools: list[ChatCompletionToolParam] | None = None + mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] + temperature: float | None + response_format: OpenAIResponseFormatParam + + +class OpenAIResponsesImpl: + def __init__( + self, + inference_api: Inference, + tool_groups_api: ToolGroups, + tool_runtime_api: ToolRuntime, + responses_store: ResponsesStore, + vector_io_api: VectorIO, # VectorIO + ): + self.inference_api = inference_api + self.tool_groups_api = tool_groups_api + self.tool_runtime_api = tool_runtime_api + self.responses_store = responses_store + self.vector_io_api = vector_io_api + + async def _prepend_previous_response( + self, + input: str | list[OpenAIResponseInput], + previous_response_id: str | None = None, + ): + if previous_response_id: + previous_response_with_input = await self.responses_store.get_response_object(previous_response_id) + + # previous response input items + new_input_items = previous_response_with_input.input + + # previous response output items + new_input_items.extend(previous_response_with_input.output) + + # new input items from the current request + if isinstance(input, str): + new_input_items.append(OpenAIResponseMessage(content=input, role="user")) + else: + new_input_items.extend(input) + + input = new_input_items + + return input + + async def _prepend_instructions(self, messages, instructions): + if instructions: + messages.insert(0, OpenAISystemMessageParam(content=instructions)) + + async def get_openai_response( + self, + response_id: str, + ) -> OpenAIResponseObject: + response_with_input = await self.responses_store.get_response_object(response_id) + return OpenAIResponseObject(**{k: v for k, v in response_with_input.model_dump().items() if k != "input"}) + + async def list_openai_responses( + self, + after: str | None = None, + limit: int | None = 50, + model: str | None = None, + order: Order | None = Order.desc, + ) -> ListOpenAIResponseObject: + return await self.responses_store.list_responses(after, limit, model, order) + + async def list_openai_response_input_items( + self, + response_id: str, + after: str | None = None, + before: str | None = None, + include: list[str] | None = None, + limit: int | None = 20, + order: Order | None = Order.desc, + ) -> ListOpenAIResponseInputItem: + """List input items for a given OpenAI response. + + :param response_id: The ID of the response to retrieve input items for. + :param after: An item ID to list items after, used for pagination. + :param before: An item ID to list items before, used for pagination. + :param include: Additional fields to include in the response. + :param limit: A limit on the number of objects to be returned. + :param order: The order to return the input items in. + :returns: An ListOpenAIResponseInputItem. + """ + return await self.responses_store.list_response_input_items(response_id, after, before, include, limit, order) + + async def _store_response( + self, + response: OpenAIResponseObject, + input: str | list[OpenAIResponseInput], + ) -> None: + new_input_id = f"msg_{uuid.uuid4()}" + if isinstance(input, str): + # synthesize a message from the input string + input_content = OpenAIResponseInputMessageContentText(text=input) + input_content_item = OpenAIResponseMessage( + role="user", + content=[input_content], + id=new_input_id, + ) + input_items_data = [input_content_item] + else: + # we already have a list of messages + input_items_data = [] + for input_item in input: + if isinstance(input_item, OpenAIResponseMessage): + # These may or may not already have an id, so dump to dict, check for id, and add if missing + input_item_dict = input_item.model_dump() + if "id" not in input_item_dict: + input_item_dict["id"] = new_input_id + input_items_data.append(OpenAIResponseMessage(**input_item_dict)) + else: + input_items_data.append(input_item) + + await self.responses_store.store_response_object( + response_object=response, + input=input_items_data, + ) + + async def create_openai_response( + self, + input: str | list[OpenAIResponseInput], + model: str, + instructions: str | None = None, + previous_response_id: str | None = None, + store: bool | None = True, + stream: bool | None = False, + temperature: float | None = None, + text: OpenAIResponseText | None = None, + tools: list[OpenAIResponseInputTool] | None = None, + include: list[str] | None = None, + max_infer_iters: int | None = 10, + ): + stream = bool(stream) + text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text + + stream_gen = self._create_streaming_response( + input=input, + model=model, + instructions=instructions, + previous_response_id=previous_response_id, + store=store, + temperature=temperature, + text=text, + tools=tools, + max_infer_iters=max_infer_iters, + ) + + if stream: + return stream_gen + else: + response = None + async for stream_chunk in stream_gen: + if stream_chunk.type == "response.completed": + if response is not None: + raise ValueError("The response stream completed multiple times! Earlier response: {response}") + response = stream_chunk.response + # don't leave the generator half complete! + + if response is None: + raise ValueError("The response stream never completed") + return response + + async def _create_streaming_response( + self, + input: str | list[OpenAIResponseInput], + model: str, + instructions: str | None = None, + previous_response_id: str | None = None, + store: bool | None = True, + temperature: float | None = None, + text: OpenAIResponseText | None = None, + tools: list[OpenAIResponseInputTool] | None = None, + max_infer_iters: int | None = 10, + ) -> AsyncIterator[OpenAIResponseObjectStream]: + output_messages: list[OpenAIResponseOutput] = [] + + # Input preprocessing + input = await self._prepend_previous_response(input, previous_response_id) + messages = await _convert_response_input_to_chat_messages(input) + await self._prepend_instructions(messages, instructions) + + # Structured outputs + response_format = await _convert_response_text_to_chat_response_format(text) + + # Tool setup, TODO: refactor this slightly since this can also yield events + chat_tools, mcp_tool_to_server, mcp_list_message = ( + await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {}, None) + ) + if mcp_list_message: + output_messages.append(mcp_list_message) + + ctx = ChatCompletionContext( + model=model, + messages=messages, + response_tools=tools, + chat_tools=chat_tools, + mcp_tool_to_server=mcp_tool_to_server, + temperature=temperature, + response_format=response_format, + ) + + # Create initial response and emit response.created immediately + response_id = f"resp-{uuid.uuid4()}" + created_at = int(time.time()) + + initial_response = OpenAIResponseObject( + created_at=created_at, + id=response_id, + model=model, + object="response", + status="in_progress", + output=output_messages.copy(), + text=text, + ) + + yield OpenAIResponseObjectStreamResponseCreated(response=initial_response) + + n_iter = 0 + messages = ctx.messages.copy() + + while True: + completion_result = await self.inference_api.openai_chat_completion( + model=ctx.model, + messages=messages, + tools=ctx.chat_tools, + stream=True, + temperature=ctx.temperature, + response_format=ctx.response_format, + ) + + # Process streaming chunks and build complete response + chat_response_id = "" + chat_response_content = [] + chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {} + chunk_created = 0 + chunk_model = "" + chunk_finish_reason = "" + sequence_number = 0 + + # Create a placeholder message item for delta events + message_item_id = f"msg_{uuid.uuid4()}" + # Track tool call items for streaming events + tool_call_item_ids: dict[int, str] = {} + # Track content parts for streaming events + content_part_emitted = False + + async for chunk in completion_result: + chat_response_id = chunk.id + chunk_created = chunk.created + chunk_model = chunk.model + for chunk_choice in chunk.choices: + # Emit incremental text content as delta events + if chunk_choice.delta.content: + # Emit content_part.added event for first text chunk + if not content_part_emitted: + content_part_emitted = True + sequence_number += 1 + yield OpenAIResponseObjectStreamResponseContentPartAdded( + response_id=response_id, + item_id=message_item_id, + part=OpenAIResponseContentPartOutputText( + text="", # Will be filled incrementally via text deltas + ), + sequence_number=sequence_number, + ) + sequence_number += 1 + yield OpenAIResponseObjectStreamResponseOutputTextDelta( + content_index=0, + delta=chunk_choice.delta.content, + item_id=message_item_id, + output_index=0, + sequence_number=sequence_number, + ) + + # Collect content for final response + chat_response_content.append(chunk_choice.delta.content or "") + if chunk_choice.finish_reason: + chunk_finish_reason = chunk_choice.finish_reason + + # Aggregate tool call arguments across chunks + if chunk_choice.delta.tool_calls: + for tool_call in chunk_choice.delta.tool_calls: + response_tool_call = chat_response_tool_calls.get(tool_call.index, None) + # Create new tool call entry if this is the first chunk for this index + is_new_tool_call = response_tool_call is None + if is_new_tool_call: + tool_call_dict: dict[str, Any] = tool_call.model_dump() + tool_call_dict.pop("type", None) + response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict) + chat_response_tool_calls[tool_call.index] = response_tool_call + + # Create item ID for this tool call for streaming events + tool_call_item_id = f"fc_{uuid.uuid4()}" + tool_call_item_ids[tool_call.index] = tool_call_item_id + + # Emit output_item.added event for the new function call + sequence_number += 1 + function_call_item = OpenAIResponseOutputMessageFunctionToolCall( + arguments="", # Will be filled incrementally via delta events + call_id=tool_call.id or "", + name=tool_call.function.name if tool_call.function else "", + id=tool_call_item_id, + status="in_progress", + ) + yield OpenAIResponseObjectStreamResponseOutputItemAdded( + response_id=response_id, + item=function_call_item, + output_index=len(output_messages), + sequence_number=sequence_number, + ) + + # Stream tool call arguments as they arrive (differentiate between MCP and function calls) + if tool_call.function and tool_call.function.arguments: + tool_call_item_id = tool_call_item_ids[tool_call.index] + sequence_number += 1 + + # Check if this is an MCP tool call + is_mcp_tool = ( + ctx.mcp_tool_to_server + and tool_call.function.name + and tool_call.function.name in ctx.mcp_tool_to_server + ) + if is_mcp_tool: + # Emit MCP-specific argument delta event + yield OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta( + delta=tool_call.function.arguments, + item_id=tool_call_item_id, + output_index=len(output_messages), + sequence_number=sequence_number, + ) + else: + # Emit function call argument delta event + yield OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta( + delta=tool_call.function.arguments, + item_id=tool_call_item_id, + output_index=len(output_messages), + sequence_number=sequence_number, + ) + + # Accumulate arguments for final response (only for subsequent chunks) + if not is_new_tool_call: + response_tool_call.function.arguments = ( + response_tool_call.function.arguments or "" + ) + tool_call.function.arguments + + # Emit arguments.done events for completed tool calls (differentiate between MCP and function calls) + for tool_call_index in sorted(chat_response_tool_calls.keys()): + tool_call_item_id = tool_call_item_ids[tool_call_index] + final_arguments = chat_response_tool_calls[tool_call_index].function.arguments or "" + tool_call_name = chat_response_tool_calls[tool_call_index].function.name + + # Check if this is an MCP tool call + is_mcp_tool = ctx.mcp_tool_to_server and tool_call_name and tool_call_name in ctx.mcp_tool_to_server + sequence_number += 1 + if is_mcp_tool: + # Emit MCP-specific argument done event + yield OpenAIResponseObjectStreamResponseMcpCallArgumentsDone( + arguments=final_arguments, + item_id=tool_call_item_id, + output_index=len(output_messages), + sequence_number=sequence_number, + ) + else: + # Emit function call argument done event + yield OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone( + arguments=final_arguments, + item_id=tool_call_item_id, + output_index=len(output_messages), + sequence_number=sequence_number, + ) + + # Convert collected chunks to complete response + if chat_response_tool_calls: + tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())] + else: + tool_calls = None + + # Emit content_part.done event if text content was streamed (before content gets cleared) + if content_part_emitted: + final_text = "".join(chat_response_content) + sequence_number += 1 + yield OpenAIResponseObjectStreamResponseContentPartDone( + response_id=response_id, + item_id=message_item_id, + part=OpenAIResponseContentPartOutputText( + text=final_text, + ), + sequence_number=sequence_number, + ) + + # Clear content when there are tool calls (OpenAI spec behavior) + if chat_response_tool_calls: + chat_response_content = [] + + assistant_message = OpenAIAssistantMessageParam( + content="".join(chat_response_content), + tool_calls=tool_calls, + ) + current_response = OpenAIChatCompletion( + id=chat_response_id, + choices=[ + OpenAIChoice( + message=assistant_message, + finish_reason=chunk_finish_reason, + index=0, + ) + ], + created=chunk_created, + model=chunk_model, + ) + + function_tool_calls = [] + non_function_tool_calls = [] + + next_turn_messages = messages.copy() + for choice in current_response.choices: + next_turn_messages.append(choice.message) + + if choice.message.tool_calls and tools: + for tool_call in choice.message.tool_calls: + if _is_function_tool_call(tool_call, tools): + function_tool_calls.append(tool_call) + else: + non_function_tool_calls.append(tool_call) + else: + output_messages.append(await _convert_chat_choice_to_response_message(choice)) + + # execute non-function tool calls + for tool_call in non_function_tool_calls: + # Find the item_id for this tool call + matching_item_id = None + for index, item_id in tool_call_item_ids.items(): + response_tool_call = chat_response_tool_calls.get(index) + if response_tool_call and response_tool_call.id == tool_call.id: + matching_item_id = item_id + break + + # Use a fallback item_id if not found + if not matching_item_id: + matching_item_id = f"tc_{uuid.uuid4()}" + + # Execute tool call with streaming + tool_call_log = None + tool_response_message = None + async for result in self._execute_tool_call( + tool_call, ctx, sequence_number, response_id, len(output_messages), matching_item_id + ): + if result.stream_event: + # Forward streaming events + sequence_number = result.sequence_number + yield result.stream_event + + if result.final_output_message is not None: + tool_call_log = result.final_output_message + tool_response_message = result.final_input_message + sequence_number = result.sequence_number + + if tool_call_log: + output_messages.append(tool_call_log) + + # Emit output_item.done event for completed non-function tool call + if matching_item_id: + sequence_number += 1 + yield OpenAIResponseObjectStreamResponseOutputItemDone( + response_id=response_id, + item=tool_call_log, + output_index=len(output_messages) - 1, + sequence_number=sequence_number, + ) + + if tool_response_message: + next_turn_messages.append(tool_response_message) + + for tool_call in function_tool_calls: + # Find the item_id for this tool call from our tracking dictionary + matching_item_id = None + for index, item_id in tool_call_item_ids.items(): + response_tool_call = chat_response_tool_calls.get(index) + if response_tool_call and response_tool_call.id == tool_call.id: + matching_item_id = item_id + break + + # Use existing item_id or create new one if not found + final_item_id = matching_item_id or f"fc_{uuid.uuid4()}" + + function_call_item = OpenAIResponseOutputMessageFunctionToolCall( + arguments=tool_call.function.arguments or "", + call_id=tool_call.id, + name=tool_call.function.name or "", + id=final_item_id, + status="completed", + ) + output_messages.append(function_call_item) + + # Emit output_item.done event for completed function call + sequence_number += 1 + yield OpenAIResponseObjectStreamResponseOutputItemDone( + response_id=response_id, + item=function_call_item, + output_index=len(output_messages) - 1, + sequence_number=sequence_number, + ) + + if not function_tool_calls and not non_function_tool_calls: + break + + if function_tool_calls: + logger.info("Exiting inference loop since there is a function (client-side) tool call") + break + + n_iter += 1 + if n_iter >= max_infer_iters: + logger.info(f"Exiting inference loop since iteration count({n_iter}) exceeds {max_infer_iters=}") + break + + messages = next_turn_messages + + # Create final response + final_response = OpenAIResponseObject( + created_at=created_at, + id=response_id, + model=model, + object="response", + status="completed", + text=text, + output=output_messages, + ) + + # Emit response.completed + yield OpenAIResponseObjectStreamResponseCompleted(response=final_response) + + if store: + await self._store_response( + response=final_response, + input=input, + ) + + async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject: + return await self.responses_store.delete_response_object(response_id) + + async def _convert_response_tools_to_chat_tools( + self, tools: list[OpenAIResponseInputTool] + ) -> tuple[ + list[ChatCompletionToolParam], + dict[str, OpenAIResponseInputToolMCP], + OpenAIResponseOutput | None, + ]: + from llama_stack.apis.agents.openai_responses import ( + MCPListToolsTool, + ) + from llama_stack.apis.tools import Tool + + mcp_tool_to_server = {} + + def make_openai_tool(tool_name: str, tool: Tool) -> ChatCompletionToolParam: + tool_def = ToolDefinition( + tool_name=tool_name, + description=tool.description, + parameters={ + param.name: ToolParamDefinition( + param_type=param.parameter_type, + description=param.description, + required=param.required, + default=param.default, + ) + for param in tool.parameters + }, + ) + return convert_tooldef_to_openai_tool(tool_def) + + mcp_list_message = None + chat_tools: list[ChatCompletionToolParam] = [] + for input_tool in tools: + # TODO: Handle other tool types + if input_tool.type == "function": + chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump())) + elif input_tool.type in WebSearchToolTypes: + tool_name = "web_search" + tool = await self.tool_groups_api.get_tool(tool_name) + if not tool: + raise ValueError(f"Tool {tool_name} not found") + chat_tools.append(make_openai_tool(tool_name, tool)) + elif input_tool.type == "file_search": + tool_name = "knowledge_search" + tool = await self.tool_groups_api.get_tool(tool_name) + if not tool: + raise ValueError(f"Tool {tool_name} not found") + chat_tools.append(make_openai_tool(tool_name, tool)) + elif input_tool.type == "mcp": + from llama_stack.providers.utils.tools.mcp import list_mcp_tools + + always_allowed = None + never_allowed = None + if input_tool.allowed_tools: + if isinstance(input_tool.allowed_tools, list): + always_allowed = input_tool.allowed_tools + elif isinstance(input_tool.allowed_tools, AllowedToolsFilter): + always_allowed = input_tool.allowed_tools.always + never_allowed = input_tool.allowed_tools.never + + tool_defs = await list_mcp_tools( + endpoint=input_tool.server_url, + headers=input_tool.headers or {}, + ) + + mcp_list_message = OpenAIResponseOutputMessageMCPListTools( + id=f"mcp_list_{uuid.uuid4()}", + status="completed", + server_label=input_tool.server_label, + tools=[], + ) + for t in tool_defs.data: + if never_allowed and t.name in never_allowed: + continue + if not always_allowed or t.name in always_allowed: + chat_tools.append(make_openai_tool(t.name, t)) + if t.name in mcp_tool_to_server: + raise ValueError(f"Duplicate tool name {t.name} found for server {input_tool.server_label}") + mcp_tool_to_server[t.name] = input_tool + mcp_list_message.tools.append( + MCPListToolsTool( + name=t.name, + description=t.description, + input_schema={ + "type": "object", + "properties": { + p.name: { + "type": p.parameter_type, + "description": p.description, + } + for p in t.parameters + }, + "required": [p.name for p in t.parameters if p.required], + }, + ) + ) + else: + raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}") + return chat_tools, mcp_tool_to_server, mcp_list_message + + async def _execute_knowledge_search_via_vector_store( + self, + query: str, + response_file_search_tool: OpenAIResponseInputToolFileSearch, + ) -> ToolInvocationResult: + """Execute knowledge search using vector_stores.search API with filters support.""" + search_results = [] + + # Create search tasks for all vector stores + async def search_single_store(vector_store_id): + try: + search_response = await self.vector_io_api.openai_search_vector_store( + vector_store_id=vector_store_id, + query=query, + filters=response_file_search_tool.filters, + max_num_results=response_file_search_tool.max_num_results, + ranking_options=response_file_search_tool.ranking_options, + rewrite_query=False, + ) + return search_response.data + except Exception as e: + logger.warning(f"Failed to search vector store {vector_store_id}: {e}") + return [] + + # Run all searches in parallel using gather + search_tasks = [search_single_store(vid) for vid in response_file_search_tool.vector_store_ids] + all_results = await asyncio.gather(*search_tasks) + + # Flatten results + for results in all_results: + search_results.extend(results) + + # Convert search results to tool result format matching memory.py + # Format the results as interleaved content similar to memory.py + content_items = [] + content_items.append( + TextContentItem( + text=f"knowledge_search tool found {len(search_results)} chunks:\nBEGIN of knowledge_search tool results.\n" + ) + ) + + for i, result_item in enumerate(search_results): + chunk_text = result_item.content[0].text if result_item.content else "" + metadata_text = f"document_id: {result_item.file_id}, score: {result_item.score}" + if result_item.attributes: + metadata_text += f", attributes: {result_item.attributes}" + text_content = f"[{i + 1}] {metadata_text}\n{chunk_text}\n" + content_items.append(TextContentItem(text=text_content)) + + content_items.append(TextContentItem(text="END of knowledge_search tool results.\n")) + content_items.append( + TextContentItem( + text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.\n', + ) + ) + + return ToolInvocationResult( + content=content_items, + metadata={ + "document_ids": [r.file_id for r in search_results], + "chunks": [r.content[0].text if r.content else "" for r in search_results], + "scores": [r.score for r in search_results], + }, + ) + + async def _execute_tool_call( + self, + tool_call: OpenAIChatCompletionToolCall, + ctx: ChatCompletionContext, + sequence_number: int, + response_id: str, + output_index: int, + item_id: str, + ) -> AsyncIterator[ToolExecutionResult]: + from llama_stack.providers.utils.inference.prompt_adapter import ( + interleaved_content_as_str, + ) + + tool_call_id = tool_call.id + function = tool_call.function + tool_kwargs = json.loads(function.arguments) if function.arguments else {} + + if not function or not tool_call_id or not function.name: + yield ToolExecutionResult(sequence_number=sequence_number) + return + + # Emit in_progress event based on tool type (only for tools with specific streaming events) + progress_event = None + if ctx.mcp_tool_to_server and function.name in ctx.mcp_tool_to_server: + sequence_number += 1 + progress_event = OpenAIResponseObjectStreamResponseMcpCallInProgress( + item_id=item_id, + output_index=output_index, + sequence_number=sequence_number, + ) + elif function.name == "web_search": + sequence_number += 1 + progress_event = OpenAIResponseObjectStreamResponseWebSearchCallInProgress( + item_id=item_id, + output_index=output_index, + sequence_number=sequence_number, + ) + # Note: knowledge_search and other custom tools don't have specific streaming events in OpenAI spec + + if progress_event: + yield ToolExecutionResult(stream_event=progress_event, sequence_number=sequence_number) + + # For web search, emit searching event + if function.name == "web_search": + sequence_number += 1 + searching_event = OpenAIResponseObjectStreamResponseWebSearchCallSearching( + item_id=item_id, + output_index=output_index, + sequence_number=sequence_number, + ) + yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number) + + # Execute the actual tool call + error_exc = None + result = None + try: + if ctx.mcp_tool_to_server and function.name in ctx.mcp_tool_to_server: + from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool + + mcp_tool = ctx.mcp_tool_to_server[function.name] + result = await invoke_mcp_tool( + endpoint=mcp_tool.server_url, + headers=mcp_tool.headers or {}, + tool_name=function.name, + kwargs=tool_kwargs, + ) + elif function.name == "knowledge_search": + response_file_search_tool = next( + (t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)), + None, + ) + if response_file_search_tool: + # Use vector_stores.search API instead of knowledge_search tool + # to support filters and ranking_options + query = tool_kwargs.get("query", "") + result = await self._execute_knowledge_search_via_vector_store( + query=query, + response_file_search_tool=response_file_search_tool, + ) + else: + result = await self.tool_runtime_api.invoke_tool( + tool_name=function.name, + kwargs=tool_kwargs, + ) + except Exception as e: + error_exc = e + + # Emit completion or failure event based on result (only for tools with specific streaming events) + has_error = error_exc or (result and ((result.error_code and result.error_code > 0) or result.error_message)) + completion_event = None + + if ctx.mcp_tool_to_server and function.name in ctx.mcp_tool_to_server: + sequence_number += 1 + if has_error: + completion_event = OpenAIResponseObjectStreamResponseMcpCallFailed( + sequence_number=sequence_number, + ) + else: + completion_event = OpenAIResponseObjectStreamResponseMcpCallCompleted( + sequence_number=sequence_number, + ) + elif function.name == "web_search": + sequence_number += 1 + completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted( + item_id=item_id, + output_index=output_index, + sequence_number=sequence_number, + ) + # Note: knowledge_search and other custom tools don't have specific completion events in OpenAI spec + + if completion_event: + yield ToolExecutionResult(stream_event=completion_event, sequence_number=sequence_number) + + # Build the result message and input message + if function.name in ctx.mcp_tool_to_server: + from llama_stack.apis.agents.openai_responses import ( + OpenAIResponseOutputMessageMCPCall, + ) + + message = OpenAIResponseOutputMessageMCPCall( + id=tool_call_id, + arguments=function.arguments, + name=function.name, + server_label=ctx.mcp_tool_to_server[function.name].server_label, + ) + if error_exc: + message.error = str(error_exc) + elif (result and result.error_code and result.error_code > 0) or (result and result.error_message): + message.error = f"Error (code {result.error_code}): {result.error_message}" + elif result and result.content: + message.output = interleaved_content_as_str(result.content) + else: + if function.name == "web_search": + message = OpenAIResponseOutputMessageWebSearchToolCall( + id=tool_call_id, + status="completed", + ) + if has_error: + message.status = "failed" + elif function.name == "knowledge_search": + message = OpenAIResponseOutputMessageFileSearchToolCall( + id=tool_call_id, + queries=[tool_kwargs.get("query", "")], + status="completed", + ) + if result and "document_ids" in result.metadata: + message.results = [] + for i, doc_id in enumerate(result.metadata["document_ids"]): + text = result.metadata["chunks"][i] if "chunks" in result.metadata else None + score = result.metadata["scores"][i] if "scores" in result.metadata else None + message.results.append( + OpenAIResponseOutputMessageFileSearchToolCallResults( + file_id=doc_id, + filename=doc_id, + text=text, + score=score, + attributes={}, + ) + ) + if has_error: + message.status = "failed" + else: + raise ValueError(f"Unknown tool {function.name} called") + + input_message = None + if result and result.content: + if isinstance(result.content, str): + content = result.content + elif isinstance(result.content, list): + from llama_stack.apis.common.content_types import ( + ImageContentItem, + TextContentItem, + ) + + content = [] + for item in result.content: + if isinstance(item, TextContentItem): + part = OpenAIChatCompletionContentPartTextParam(text=item.text) + elif isinstance(item, ImageContentItem): + if item.image.data: + url = f"data:image;base64,{item.image.data}" + else: + url = item.image.url + part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url)) + else: + raise ValueError(f"Unknown result content type: {type(item)}") + content.append(part) + else: + raise ValueError(f"Unknown result content type: {type(result.content)}") + input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id) + else: + text = str(error_exc) if error_exc else "Tool execution failed" + input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id) + + # Yield the final result + yield ToolExecutionResult( + sequence_number=sequence_number, final_output_message=message, final_input_message=input_message + ) + + +def _is_function_tool_call( + tool_call: OpenAIChatCompletionToolCall, + tools: list[OpenAIResponseInputTool], +) -> bool: + if not tool_call.function: + return False + for t in tools: + if t.type == "function" and t.name == tool_call.function.name: + return True + return False diff --git a/llama_stack/providers/inline/agents/meta_reference/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py index 0b234d96c..c19051f86 100644 --- a/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import json -import logging import uuid from datetime import UTC, datetime @@ -15,9 +14,10 @@ from llama_stack.core.access_control.access_control import AccessDeniedError, is from llama_stack.core.access_control.datatypes import AccessRule from llama_stack.core.datatypes import User from llama_stack.core.request_headers import get_authenticated_user +from llama_stack.log import get_logger from llama_stack.providers.utils.kvstore import KVStore -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="agents") class AgentSessionInfo(Session): diff --git a/llama_stack/providers/inline/agents/meta_reference/safety.py b/llama_stack/providers/inline/agents/meta_reference/safety.py index 605f387b7..b8a5d8a95 100644 --- a/llama_stack/providers/inline/agents/meta_reference/safety.py +++ b/llama_stack/providers/inline/agents/meta_reference/safety.py @@ -5,13 +5,13 @@ # the root directory of this source tree. import asyncio -import logging from llama_stack.apis.inference import Message from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel +from llama_stack.log import get_logger from llama_stack.providers.utils.telemetry import tracing -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="agents") class SafetyException(Exception): # noqa: N818 diff --git a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py index 7ade75032..bb6a1bd03 100644 --- a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +++ b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py @@ -12,7 +12,6 @@ import copy import json -import logging import multiprocessing import os import tempfile @@ -32,13 +31,14 @@ from fairscale.nn.model_parallel.initialize import ( from pydantic import BaseModel, Field from torch.distributed.launcher.api import LaunchConfig, elastic_launch +from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import GenerationResult from llama_stack.providers.utils.inference.prompt_adapter import ( ChatCompletionRequestWithRawContent, CompletionRequestWithRawContent, ) -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="inference") class ProcessingMessageName(str, Enum): diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index fea8a8189..600a5bd37 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging from collections.abc import AsyncGenerator from llama_stack.apis.inference import ( @@ -21,6 +20,7 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.apis.models import ModelType +from llama_stack.log import get_logger from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.utils.inference.embedding_mixin import ( SentenceTransformerEmbeddingMixin, @@ -32,7 +32,7 @@ from llama_stack.providers.utils.inference.openai_compat import ( from .config import SentenceTransformersInferenceConfig -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="inference") class SentenceTransformersInferenceImpl( diff --git a/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py index 2574b995b..d9ee3d2a8 100644 --- a/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py +++ b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py @@ -6,7 +6,6 @@ import gc import json -import logging import multiprocessing from pathlib import Path from typing import Any @@ -28,6 +27,7 @@ from llama_stack.apis.post_training import ( LoraFinetuningConfig, TrainingConfig, ) +from llama_stack.log import get_logger from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device from ..config import HuggingFacePostTrainingConfig @@ -44,7 +44,7 @@ from ..utils import ( split_dataset, ) -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="post_training") class HFFinetuningSingleDevice: diff --git a/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device_dpo.py b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device_dpo.py index a7c19faac..b39a24c66 100644 --- a/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device_dpo.py +++ b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device_dpo.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import gc -import logging import multiprocessing from pathlib import Path from typing import Any @@ -24,6 +23,7 @@ from llama_stack.apis.post_training import ( DPOAlignmentConfig, TrainingConfig, ) +from llama_stack.log import get_logger from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device from ..config import HuggingFacePostTrainingConfig @@ -40,7 +40,7 @@ from ..utils import ( split_dataset, ) -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="post_training") class HFDPOAlignmentSingleDevice: diff --git a/llama_stack/providers/inline/post_training/huggingface/utils.py b/llama_stack/providers/inline/post_training/huggingface/utils.py index 3147c19ab..f229c87dd 100644 --- a/llama_stack/providers/inline/post_training/huggingface/utils.py +++ b/llama_stack/providers/inline/post_training/huggingface/utils.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging import os import signal import sys @@ -19,10 +18,11 @@ from transformers import AutoConfig, AutoModelForCausalLM from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.post_training import Checkpoint, TrainingConfig +from llama_stack.log import get_logger from .config import HuggingFacePostTrainingConfig -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="post_training") def setup_environment(): diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index 49e1c95b8..8b1462862 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging import os import time from datetime import UTC, datetime @@ -19,6 +18,7 @@ from torch.utils.data import DataLoader, DistributedSampler from torchtune import modules, training from torchtune import utils as torchtune_utils from torchtune.data import padded_collate_sft +from torchtune.models.llama3._tokenizer import Llama3Tokenizer from torchtune.modules.loss import CEWithChunkedOutputLoss from torchtune.modules.peft import ( get_adapter_params, @@ -45,6 +45,7 @@ from llama_stack.apis.post_training import ( ) from llama_stack.core.utils.config_dirs import DEFAULT_CHECKPOINT_DIR from llama_stack.core.utils.model_utils import model_local_dir +from llama_stack.log import get_logger from llama_stack.models.llama.sku_list import resolve_model from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device from llama_stack.providers.inline.post_training.torchtune.common import utils @@ -56,9 +57,7 @@ from llama_stack.providers.inline.post_training.torchtune.config import ( ) from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset -log = logging.getLogger(__name__) - -from torchtune.models.llama3._tokenizer import Llama3Tokenizer +log = get_logger(name=__name__, category="post_training") class LoraFinetuningSingleDevice: diff --git a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py index be05ee436..1b9397a4d 100644 --- a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py +++ b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging from typing import Any from llama_stack.apis.inference import Message @@ -15,13 +14,14 @@ from llama_stack.apis.safety import ( ViolationLevel, ) from llama_stack.apis.shields import Shield +from llama_stack.log import get_logger from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, ) from .config import CodeScannerConfig -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="safety") ALLOWED_CODE_SCANNER_MODEL_IDS = [ "CodeScanner", diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py index bae744010..9ba72798d 100644 --- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging import re import uuid from string import Template @@ -25,6 +24,7 @@ from llama_stack.apis.safety import ( from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults from llama_stack.apis.shields import Shield from llama_stack.core.datatypes import Api +from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import Role from llama_stack.models.llama.sku_types import CoreModelId from llama_stack.providers.datatypes import ShieldsProtocolPrivate @@ -137,6 +137,8 @@ Provide your safety assessment for ONLY THE LAST $agent_type message in the abov PROMPT_TEMPLATE = Template(f"{PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS}") +logger = get_logger(name=__name__, category="safety") + class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate): def __init__(self, config: LlamaGuardConfig, deps) -> None: @@ -412,7 +414,7 @@ class LlamaGuardShield: unsafe_code_list = [code.strip() for code in unsafe_code.split(",")] invalid_codes = [code for code in unsafe_code_list if code not in SAFETY_CODE_TO_CATEGORIES_MAP] if invalid_codes: - logging.warning(f"Invalid safety codes returned: {invalid_codes}") + logger.warning(f"Invalid safety codes returned: {invalid_codes}") # just returning safe object, as we don't know what the invalid codes can map to return ModerationObject( id=f"modr-{uuid.uuid4()}", diff --git a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py index c760f0fd1..6fb6c4407 100644 --- a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +++ b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging from typing import Any import torch @@ -21,6 +20,7 @@ from llama_stack.apis.safety import ( from llama_stack.apis.safety.safety import ModerationObject from llama_stack.apis.shields import Shield from llama_stack.core.utils.model_utils import model_local_dir +from llama_stack.log import get_logger from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, @@ -28,7 +28,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import PromptGuardConfig, PromptGuardType -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="safety") PROMPT_GUARD_MODEL = "Prompt-Guard-86M" diff --git a/llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py b/llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py index b74c3826e..c9358101d 100644 --- a/llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py +++ b/llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py @@ -7,7 +7,6 @@ import collections import functools import json -import logging import random import re import string @@ -20,7 +19,9 @@ import nltk from pythainlp.tokenize import sent_tokenize as sent_tokenize_thai from pythainlp.tokenize import word_tokenize as word_tokenize_thai -logger = logging.getLogger() +from llama_stack.log import get_logger + +logger = get_logger(name=__name__, category="scoring") WORD_LIST = [ "western", diff --git a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py index d99255c79..30710ec2a 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py @@ -4,13 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging import threading from typing import Any from opentelemetry import metrics, trace - -logger = logging.getLogger(__name__) from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter from opentelemetry.sdk.metrics import MeterProvider @@ -40,6 +37,7 @@ from llama_stack.apis.telemetry import ( UnstructuredLogEvent, ) from llama_stack.core.datatypes import Api +from llama_stack.log import get_logger from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import ( ConsoleSpanProcessor, ) @@ -61,6 +59,8 @@ _GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = { _global_lock = threading.Lock() _TRACER_PROVIDER = None +logger = get_logger(name=__name__, category="telemetry") + def is_tracing_enabled(tracer): with tracer.start_as_current_span("check_tracing") as span: diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index 6a7c7885c..a1543457b 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import asyncio -import logging import secrets import string from typing import Any @@ -32,6 +31,7 @@ from llama_stack.apis.tools import ( ToolRuntime, ) from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO +from llama_stack.log import get_logger from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str from llama_stack.providers.utils.memory.vector_store import ( @@ -42,7 +42,7 @@ from llama_stack.providers.utils.memory.vector_store import ( from .config import RagToolRuntimeConfig from .context_retriever import generate_rag_query -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="tool_runtime") def make_random_string(length: int = 8): diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py index af61da59b..258c6e7aa 100644 --- a/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -8,7 +8,6 @@ import asyncio import base64 import io import json -import logging from typing import Any import faiss @@ -24,6 +23,7 @@ from llama_stack.apis.vector_io import ( QueryChunksResponse, VectorIO, ) +from llama_stack.log import get_logger from llama_stack.providers.datatypes import ( HealthResponse, HealthStatus, @@ -40,7 +40,7 @@ from llama_stack.providers.utils.memory.vector_store import ( from .config import FaissVectorIOConfig -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="vector_io") VERSION = "v3" VECTOR_DBS_PREFIX = f"vector_dbs:{VERSION}::" diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index cc1982f3b..7cf163960 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import asyncio -import logging import re import sqlite3 import struct @@ -24,6 +23,7 @@ from llama_stack.apis.vector_io import ( QueryChunksResponse, VectorIO, ) +from llama_stack.log import get_logger from llama_stack.providers.datatypes import VectorDBsProtocolPrivate from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.api import KVStore @@ -36,7 +36,7 @@ from llama_stack.providers.utils.memory.vector_store import ( VectorDBWithIndex, ) -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="vector_io") # Specifying search mode is dependent on the VectorIO provider. VECTOR_SEARCH = "vector" diff --git a/llama_stack/providers/remote/inference/llama_openai_compat/llama.py b/llama_stack/providers/remote/inference/llama_openai_compat/llama.py index 4857c6723..cfcfcbf90 100644 --- a/llama_stack/providers/remote/inference/llama_openai_compat/llama.py +++ b/llama_stack/providers/remote/inference/llama_openai_compat/llama.py @@ -3,15 +3,14 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging - +from llama_stack.log import get_logger from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from .models import MODEL_ENTRIES -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="inference") class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin): diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 7bc3fd0c9..297fb5762 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging import warnings from collections.abc import AsyncIterator @@ -33,6 +32,7 @@ from llama_stack.apis.inference import ( ToolChoice, ToolConfig, ) +from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, @@ -54,7 +54,7 @@ from .openai_utils import ( ) from .utils import _is_nvidia_hosted -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="inference") class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper): diff --git a/llama_stack/providers/remote/inference/nvidia/utils.py b/llama_stack/providers/remote/inference/nvidia/utils.py index 74019999e..790bbafd1 100644 --- a/llama_stack/providers/remote/inference/nvidia/utils.py +++ b/llama_stack/providers/remote/inference/nvidia/utils.py @@ -4,13 +4,13 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging - import httpx +from llama_stack.log import get_logger + from . import NVIDIAConfig -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="inference") def _is_nvidia_hosted(config: NVIDIAConfig) -> bool: diff --git a/llama_stack/providers/remote/inference/openai/openai.py b/llama_stack/providers/remote/inference/openai/openai.py index 865258559..1c72fa0bc 100644 --- a/llama_stack/providers/remote/inference/openai/openai.py +++ b/llama_stack/providers/remote/inference/openai/openai.py @@ -4,15 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging - +from llama_stack.log import get_logger from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from .config import OpenAIConfig from .models import MODEL_ENTRIES -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="inference") # diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index 323831845..9da961438 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -5,7 +5,6 @@ # the root directory of this source tree. -import logging from collections.abc import AsyncGenerator from huggingface_hub import AsyncInferenceClient, HfApi @@ -34,6 +33,7 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.apis.models import Model +from llama_stack.log import get_logger from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.model_registry import ( @@ -58,7 +58,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="inference") def build_hf_repo_model_entries(): diff --git a/llama_stack/providers/remote/post_training/nvidia/utils.py b/llama_stack/providers/remote/post_training/nvidia/utils.py index d6e1016b2..9a6c3b53c 100644 --- a/llama_stack/providers/remote/post_training/nvidia/utils.py +++ b/llama_stack/providers/remote/post_training/nvidia/utils.py @@ -4,18 +4,18 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging import warnings from typing import Any from pydantic import BaseModel from llama_stack.apis.post_training import TrainingConfig +from llama_stack.log import get_logger from llama_stack.providers.remote.post_training.nvidia.config import SFTLoRADefaultConfig from .config import NvidiaPostTrainingConfig -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="integration") def warn_unsupported_params(config_dict: Any, supported_keys: set[str], config_name: str) -> None: diff --git a/llama_stack/providers/remote/safety/bedrock/bedrock.py b/llama_stack/providers/remote/safety/bedrock/bedrock.py index 1895e7507..1ca87ae3d 100644 --- a/llama_stack/providers/remote/safety/bedrock/bedrock.py +++ b/llama_stack/providers/remote/safety/bedrock/bedrock.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import json -import logging from typing import Any from llama_stack.apis.inference import Message @@ -16,12 +15,13 @@ from llama_stack.apis.safety import ( ViolationLevel, ) from llama_stack.apis.shields import Shield +from llama_stack.log import get_logger from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.utils.bedrock.client import create_bedrock_client from .config import BedrockSafetyConfig -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="safety") class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): diff --git a/llama_stack/providers/remote/safety/nvidia/nvidia.py b/llama_stack/providers/remote/safety/nvidia/nvidia.py index 7f17b1cb6..0d8d8ba7a 100644 --- a/llama_stack/providers/remote/safety/nvidia/nvidia.py +++ b/llama_stack/providers/remote/safety/nvidia/nvidia.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging from typing import Any import requests @@ -12,12 +11,13 @@ import requests from llama_stack.apis.inference import Message from llama_stack.apis.safety import RunShieldResponse, Safety, SafetyViolation, ViolationLevel from llama_stack.apis.shields import Shield +from llama_stack.log import get_logger from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new from .config import NVIDIASafetyConfig -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="safety") class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate): diff --git a/llama_stack/providers/remote/safety/sambanova/sambanova.py b/llama_stack/providers/remote/safety/sambanova/sambanova.py index 6c7190afe..676ee7185 100644 --- a/llama_stack/providers/remote/safety/sambanova/sambanova.py +++ b/llama_stack/providers/remote/safety/sambanova/sambanova.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import json -import logging from typing import Any import litellm @@ -20,12 +19,13 @@ from llama_stack.apis.safety import ( ) from llama_stack.apis.shields import Shield from llama_stack.core.request_headers import NeedsRequestProviderData +from llama_stack.log import get_logger from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new from .config import SambaNovaSafetyConfig -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="safety") CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?" diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py index 8f252711b..0047e6055 100644 --- a/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import asyncio import json -import logging from typing import Any from urllib.parse import urlparse @@ -20,6 +19,7 @@ from llama_stack.apis.vector_io import ( QueryChunksResponse, VectorIO, ) +from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig from llama_stack.providers.utils.kvstore import kvstore_impl @@ -33,7 +33,7 @@ from llama_stack.providers.utils.memory.vector_store import ( from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="vector_io") ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index 0eaae81b3..af918d0eb 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import asyncio -import logging import os from typing import Any @@ -21,6 +20,7 @@ from llama_stack.apis.vector_io import ( QueryChunksResponse, VectorIO, ) +from llama_stack.log import get_logger from llama_stack.providers.datatypes import VectorDBsProtocolPrivate from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig from llama_stack.providers.utils.kvstore import kvstore_impl @@ -36,7 +36,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="vector_io") VERSION = "v3" VECTOR_DBS_PREFIX = f"vector_dbs:milvus:{VERSION}::" diff --git a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py index d2a5d910b..e829c9e72 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py +++ b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging from typing import Any import psycopg2 @@ -22,6 +21,7 @@ from llama_stack.apis.vector_io import ( QueryChunksResponse, VectorIO, ) +from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.api import KVStore @@ -34,7 +34,7 @@ from llama_stack.providers.utils.memory.vector_store import ( from .config import PGVectorVectorIOConfig -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="vector_io") VERSION = "v3" VECTOR_DBS_PREFIX = f"vector_dbs:pgvector:{VERSION}::" diff --git a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py index 018015780..8499ff997 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py +++ b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import asyncio -import logging import uuid from typing import Any @@ -24,6 +23,7 @@ from llama_stack.apis.vector_io import ( VectorStoreChunkingStrategy, VectorStoreFileObject, ) +from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl @@ -36,7 +36,7 @@ from llama_stack.providers.utils.memory.vector_store import ( from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="vector_io") CHUNK_ID_KEY = "_chunk_id" # KV store prefixes for vector databases diff --git a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py index 966724848..ddf95317b 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py +++ b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import json -import logging from typing import Any import weaviate @@ -19,6 +18,7 @@ from llama_stack.apis.files.files import Files from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.core.request_headers import NeedsRequestProviderData +from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.api import KVStore @@ -34,7 +34,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti from .config import WeaviateVectorIOConfig -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="vector_io") VERSION = "v3" VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::" diff --git a/llama_stack/providers/utils/inference/embedding_mixin.py b/llama_stack/providers/utils/inference/embedding_mixin.py index 32e89f987..05886cdc8 100644 --- a/llama_stack/providers/utils/inference/embedding_mixin.py +++ b/llama_stack/providers/utils/inference/embedding_mixin.py @@ -5,10 +5,11 @@ # the root directory of this source tree. import base64 -import logging import struct from typing import TYPE_CHECKING +from llama_stack.log import get_logger + if TYPE_CHECKING: from sentence_transformers import SentenceTransformer @@ -27,7 +28,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import interleaved_con EMBEDDING_MODELS = {} -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="inference") class SentenceTransformerEmbeddingMixin: diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 6297cc2ed..255ff56f3 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import base64 import json -import logging import struct import time import uuid @@ -116,6 +115,7 @@ from llama_stack.apis.inference import ( from llama_stack.apis.inference import ( OpenAIChoice as OpenAIChatCompletionChoice, ) +from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import ( BuiltinTool, StopReason, @@ -128,7 +128,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( decode_assistant_message, ) -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="inference") class OpenAICompatCompletionChoiceDelta(BaseModel): diff --git a/llama_stack/providers/utils/kvstore/mongodb/mongodb.py b/llama_stack/providers/utils/kvstore/mongodb/mongodb.py index 3842773d9..af52f3708 100644 --- a/llama_stack/providers/utils/kvstore/mongodb/mongodb.py +++ b/llama_stack/providers/utils/kvstore/mongodb/mongodb.py @@ -4,16 +4,16 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging from datetime import datetime from pymongo import AsyncMongoClient +from llama_stack.log import get_logger from llama_stack.providers.utils.kvstore import KVStore from ..config import MongoDBKVStoreConfig -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="kvstore") class MongoDBKVStoreImpl(KVStore): diff --git a/llama_stack/providers/utils/kvstore/postgres/postgres.py b/llama_stack/providers/utils/kvstore/postgres/postgres.py index bd35decfc..a83257175 100644 --- a/llama_stack/providers/utils/kvstore/postgres/postgres.py +++ b/llama_stack/providers/utils/kvstore/postgres/postgres.py @@ -4,16 +4,17 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging from datetime import datetime import psycopg2 from psycopg2.extras import DictCursor +from llama_stack.log import get_logger + from ..api import KVStore from ..config import PostgresKVStoreConfig -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="kvstore") class PostgresKVStoreImpl(KVStore): diff --git a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py index 120d0d4fc..0775b31d1 100644 --- a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py +++ b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py @@ -44,7 +44,7 @@ from llama_stack.providers.utils.memory.vector_store import ( make_overlapped_chunks, ) -logger = get_logger(__name__, category="vector_io") +logger = get_logger(name=__name__, category="memory") # Constants for OpenAI vector stores CHUNK_MULTIPLIER = 5 diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index 6ae5bb521..b5d82432d 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import base64 import io -import logging import re import time from abc import ABC, abstractmethod @@ -26,6 +25,7 @@ from llama_stack.apis.common.content_types import ( from llama_stack.apis.tools import RAGDocument from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse +from llama_stack.log import get_logger from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.providers.datatypes import Api from llama_stack.providers.utils.inference.prompt_adapter import ( @@ -33,7 +33,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( ) from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="memory") class ChunkForDeletion(BaseModel): diff --git a/llama_stack/providers/utils/telemetry/tracing.py b/llama_stack/providers/utils/telemetry/tracing.py index 7080e774a..7694003b5 100644 --- a/llama_stack/providers/utils/telemetry/tracing.py +++ b/llama_stack/providers/utils/telemetry/tracing.py @@ -6,7 +6,7 @@ import asyncio import contextvars -import logging +import logging # allow-direct-logging import queue import random import sys diff --git a/tests/integration/post_training/test_post_training.py b/tests/integration/post_training/test_post_training.py index f9c797593..b5be71c7c 100644 --- a/tests/integration/post_training/test_post_training.py +++ b/tests/integration/post_training/test_post_training.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging import sys import time import uuid @@ -19,10 +18,10 @@ from llama_stack.apis.post_training import ( LoraFinetuningConfig, TrainingConfig, ) +from llama_stack.log import get_logger # Configure logging -logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", force=True) -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="post_training") skip_because_resource_intensive = pytest.mark.skip( diff --git a/tests/integration/vector_io/test_openai_vector_stores.py b/tests/integration/vector_io/test_openai_vector_stores.py index 7ccca9077..dff36182c 100644 --- a/tests/integration/vector_io/test_openai_vector_stores.py +++ b/tests/integration/vector_io/test_openai_vector_stores.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging import time from io import BytesIO @@ -14,8 +13,9 @@ from openai import BadRequestError as OpenAIBadRequestError from llama_stack.apis.vector_io import Chunk from llama_stack.core.library_client import LlamaStackAsLibraryClient +from llama_stack.log import get_logger -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="vector_io") def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models): diff --git a/tests/unit/providers/inference/test_remote_vllm.py b/tests/unit/providers/inference/test_remote_vllm.py index 5c2ad03ab..ce0e930b1 100644 --- a/tests/unit/providers/inference/test_remote_vllm.py +++ b/tests/unit/providers/inference/test_remote_vllm.py @@ -6,7 +6,7 @@ import asyncio import json -import logging +import logging # allow-direct-logging import threading import time from http.server import BaseHTTPRequestHandler, HTTPServer