From 43adc23ef641d84d183a47afa8a8653fc092f9f7 Mon Sep 17 00:00:00 2001 From: Charlie Doern Date: Mon, 10 Nov 2025 18:29:24 -0500 Subject: [PATCH 1/2] refactor: remove dead inference API code and clean up imports (#4093) # What does this PR do? Delete ~2,000 lines of dead code from the old bespoke inference API that was replaced by OpenAI-only API. This includes removing unused type conversion functions, dead provider methods, and event_logger.py. Clean up imports across the codebase to remove references to deleted types. This eliminates unnecessary code and dependencies, helping isolate the API package as a self-contained module. This is the last interdependency between the .api package and "exterior" packages, meaning that now every other package in llama stack imports the API, not the other way around. ## Test Plan this is a structural change, no tests needed. --------- Signed-off-by: Charlie Doern --- .../apis/inference/event_logger.py | 43 - src/llama_stack/apis/inference/inference.py | 182 +--- src/llama_stack/core/routers/safety.py | 4 +- .../models/llama/llama3/generation.py | 4 +- .../models/llama/llama3/interface.py | 7 +- .../llama3/prompt_templates/system_prompts.py | 2 +- .../models/llama/llama3/tool_utils.py | 3 +- .../llama4/prompt_templates/system_prompts.py | 2 +- .../inference/meta_reference/generators.py | 98 +- .../inference/meta_reference/inference.py | 395 ++++++- .../meta_reference/model_parallel.py | 33 +- .../meta_reference/parallel_utils.py | 14 +- .../sentence_transformers.py | 4 - .../utils/inference/litellm_openai_mixin.py | 51 - .../utils/inference/openai_compat.py | 971 +----------------- .../utils/inference/prompt_adapter.py | 287 +----- tests/unit/models/test_prompt_adapter.py | 303 ------ .../providers/inline/inference/__init__.py | 5 + .../inline/inference/test_meta_reference.py | 44 + tests/unit/providers/nvidia/test_safety.py | 27 +- .../utils/inference/test_openai_compat.py | 220 ---- .../utils/inference/test_prompt_adapter.py | 35 + 22 files changed, 593 insertions(+), 2141 deletions(-) delete mode 100644 src/llama_stack/apis/inference/event_logger.py delete mode 100644 tests/unit/models/test_prompt_adapter.py create mode 100644 tests/unit/providers/inline/inference/__init__.py create mode 100644 tests/unit/providers/inline/inference/test_meta_reference.py delete mode 100644 tests/unit/providers/utils/inference/test_openai_compat.py create mode 100644 tests/unit/providers/utils/inference/test_prompt_adapter.py diff --git a/src/llama_stack/apis/inference/event_logger.py b/src/llama_stack/apis/inference/event_logger.py deleted file mode 100644 index d97ece6d4..000000000 --- a/src/llama_stack/apis/inference/event_logger.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from termcolor import cprint - -from llama_stack.apis.inference import ( - ChatCompletionResponseEventType, - ChatCompletionResponseStreamChunk, -) - - -class LogEvent: - def __init__( - self, - content: str = "", - end: str = "\n", - color="white", - ): - self.content = content - self.color = color - self.end = "\n" if end is None else end - - def print(self, flush=True): - cprint(f"{self.content}", color=self.color, end=self.end, flush=flush) - - -class EventLogger: - async def log(self, event_generator): - async for chunk in event_generator: - if isinstance(chunk, ChatCompletionResponseStreamChunk): - event = chunk.event - if event.event_type == ChatCompletionResponseEventType.start: - yield LogEvent("Assistant> ", color="cyan", end="") - elif event.event_type == ChatCompletionResponseEventType.progress: - yield LogEvent(event.delta, color="yellow", end="") - elif event.event_type == ChatCompletionResponseEventType.complete: - yield LogEvent("") - else: - yield LogEvent("Assistant> ", color="cyan", end="") - yield LogEvent(chunk.completion_message.content, color="yellow") diff --git a/src/llama_stack/apis/inference/inference.py b/src/llama_stack/apis/inference/inference.py index 1a865ce5f..9f04917c9 100644 --- a/src/llama_stack/apis/inference/inference.py +++ b/src/llama_stack/apis/inference/inference.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from collections.abc import AsyncIterator -from enum import Enum +from enum import Enum, StrEnum from typing import ( Annotated, Any, @@ -15,28 +15,18 @@ from typing import ( ) from fastapi import Body -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field from typing_extensions import TypedDict -from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent -from llama_stack.apis.common.responses import MetricResponseMixin, Order +from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.common.responses import ( + Order, +) from llama_stack.apis.common.tracing import telemetry_traceable from llama_stack.apis.models import Model from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA -from llama_stack.models.llama.datatypes import ( - BuiltinTool, - StopReason, - ToolCall, - ToolDefinition, - ToolPromptFormat, -) from llama_stack.schema_utils import json_schema_type, register_schema, webmethod -register_schema(ToolCall) -register_schema(ToolDefinition) - -from enum import StrEnum - @json_schema_type class GreedySamplingStrategy(BaseModel): @@ -201,58 +191,6 @@ class ToolResponseMessage(BaseModel): content: InterleavedContent -@json_schema_type -class CompletionMessage(BaseModel): - """A message containing the model's (assistant) response in a chat conversation. - - :param role: Must be "assistant" to identify this as the model's response - :param content: The content of the model's response - :param stop_reason: Reason why the model stopped generating. Options are: - - `StopReason.end_of_turn`: The model finished generating the entire response. - - `StopReason.end_of_message`: The model finished generating but generated a partial response -- usually, a tool call. The user may call the tool and continue the conversation with the tool's response. - - `StopReason.out_of_tokens`: The model ran out of token budget. - :param tool_calls: List of tool calls. Each tool call is a ToolCall object. - """ - - role: Literal["assistant"] = "assistant" - content: InterleavedContent - stop_reason: StopReason - tool_calls: list[ToolCall] | None = Field(default_factory=lambda: []) - - -Message = Annotated[ - UserMessage | SystemMessage | ToolResponseMessage | CompletionMessage, - Field(discriminator="role"), -] -register_schema(Message, name="Message") - - -@json_schema_type -class ToolResponse(BaseModel): - """Response from a tool invocation. - - :param call_id: Unique identifier for the tool call this response is for - :param tool_name: Name of the tool that was invoked - :param content: The response content from the tool - :param metadata: (Optional) Additional metadata about the tool response - """ - - call_id: str - tool_name: BuiltinTool | str - content: InterleavedContent - metadata: dict[str, Any] | None = None - - @field_validator("tool_name", mode="before") - @classmethod - def validate_field(cls, v): - if isinstance(v, str): - try: - return BuiltinTool(v) - except ValueError: - return v - return v - - class ToolChoice(Enum): """Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model. @@ -289,22 +227,6 @@ class ChatCompletionResponseEventType(Enum): progress = "progress" -@json_schema_type -class ChatCompletionResponseEvent(BaseModel): - """An event during chat completion generation. - - :param event_type: Type of the event - :param delta: Content generated since last event. This can be one or more tokens, or a tool call. - :param logprobs: Optional log probabilities for generated tokens - :param stop_reason: Optional reason why generation stopped, if complete - """ - - event_type: ChatCompletionResponseEventType - delta: ContentDelta - logprobs: list[TokenLogProbs] | None = None - stop_reason: StopReason | None = None - - class ResponseFormatType(StrEnum): """Types of formats for structured (guided) decoding. @@ -357,34 +279,6 @@ class CompletionRequest(BaseModel): logprobs: LogProbConfig | None = None -@json_schema_type -class CompletionResponse(MetricResponseMixin): - """Response from a completion request. - - :param content: The generated completion text - :param stop_reason: Reason why generation stopped - :param logprobs: Optional log probabilities for generated tokens - """ - - content: str - stop_reason: StopReason - logprobs: list[TokenLogProbs] | None = None - - -@json_schema_type -class CompletionResponseStreamChunk(MetricResponseMixin): - """A chunk of a streamed completion response. - - :param delta: New content generated since last chunk. This can be one or more tokens. - :param stop_reason: Optional reason why generation stopped, if complete - :param logprobs: Optional log probabilities for generated tokens - """ - - delta: str - stop_reason: StopReason | None = None - logprobs: list[TokenLogProbs] | None = None - - class SystemMessageBehavior(Enum): """Config for how to override the default system prompt. @@ -398,70 +292,6 @@ class SystemMessageBehavior(Enum): replace = "replace" -@json_schema_type -class ToolConfig(BaseModel): - """Configuration for tool use. - - :param tool_choice: (Optional) Whether tool use is automatic, required, or none. Can also specify a tool name to use a specific tool. Defaults to ToolChoice.auto. - :param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model. - - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object. - - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a tag. - - `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls. - :param system_message_behavior: (Optional) Config for how to override the default system prompt. - - `SystemMessageBehavior.append`: Appends the provided system message to the default system prompt. - - `SystemMessageBehavior.replace`: Replaces the default system prompt with the provided system message. The system message can include the string - '{{function_definitions}}' to indicate where the function definitions should be inserted. - """ - - tool_choice: ToolChoice | str | None = Field(default=ToolChoice.auto) - tool_prompt_format: ToolPromptFormat | None = Field(default=None) - system_message_behavior: SystemMessageBehavior | None = Field(default=SystemMessageBehavior.append) - - def model_post_init(self, __context: Any) -> None: - if isinstance(self.tool_choice, str): - try: - self.tool_choice = ToolChoice[self.tool_choice] - except KeyError: - pass - - -# This is an internally used class -@json_schema_type -class ChatCompletionRequest(BaseModel): - model: str - messages: list[Message] - sampling_params: SamplingParams | None = Field(default_factory=SamplingParams) - - tools: list[ToolDefinition] | None = Field(default_factory=lambda: []) - tool_config: ToolConfig | None = Field(default_factory=ToolConfig) - - response_format: ResponseFormat | None = None - stream: bool | None = False - logprobs: LogProbConfig | None = None - - -@json_schema_type -class ChatCompletionResponseStreamChunk(MetricResponseMixin): - """A chunk of a streamed chat completion response. - - :param event: The event containing the new content - """ - - event: ChatCompletionResponseEvent - - -@json_schema_type -class ChatCompletionResponse(MetricResponseMixin): - """Response from a chat completion request. - - :param completion_message: The complete response message - :param logprobs: Optional log probabilities for generated tokens - """ - - completion_message: CompletionMessage - logprobs: list[TokenLogProbs] | None = None - - @json_schema_type class EmbeddingsResponse(BaseModel): """Response containing generated embeddings. diff --git a/src/llama_stack/core/routers/safety.py b/src/llama_stack/core/routers/safety.py index 79eac8b46..e5ff2ada9 100644 --- a/src/llama_stack/core/routers/safety.py +++ b/src/llama_stack/core/routers/safety.py @@ -6,7 +6,7 @@ from typing import Any -from llama_stack.apis.inference import Message +from llama_stack.apis.inference import OpenAIMessageParam from llama_stack.apis.safety import RunShieldResponse, Safety from llama_stack.apis.safety.safety import ModerationObject from llama_stack.apis.shields import Shield @@ -52,7 +52,7 @@ class SafetyRouter(Safety): async def run_shield( self, shield_id: str, - messages: list[Message], + messages: list[OpenAIMessageParam], params: dict[str, Any] = None, ) -> RunShieldResponse: logger.debug(f"SafetyRouter.run_shield: {shield_id}") diff --git a/src/llama_stack/models/llama/llama3/generation.py b/src/llama_stack/models/llama/llama3/generation.py index fe7be5ea9..9ac215c3b 100644 --- a/src/llama_stack/models/llama/llama3/generation.py +++ b/src/llama_stack/models/llama/llama3/generation.py @@ -26,8 +26,10 @@ from fairscale.nn.model_parallel.initialize import ( ) from termcolor import cprint +from llama_stack.models.llama.datatypes import ToolPromptFormat + from ..checkpoint import maybe_reshard_state_dict -from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage, ToolPromptFormat +from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage from .args import ModelArgs from .chat_format import ChatFormat, LLMInput from .model import Transformer diff --git a/src/llama_stack/models/llama/llama3/interface.py b/src/llama_stack/models/llama/llama3/interface.py index b63ba4847..89be31a55 100644 --- a/src/llama_stack/models/llama/llama3/interface.py +++ b/src/llama_stack/models/llama/llama3/interface.py @@ -15,13 +15,10 @@ from pathlib import Path from termcolor import colored +from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall, ToolDefinition, ToolPromptFormat + from ..datatypes import ( - BuiltinTool, RawMessage, - StopReason, - ToolCall, - ToolDefinition, - ToolPromptFormat, ) from . import template_data from .chat_format import ChatFormat diff --git a/src/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py b/src/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py index 11a5993e9..3fbaa103e 100644 --- a/src/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py +++ b/src/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py @@ -15,7 +15,7 @@ import textwrap from datetime import datetime from typing import Any -from llama_stack.apis.inference import ( +from llama_stack.models.llama.datatypes import ( BuiltinTool, ToolDefinition, ) diff --git a/src/llama_stack/models/llama/llama3/tool_utils.py b/src/llama_stack/models/llama/llama3/tool_utils.py index 8c12fe680..6f919e1fa 100644 --- a/src/llama_stack/models/llama/llama3/tool_utils.py +++ b/src/llama_stack/models/llama/llama3/tool_utils.py @@ -8,8 +8,9 @@ import json import re from llama_stack.log import get_logger +from llama_stack.models.llama.datatypes import BuiltinTool, ToolCall, ToolPromptFormat -from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat +from ..datatypes import RecursiveType logger = get_logger(name=__name__, category="models::llama") diff --git a/src/llama_stack/models/llama/llama4/prompt_templates/system_prompts.py b/src/llama_stack/models/llama/llama4/prompt_templates/system_prompts.py index 1ee570933..feded9f8c 100644 --- a/src/llama_stack/models/llama/llama4/prompt_templates/system_prompts.py +++ b/src/llama_stack/models/llama/llama4/prompt_templates/system_prompts.py @@ -13,7 +13,7 @@ import textwrap -from llama_stack.apis.inference import ToolDefinition +from llama_stack.models.llama.datatypes import ToolDefinition from llama_stack.models.llama.llama3.prompt_templates.base import ( PromptTemplate, PromptTemplateGeneratorBase, diff --git a/src/llama_stack/providers/inline/inference/meta_reference/generators.py b/src/llama_stack/providers/inline/inference/meta_reference/generators.py index cb926f529..51a2ddfad 100644 --- a/src/llama_stack/providers/inline/inference/meta_reference/generators.py +++ b/src/llama_stack/providers/inline/inference/meta_reference/generators.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import math -from collections.abc import Generator from typing import Optional import torch @@ -14,21 +13,19 @@ from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerToken from llama_stack.apis.inference import ( GreedySamplingStrategy, JsonSchemaResponseFormat, + OpenAIChatCompletionRequestWithExtraBody, + OpenAIResponseFormatJSONSchema, ResponseFormat, + ResponseFormatType, SamplingParams, TopPSamplingStrategy, ) -from llama_stack.models.llama.datatypes import QuantizationMode +from llama_stack.models.llama.datatypes import QuantizationMode, ToolPromptFormat from llama_stack.models.llama.llama3.generation import Llama3 from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer from llama_stack.models.llama.llama4.generation import Llama4 from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer from llama_stack.models.llama.sku_types import Model, ModelFamily -from llama_stack.providers.utils.inference.prompt_adapter import ( - ChatCompletionRequestWithRawContent, - CompletionRequestWithRawContent, - get_default_tool_prompt_format, -) from .common import model_checkpoint_dir from .config import MetaReferenceInferenceConfig @@ -106,14 +103,6 @@ def _infer_sampling_params(sampling_params: SamplingParams): return temperature, top_p -def _infer_tool_prompt_format(request: ChatCompletionRequestWithRawContent): - tool_config = request.tool_config - if tool_config is not None and tool_config.tool_prompt_format is not None: - return tool_config.tool_prompt_format - else: - return get_default_tool_prompt_format(request.model) - - class LlamaGenerator: def __init__( self, @@ -157,55 +146,56 @@ class LlamaGenerator: self.args = self.inner_generator.args self.formatter = self.inner_generator.formatter - def completion( - self, - request_batch: list[CompletionRequestWithRawContent], - ) -> Generator: - first_request = request_batch[0] - sampling_params = first_request.sampling_params or SamplingParams() - max_gen_len = sampling_params.max_tokens - if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len: - max_gen_len = self.args.max_seq_len - 1 - - temperature, top_p = _infer_sampling_params(sampling_params) - yield from self.inner_generator.generate( - llm_inputs=[self.formatter.encode_content(request.content) for request in request_batch], - max_gen_len=max_gen_len, - temperature=temperature, - top_p=top_p, - logprobs=bool(first_request.logprobs), - echo=False, - logits_processor=get_logits_processor( - self.tokenizer, - self.args.vocab_size, - first_request.response_format, - ), - ) - def chat_completion( self, - request_batch: list[ChatCompletionRequestWithRawContent], - ) -> Generator: - first_request = request_batch[0] - sampling_params = first_request.sampling_params or SamplingParams() + request: OpenAIChatCompletionRequestWithExtraBody, + raw_messages: list, + ): + """Generate chat completion using OpenAI request format. + + Args: + request: OpenAI chat completion request + raw_messages: Pre-converted list of RawMessage objects + """ + + # Determine tool prompt format + tool_prompt_format = ToolPromptFormat.json if request.tools else ToolPromptFormat.json + + # Prepare sampling params + sampling_params = SamplingParams() + if request.temperature is not None or request.top_p is not None: + sampling_params.strategy = TopPSamplingStrategy( + temperature=request.temperature if request.temperature is not None else 1.0, + top_p=request.top_p if request.top_p is not None else 1.0, + ) + if request.max_tokens: + sampling_params.max_tokens = request.max_tokens + max_gen_len = sampling_params.max_tokens if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len: max_gen_len = self.args.max_seq_len - 1 temperature, top_p = _infer_sampling_params(sampling_params) + + # Get logits processor for response format + logits_processor = None + if request.response_format: + if isinstance(request.response_format, OpenAIResponseFormatJSONSchema): + # Extract the actual schema from OpenAIJSONSchema TypedDict + schema_dict = request.response_format.json_schema.get("schema") or {} + json_schema_format = JsonSchemaResponseFormat( + type=ResponseFormatType.json_schema, + json_schema=schema_dict, + ) + logits_processor = get_logits_processor(self.tokenizer, self.args.vocab_size, json_schema_format) + + # Generate yield from self.inner_generator.generate( - llm_inputs=[ - self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request)) - for request in request_batch - ], + llm_inputs=[self.formatter.encode_dialog_prompt(raw_messages, tool_prompt_format)], max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, - logprobs=bool(first_request.logprobs), + logprobs=False, echo=False, - logits_processor=get_logits_processor( - self.tokenizer, - self.args.vocab_size, - first_request.response_format, - ), + logits_processor=logits_processor, ) diff --git a/src/llama_stack/providers/inline/inference/meta_reference/inference.py b/src/llama_stack/providers/inline/inference/meta_reference/inference.py index 76d3fdd50..ef21132a0 100644 --- a/src/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/src/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -5,12 +5,19 @@ # the root directory of this source tree. import asyncio +import time +import uuid from collections.abc import AsyncIterator from llama_stack.apis.inference import ( InferenceProvider, + OpenAIAssistantMessageParam, OpenAIChatCompletionRequestWithExtraBody, + OpenAIChatCompletionUsage, + OpenAIChoice, OpenAICompletionRequestWithExtraBody, + OpenAIUserMessageParam, + ToolChoice, ) from llama_stack.apis.inference.inference import ( OpenAIChatCompletion, @@ -19,12 +26,20 @@ from llama_stack.apis.inference.inference import ( ) from llama_stack.apis.models import Model, ModelType from llama_stack.log import get_logger +from llama_stack.models.llama.datatypes import RawMessage, RawTextItem, ToolDefinition from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat +from llama_stack.models.llama.llama3.prompt_templates import ( + JsonCustomToolGenerator, + SystemDefaultGenerator, +) from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat +from llama_stack.models.llama.llama4.prompt_templates.system_prompts import ( + PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4, +) from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer from llama_stack.models.llama.sku_list import resolve_model -from llama_stack.models.llama.sku_types import ModelFamily +from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.embedding_mixin import ( SentenceTransformerEmbeddingMixin, @@ -44,6 +59,170 @@ log = get_logger(__name__, category="inference") SEMAPHORE = asyncio.Semaphore(1) +def _convert_openai_tool_to_tool_definition(tool) -> ToolDefinition: + """Convert OpenAI tool format to ToolDefinition format.""" + # OpenAI tools have function.name and function.parameters + return ToolDefinition( + tool_name=tool.function.name, + description=tool.function.description or "", + parameters=tool.function.parameters or {}, + ) + + +def _get_tool_choice_prompt(tool_choice, tools) -> str: + """Generate prompt text for tool_choice behavior.""" + if not tool_choice or tool_choice == ToolChoice.auto or tool_choice == "auto": + return "" + elif tool_choice == ToolChoice.required or tool_choice == "required": + return "You MUST use one of the provided functions/tools to answer the user query." + elif tool_choice == ToolChoice.none or tool_choice == "none": + return "" + else: + # Specific tool specified + return f"You MUST use the tool `{tool_choice}` to answer the user query." + + +def _raw_content_as_str(content) -> str: + """Convert RawContent to string for system messages.""" + if isinstance(content, str): + return content + elif isinstance(content, RawTextItem): + return content.text + elif isinstance(content, list): + return "\n".join(_raw_content_as_str(c) for c in content) + else: + return "" + + +def _augment_raw_messages_for_tools_llama_3_1( + raw_messages: list[RawMessage], + tools: list, + tool_choice, +) -> list[RawMessage]: + """Augment raw messages with tool definitions for Llama 3.1 style models.""" + messages = raw_messages.copy() + existing_system_message = None + if messages and messages[0].role == "system": + existing_system_message = messages.pop(0) + + sys_content = "" + + # Add tool definitions first (if present) + if tools: + # Convert OpenAI tools to ToolDefinitions + tool_definitions = [_convert_openai_tool_to_tool_definition(t) for t in tools] + + # For OpenAI format, all tools are custom (have string names) + tool_gen = JsonCustomToolGenerator() + tool_template = tool_gen.gen(tool_definitions) + sys_content += tool_template.render() + sys_content += "\n" + + # Add default system prompt + default_gen = SystemDefaultGenerator() + default_template = default_gen.gen() + sys_content += default_template.render() + + # Add existing system message if present + if existing_system_message: + sys_content += "\n" + _raw_content_as_str(existing_system_message.content) + + # Add tool choice prompt if needed + if tool_choice_prompt := _get_tool_choice_prompt(tool_choice, tools): + sys_content += "\n" + tool_choice_prompt + + # Create new system message + new_system_message = RawMessage( + role="system", + content=[RawTextItem(text=sys_content.strip())], + ) + + return [new_system_message] + messages + + +def _augment_raw_messages_for_tools_llama_4( + raw_messages: list[RawMessage], + tools: list, + tool_choice, +) -> list[RawMessage]: + """Augment raw messages with tool definitions for Llama 4/3.2/3.3 style models.""" + messages = raw_messages.copy() + existing_system_message = None + if messages and messages[0].role == "system": + existing_system_message = messages.pop(0) + + sys_content = "" + + # Add tool definitions if present + if tools: + # Convert OpenAI tools to ToolDefinitions + tool_definitions = [_convert_openai_tool_to_tool_definition(t) for t in tools] + + # Use python_list format for Llama 4 + tool_gen = PythonListCustomToolGeneratorLlama4() + system_prompt = None + if existing_system_message: + system_prompt = _raw_content_as_str(existing_system_message.content) + + tool_template = tool_gen.gen(tool_definitions, system_prompt) + sys_content = tool_template.render() + elif existing_system_message: + # No tools, just use existing system message + sys_content = _raw_content_as_str(existing_system_message.content) + + # Add tool choice prompt if needed + if tool_choice_prompt := _get_tool_choice_prompt(tool_choice, tools): + sys_content += "\n" + tool_choice_prompt + + if sys_content: + new_system_message = RawMessage( + role="system", + content=[RawTextItem(text=sys_content.strip())], + ) + return [new_system_message] + messages + + return messages + + +def augment_raw_messages_for_tools( + raw_messages: list[RawMessage], + params: OpenAIChatCompletionRequestWithExtraBody, + llama_model, +) -> list[RawMessage]: + """Augment raw messages with tool definitions based on model family.""" + if not params.tools: + return raw_messages + + # Determine augmentation strategy based on model family + if llama_model.model_family == ModelFamily.llama3_1 or ( + llama_model.model_family == ModelFamily.llama3_2 and is_multimodal(llama_model.core_model_id) + ): + # Llama 3.1 and Llama 3.2 multimodal use JSON format + return _augment_raw_messages_for_tools_llama_3_1( + raw_messages, + params.tools, + params.tool_choice, + ) + elif llama_model.model_family in ( + ModelFamily.llama3_2, + ModelFamily.llama3_3, + ModelFamily.llama4, + ): + # Llama 3.2/3.3/4 use python_list format + return _augment_raw_messages_for_tools_llama_4( + raw_messages, + params.tools, + params.tool_choice, + ) + else: + # Default to Llama 3.1 style + return _augment_raw_messages_for_tools_llama_3_1( + raw_messages, + params.tools, + params.tool_choice, + ) + + def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator: return LlamaGenerator(config, model_id, llama_model) @@ -136,10 +315,13 @@ class MetaReferenceInferenceImpl( self.llama_model = llama_model log.info("Warming up...") + await self.openai_chat_completion( - model=model_id, - messages=[{"role": "user", "content": "Hi how are you?"}], - max_tokens=20, + params=OpenAIChatCompletionRequestWithExtraBody( + model=model_id, + messages=[OpenAIUserMessageParam(role="user", content="Hi how are you?")], + max_tokens=20, + ) ) log.info("Warmed up!") @@ -155,4 +337,207 @@ class MetaReferenceInferenceImpl( self, params: OpenAIChatCompletionRequestWithExtraBody, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: - raise NotImplementedError("OpenAI chat completion not supported by meta-reference inference provider") + self.check_model(params) + + # Convert OpenAI messages to RawMessages + from llama_stack.models.llama.datatypes import StopReason + from llama_stack.providers.utils.inference.prompt_adapter import ( + convert_openai_message_to_raw_message, + decode_assistant_message, + ) + + raw_messages = [await convert_openai_message_to_raw_message(msg) for msg in params.messages] + + # Augment messages with tool definitions if tools are present + raw_messages = augment_raw_messages_for_tools(raw_messages, params, self.llama_model) + + # Call generator's chat_completion method (works for both single-GPU and model-parallel) + if isinstance(self.generator, LlamaGenerator): + generator = self.generator.chat_completion(params, raw_messages) + else: + # Model parallel: submit task to process group + generator = self.generator.group.run_inference(("chat_completion", [params, raw_messages])) + + # Check if streaming is requested + if params.stream: + return self._stream_chat_completion(generator, params) + + # Non-streaming: collect all generated text + generated_text = "" + for result_batch in generator: + for result in result_batch: + if not result.ignore_token and result.source == "output": + generated_text += result.text + + # Decode assistant message to extract tool calls and determine stop_reason + # Default to end_of_turn if generation completed normally + decoded_message = decode_assistant_message(generated_text, StopReason.end_of_turn) + + # Convert tool calls to OpenAI format + openai_tool_calls = None + if decoded_message.tool_calls: + from llama_stack.apis.inference import ( + OpenAIChatCompletionToolCall, + OpenAIChatCompletionToolCallFunction, + ) + + openai_tool_calls = [ + OpenAIChatCompletionToolCall( + # generate a uuid for the call id. This is the only inline provider that does this, so need to get creative. + id=f"call_{uuid.uuid4().hex[:24]}", + type="function", + function=OpenAIChatCompletionToolCallFunction( + name=str(tc.tool_name), + arguments=tc.arguments, + ), + ) + for tc in decoded_message.tool_calls + ] + + # Determine finish_reason based on whether tool calls are present + finish_reason = "tool_calls" if openai_tool_calls else "stop" + + # Extract content from decoded message + content = "" + if isinstance(decoded_message.content, str): + content = decoded_message.content + elif isinstance(decoded_message.content, list): + for item in decoded_message.content: + if isinstance(item, RawTextItem): + content += item.text + + # Create OpenAI response + # generate a uuid for the call id. This is the only inline provider that does this, so need to get creative. + response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}" + created = int(time.time()) + + return OpenAIChatCompletion( + id=response_id, + object="chat.completion", + created=created, + model=params.model, + choices=[ + OpenAIChoice( + index=0, + message=OpenAIAssistantMessageParam( + role="assistant", + content=content, + tool_calls=openai_tool_calls, + ), + finish_reason=finish_reason, + logprobs=None, + ) + ], + usage=OpenAIChatCompletionUsage( + prompt_tokens=0, # TODO: calculate properly + completion_tokens=0, # TODO: calculate properly + total_tokens=0, # TODO: calculate properly + ), + ) + + async def _stream_chat_completion( + self, + generator, + params: OpenAIChatCompletionRequestWithExtraBody, + ) -> AsyncIterator[OpenAIChatCompletionChunk]: + """Stream chat completion chunks as they're generated.""" + from llama_stack.apis.inference import ( + OpenAIChatCompletionChunk, + OpenAIChatCompletionToolCall, + OpenAIChatCompletionToolCallFunction, + OpenAIChoiceDelta, + OpenAIChunkChoice, + ) + from llama_stack.models.llama.datatypes import StopReason + from llama_stack.providers.utils.inference.prompt_adapter import decode_assistant_message + + response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}" + created = int(time.time()) + generated_text = "" + + # Yield chunks as tokens are generated + for result_batch in generator: + for result in result_batch: + if result.ignore_token or result.source != "output": + continue + + generated_text += result.text + + # Yield delta chunk with the new text + chunk = OpenAIChatCompletionChunk( + id=response_id, + object="chat.completion.chunk", + created=created, + model=params.model, + choices=[ + OpenAIChunkChoice( + index=0, + delta=OpenAIChoiceDelta( + role="assistant", + content=result.text, + ), + finish_reason="", + logprobs=None, + ) + ], + ) + yield chunk + + # After generation completes, decode the full message to extract tool calls + decoded_message = decode_assistant_message(generated_text, StopReason.end_of_turn) + + # If tool calls are present, yield a final chunk with tool_calls + if decoded_message.tool_calls: + openai_tool_calls = [ + OpenAIChatCompletionToolCall( + # generate a uuid for the call id. This is the only inline provider that does this, so need to get creative. + id=f"call_{uuid.uuid4().hex[:24]}", + type="function", + function=OpenAIChatCompletionToolCallFunction( + name=str(tc.tool_name), + arguments=tc.arguments, + ), + ) + for tc in decoded_message.tool_calls + ] + + # Yield chunk with tool_calls + chunk = OpenAIChatCompletionChunk( + id=response_id, + object="chat.completion.chunk", + created=created, + model=params.model, + choices=[ + OpenAIChunkChoice( + index=0, + delta=OpenAIChoiceDelta( + role="assistant", + tool_calls=openai_tool_calls, + ), + finish_reason="", + logprobs=None, + ) + ], + ) + yield chunk + + finish_reason = "tool_calls" + else: + finish_reason = "stop" + + # Yield final chunk with finish_reason + final_chunk = OpenAIChatCompletionChunk( + id=response_id, + object="chat.completion.chunk", + created=created, + model=params.model, + choices=[ + OpenAIChunkChoice( + index=0, + delta=OpenAIChoiceDelta(), + finish_reason=finish_reason, + logprobs=None, + ) + ], + ) + yield final_chunk diff --git a/src/llama_stack/providers/inline/inference/meta_reference/model_parallel.py b/src/llama_stack/providers/inline/inference/meta_reference/model_parallel.py index 9d0295d65..f50b41f34 100644 --- a/src/llama_stack/providers/inline/inference/meta_reference/model_parallel.py +++ b/src/llama_stack/providers/inline/inference/meta_reference/model_parallel.py @@ -4,17 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from collections.abc import Callable, Generator -from copy import deepcopy +from collections.abc import Callable from functools import partial from typing import Any from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat -from llama_stack.providers.utils.inference.prompt_adapter import ( - ChatCompletionRequestWithRawContent, - CompletionRequestWithRawContent, -) from .parallel_utils import ModelParallelProcessGroup @@ -23,12 +18,14 @@ class ModelRunner: def __init__(self, llama): self.llama = llama - # the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()` def __call__(self, task: Any): - if task[0] == "chat_completion": - return self.llama.chat_completion(task[1]) + task_type = task[0] + if task_type == "chat_completion": + # task[1] is [params, raw_messages] + params, raw_messages = task[1] + return self.llama.chat_completion(params, raw_messages) else: - raise ValueError(f"Unexpected task type {task[0]}") + raise ValueError(f"Unexpected task type {task_type}") def init_model_cb( @@ -78,19 +75,3 @@ class LlamaModelParallelGenerator: def __exit__(self, exc_type, exc_value, exc_traceback): self.group.stop() - - def completion( - self, - request_batch: list[CompletionRequestWithRawContent], - ) -> Generator: - req_obj = deepcopy(request_batch) - gen = self.group.run_inference(("completion", req_obj)) - yield from gen - - def chat_completion( - self, - request_batch: list[ChatCompletionRequestWithRawContent], - ) -> Generator: - req_obj = deepcopy(request_batch) - gen = self.group.run_inference(("chat_completion", req_obj)) - yield from gen diff --git a/src/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py b/src/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py index bb6a1bd03..663e4793b 100644 --- a/src/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +++ b/src/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py @@ -33,10 +33,6 @@ from torch.distributed.launcher.api import LaunchConfig, elastic_launch from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import GenerationResult -from llama_stack.providers.utils.inference.prompt_adapter import ( - ChatCompletionRequestWithRawContent, - CompletionRequestWithRawContent, -) log = get_logger(name=__name__, category="inference") @@ -69,10 +65,7 @@ class CancelSentinel(BaseModel): class TaskRequest(BaseModel): type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request - task: tuple[ - str, - list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent], - ] + task: tuple[str, list] class TaskResponse(BaseModel): @@ -328,10 +321,7 @@ class ModelParallelProcessGroup: def run_inference( self, - req: tuple[ - str, - list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent], - ], + req: tuple[str, list], ) -> Generator: assert not self.running, "inference already running" diff --git a/src/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/src/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index cb72aa13a..e6dcf3ae7 100644 --- a/src/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/src/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -22,9 +22,6 @@ from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.utils.inference.embedding_mixin import ( SentenceTransformerEmbeddingMixin, ) -from llama_stack.providers.utils.inference.openai_compat import ( - OpenAIChatCompletionToLlamaStackMixin, -) from .config import SentenceTransformersInferenceConfig @@ -32,7 +29,6 @@ log = get_logger(name=__name__, category="inference") class SentenceTransformersInferenceImpl( - OpenAIChatCompletionToLlamaStackMixin, SentenceTransformerEmbeddingMixin, InferenceProvider, ModelsProtocolPrivate, diff --git a/src/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/src/llama_stack/providers/utils/inference/litellm_openai_mixin.py index 223497fb8..a793c499e 100644 --- a/src/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/src/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -11,9 +11,7 @@ from collections.abc import AsyncIterator import litellm from llama_stack.apis.inference import ( - ChatCompletionRequest, InferenceProvider, - JsonSchemaResponseFormat, OpenAIChatCompletion, OpenAIChatCompletionChunk, OpenAIChatCompletionRequestWithExtraBody, @@ -23,15 +21,11 @@ from llama_stack.apis.inference import ( OpenAIEmbeddingsRequestWithExtraBody, OpenAIEmbeddingsResponse, OpenAIEmbeddingUsage, - ToolChoice, ) from llama_stack.core.request_headers import NeedsRequestProviderData from llama_stack.log import get_logger from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry from llama_stack.providers.utils.inference.openai_compat import ( - convert_message_to_openai_dict_new, - convert_tooldef_to_openai_tool, - get_sampling_options, prepare_openai_completion_params, ) @@ -127,51 +121,6 @@ class LiteLLMOpenAIMixin( return schema - async def _get_params(self, request: ChatCompletionRequest) -> dict: - from typing import Any - - input_dict: dict[str, Any] = {} - - input_dict["messages"] = [ - await convert_message_to_openai_dict_new(m, download_images=self.download_images) for m in request.messages - ] - if fmt := request.response_format: - if not isinstance(fmt, JsonSchemaResponseFormat): - raise ValueError( - f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported." - ) - - # Convert to dict for manipulation - fmt_dict = dict(fmt.json_schema) - name = fmt_dict["title"] - del fmt_dict["title"] - fmt_dict["additionalProperties"] = False - - # Apply additionalProperties: False recursively to all objects - fmt_dict = self._add_additional_properties_recursive(fmt_dict) - - input_dict["response_format"] = { - "type": "json_schema", - "json_schema": { - "name": name, - "schema": fmt_dict, - "strict": self.json_schema_strict, - }, - } - if request.tools: - input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools] - if request.tool_config and (tool_choice := request.tool_config.tool_choice): - input_dict["tool_choice"] = tool_choice.value if isinstance(tool_choice, ToolChoice) else tool_choice - - return { - "model": request.model, - "api_key": self.get_api_key(), - "api_base": self.api_base, - **input_dict, - "stream": request.stream, - **get_sampling_options(request.sampling_params), - } - def get_api_key(self) -> str: provider_data = self.get_request_provider_data() key_field = self.provider_data_api_key_field diff --git a/src/llama_stack/providers/utils/inference/openai_compat.py b/src/llama_stack/providers/utils/inference/openai_compat.py index aabcb50f8..c2e6829e0 100644 --- a/src/llama_stack/providers/utils/inference/openai_compat.py +++ b/src/llama_stack/providers/utils/inference/openai_compat.py @@ -3,31 +3,14 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import json -import time -import uuid -import warnings -from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Iterable +from collections.abc import Iterable from typing import ( Any, ) -from openai import AsyncStream -from openai.types.chat import ( - ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage, -) -from openai.types.chat import ( - ChatCompletionChunk as OpenAIChatCompletionChunk, -) -from openai.types.chat import ( - ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam, -) from openai.types.chat import ( ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam, ) -from openai.types.chat import ( - ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam, -) try: from openai.types.chat import ( @@ -37,84 +20,24 @@ except ImportError: from openai.types.chat.chat_completion_message_tool_call import ( ChatCompletionMessageToolCall as OpenAIChatCompletionMessageFunctionToolCall, ) -from openai.types.chat import ( - ChatCompletionMessageParam as OpenAIChatCompletionMessage, -) from openai.types.chat import ( ChatCompletionMessageToolCall, ) -from openai.types.chat import ( - ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage, -) -from openai.types.chat import ( - ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage, -) -from openai.types.chat import ( - ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage, -) -from openai.types.chat.chat_completion import ( - Choice as OpenAIChoice, -) -from openai.types.chat.chat_completion import ( - ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs -) -from openai.types.chat.chat_completion_chunk import ( - Choice as OpenAIChatCompletionChunkChoice, -) -from openai.types.chat.chat_completion_chunk import ( - ChoiceDelta as OpenAIChoiceDelta, -) -from openai.types.chat.chat_completion_chunk import ( - ChoiceDeltaToolCall as OpenAIChoiceDeltaToolCall, -) -from openai.types.chat.chat_completion_chunk import ( - ChoiceDeltaToolCallFunction as OpenAIChoiceDeltaToolCallFunction, -) -from openai.types.chat.chat_completion_content_part_image_param import ( - ImageURL as OpenAIImageURL, -) -from openai.types.chat.chat_completion_message_tool_call import ( - Function as OpenAIFunction, -) from pydantic import BaseModel from llama_stack.apis.common.content_types import ( URL, ImageContentItem, - InterleavedContent, TextContentItem, - TextDelta, - ToolCallDelta, - ToolCallParseStatus, _URLOrData, ) from llama_stack.apis.inference import ( - ChatCompletionRequest, - ChatCompletionResponse, - ChatCompletionResponseEvent, - ChatCompletionResponseEventType, - ChatCompletionResponseStreamChunk, - CompletionMessage, - CompletionResponse, - CompletionResponseStreamChunk, GreedySamplingStrategy, JsonSchemaResponseFormat, - Message, - OpenAIChatCompletion, - OpenAIMessageParam, OpenAIResponseFormatParam, SamplingParams, - SystemMessage, - TokenLogProbs, - ToolChoice, - ToolConfig, - ToolResponseMessage, TopKSamplingStrategy, TopPSamplingStrategy, - UserMessage, -) -from llama_stack.apis.inference import ( - OpenAIChoice as OpenAIChatCompletionChoice, ) from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import ( @@ -123,10 +46,6 @@ from llama_stack.models.llama.datatypes import ( ToolCall, ToolDefinition, ) -from llama_stack.providers.utils.inference.prompt_adapter import ( - convert_image_content_to_url, - decode_assistant_message, -) logger = get_logger(name=__name__, category="providers::utils") @@ -213,345 +132,6 @@ def get_stop_reason(finish_reason: str) -> StopReason: return StopReason.out_of_tokens -def convert_openai_completion_logprobs( - logprobs: OpenAICompatLogprobs | None, -) -> list[TokenLogProbs] | None: - if not logprobs: - return None - if hasattr(logprobs, "top_logprobs") and logprobs.top_logprobs: - return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs] - - # Together supports logprobs with top_k=1 only. This means for each token position, - # they return only the logprobs for the selected token (vs. the top n most likely tokens). - # Here we construct the response by matching the selected token with the logprobs. - if logprobs.tokens and logprobs.token_logprobs: - return [ - TokenLogProbs(logprobs_by_token={token: token_lp}) - for token, token_lp in zip(logprobs.tokens, logprobs.token_logprobs, strict=False) - ] - return None - - -def convert_openai_completion_logprobs_stream(text: str, logprobs: float | OpenAICompatLogprobs | None): - if logprobs is None: - return None - if isinstance(logprobs, float): - # Adapt response from Together CompletionChoicesChunk - return [TokenLogProbs(logprobs_by_token={text: logprobs})] - if hasattr(logprobs, "top_logprobs") and logprobs.top_logprobs: - return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs] - return None - - -def process_completion_response( - response: OpenAICompatCompletionResponse, -) -> CompletionResponse: - choice = response.choices[0] - text = choice.text or "" - # drop suffix if present and return stop reason as end of turn - if text.endswith("<|eot_id|>"): - return CompletionResponse( - stop_reason=StopReason.end_of_turn, - content=text[: -len("<|eot_id|>")], - logprobs=convert_openai_completion_logprobs(choice.logprobs), - ) - # drop suffix if present and return stop reason as end of message - if text.endswith("<|eom_id|>"): - return CompletionResponse( - stop_reason=StopReason.end_of_message, - content=text[: -len("<|eom_id|>")], - logprobs=convert_openai_completion_logprobs(choice.logprobs), - ) - return CompletionResponse( - stop_reason=get_stop_reason(choice.finish_reason or "stop"), - content=text, - logprobs=convert_openai_completion_logprobs(choice.logprobs), - ) - - -def process_chat_completion_response( - response: OpenAICompatCompletionResponse, - request: ChatCompletionRequest, -) -> ChatCompletionResponse: - choice = response.choices[0] - if choice.finish_reason == "tool_calls": - if not hasattr(choice, "message") or not choice.message or not choice.message.tool_calls: # type: ignore[attr-defined] # OpenAICompatCompletionChoice is runtime duck-typed - raise ValueError("Tool calls are not present in the response") - - tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls] # type: ignore[attr-defined] # OpenAICompatCompletionChoice is runtime duck-typed - if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls): - # If we couldn't parse a tool call, jsonify the tool calls and return them - return ChatCompletionResponse( - completion_message=CompletionMessage( - stop_reason=StopReason.end_of_turn, - content=json.dumps(tool_calls, default=lambda x: x.model_dump()), - ), - logprobs=None, - ) - else: - # Otherwise, return tool calls as normal - # Filter to only valid ToolCall objects - valid_tool_calls = [tc for tc in tool_calls if isinstance(tc, ToolCall)] - return ChatCompletionResponse( - completion_message=CompletionMessage( - tool_calls=valid_tool_calls, - stop_reason=StopReason.end_of_turn, - # Content is not optional - content="", - ), - logprobs=None, - ) - - # TODO: This does not work well with tool calls for vLLM remote provider - # Ref: https://github.com/meta-llama/llama-stack/issues/1058 - raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason or "stop")) - - # NOTE: If we do not set tools in chat-completion request, we should not - # expect the ToolCall in the response. Instead, we should return the raw - # response from the model. - if raw_message.tool_calls: - if not request.tools: - raw_message.tool_calls = [] - raw_message.content = text_from_choice(choice) - else: - # only return tool_calls if provided in the request - new_tool_calls = [] - request_tools = {t.tool_name: t for t in request.tools} - for t in raw_message.tool_calls: - if t.tool_name in request_tools: - new_tool_calls.append(t) - else: - logger.warning(f"Tool {t.tool_name} not found in request tools") - - if len(new_tool_calls) < len(raw_message.tool_calls): - raw_message.tool_calls = new_tool_calls - raw_message.content = text_from_choice(choice) - - return ChatCompletionResponse( - completion_message=CompletionMessage( - content=raw_message.content, # type: ignore[arg-type] # decode_assistant_message returns Union[str, InterleavedContent] - stop_reason=raw_message.stop_reason or StopReason.end_of_turn, - tool_calls=raw_message.tool_calls, - ), - logprobs=None, - ) - - -async def process_completion_stream_response( - stream: AsyncGenerator[OpenAICompatCompletionResponse, None], -) -> AsyncGenerator[CompletionResponseStreamChunk, None]: - stop_reason = None - - async for chunk in stream: - choice = chunk.choices[0] - finish_reason = choice.finish_reason - - text = text_from_choice(choice) - if text == "<|eot_id|>": - stop_reason = StopReason.end_of_turn - text = "" - continue - elif text == "<|eom_id|>": - stop_reason = StopReason.end_of_message - text = "" - continue - yield CompletionResponseStreamChunk( - delta=text, - stop_reason=stop_reason, - logprobs=convert_openai_completion_logprobs_stream(text, choice.logprobs), - ) - if finish_reason: - if finish_reason in ["stop", "eos", "eos_token"]: - stop_reason = StopReason.end_of_turn - elif finish_reason == "length": - stop_reason = StopReason.out_of_tokens - break - - yield CompletionResponseStreamChunk( - delta="", - stop_reason=stop_reason, - ) - - -async def process_chat_completion_stream_response( - stream: AsyncGenerator[OpenAICompatCompletionResponse, None], - request: ChatCompletionRequest, -) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.start, - delta=TextDelta(text=""), - ) - ) - - buffer = "" - ipython = False - stop_reason = None - - async for chunk in stream: - choice = chunk.choices[0] - finish_reason = choice.finish_reason - - if finish_reason: - if stop_reason is None and finish_reason in ["stop", "eos", "eos_token"]: - stop_reason = StopReason.end_of_turn - elif stop_reason is None and finish_reason == "length": - stop_reason = StopReason.out_of_tokens - break - - text = text_from_choice(choice) - if not text: - # Sometimes you get empty chunks from providers - continue - - # check if its a tool call ( aka starts with <|python_tag|> ) - if not ipython and text.startswith("<|python_tag|>"): - ipython = True - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - tool_call="", - parse_status=ToolCallParseStatus.started, - ), - ) - ) - buffer += text - continue - - if text == "<|eot_id|>": - stop_reason = StopReason.end_of_turn - text = "" - continue - elif text == "<|eom_id|>": - stop_reason = StopReason.end_of_message - text = "" - continue - - if ipython: - buffer += text - delta = ToolCallDelta( - tool_call=text, - parse_status=ToolCallParseStatus.in_progress, - ) - - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=delta, - stop_reason=stop_reason, - ) - ) - else: - buffer += text - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=TextDelta(text=text), - stop_reason=stop_reason, - ) - ) - - # parse tool calls and report errors - message = decode_assistant_message(buffer, stop_reason or StopReason.end_of_turn) - - parsed_tool_calls = len(message.tool_calls) > 0 - if ipython and not parsed_tool_calls: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - tool_call="", - parse_status=ToolCallParseStatus.failed, - ), - stop_reason=stop_reason, - ) - ) - - request_tools = {t.tool_name: t for t in (request.tools or [])} - for tool_call in message.tool_calls: - if tool_call.tool_name in request_tools: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - tool_call=tool_call, - parse_status=ToolCallParseStatus.succeeded, - ), - stop_reason=stop_reason, - ) - ) - else: - logger.warning(f"Tool {tool_call.tool_name} not found in request tools") - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - # Parsing tool call failed due to tool call not being found in request tools, - # We still add the raw message text inside tool_call for responding back to the user - tool_call=buffer, - parse_status=ToolCallParseStatus.failed, - ), - stop_reason=stop_reason, - ) - ) - - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.complete, - delta=TextDelta(text=""), - stop_reason=stop_reason, - ) - ) - - -async def convert_message_to_openai_dict(message: Message, download: bool = False) -> dict: - async def _convert_content(content) -> dict: - if isinstance(content, ImageContentItem): - return { - "type": "image_url", - "image_url": { - "url": await convert_image_content_to_url(content, download=download), - }, - } - else: - text = content.text if isinstance(content, TextContentItem) else content - assert isinstance(text, str) - return {"type": "text", "text": text} - - if isinstance(message.content, list): - content = [await _convert_content(c) for c in message.content] - else: - content = [await _convert_content(message.content)] - - result = { - "role": message.role, - "content": content, - } - - if hasattr(message, "tool_calls") and message.tool_calls: - tool_calls_list = [] - for tc in message.tool_calls: - # The tool.tool_name can be a str or a BuiltinTool enum. If - # it's the latter, convert to a string. - tool_name = tc.tool_name - if isinstance(tool_name, BuiltinTool): - tool_name = tool_name.value - - tool_calls_list.append( - { - "id": tc.call_id, - "type": "function", - "function": { - "name": tool_name, - "arguments": tc.arguments, - }, - } - ) - result["tool_calls"] = tool_calls_list # type: ignore[assignment] # dict allows Any value, stricter type expected - return result - - class UnparseableToolCall(BaseModel): """ A ToolCall with arguments that are not valid JSON. @@ -563,112 +143,6 @@ class UnparseableToolCall(BaseModel): arguments: str = "" -async def convert_message_to_openai_dict_new( - message: Message | dict, - download_images: bool = False, -) -> OpenAIChatCompletionMessage: - """ - Convert a Message to an OpenAI API-compatible dictionary. - """ - # users can supply a dict instead of a Message object, we'll - # convert it to a Message object and proceed with some type safety. - if isinstance(message, dict): - if "role" not in message: - raise ValueError("role is required in message") - if message["role"] == "user": - message = UserMessage(**message) - elif message["role"] == "assistant": - message = CompletionMessage(**message) - elif message["role"] == "tool": - message = ToolResponseMessage(**message) - elif message["role"] == "system": - message = SystemMessage(**message) - else: - raise ValueError(f"Unsupported message role: {message['role']}") - - # Map Llama Stack spec to OpenAI spec - - # str -> str - # {"type": "text", "text": ...} -> {"type": "text", "text": ...} - # {"type": "image", "image": {"url": {"uri": ...}}} -> {"type": "image_url", "image_url": {"url": ...}} - # {"type": "image", "image": {"data": ...}} -> {"type": "image_url", "image_url": {"url": "data:image/?;base64,..."}} - # List[...] -> List[...] - async def _convert_message_content( - content: InterleavedContent, - ) -> str | Iterable[OpenAIChatCompletionContentPartParam]: - async def impl( - content_: InterleavedContent, - ) -> str | OpenAIChatCompletionContentPartParam | list[OpenAIChatCompletionContentPartParam]: - # Llama Stack and OpenAI spec match for str and text input - if isinstance(content_, str): - return content_ - elif isinstance(content_, TextContentItem): - return OpenAIChatCompletionContentPartTextParam( - type="text", - text=content_.text, - ) - elif isinstance(content_, ImageContentItem): - return OpenAIChatCompletionContentPartImageParam( - type="image_url", - image_url=OpenAIImageURL( - url=await convert_image_content_to_url(content_, download=download_images) - ), - ) - elif isinstance(content_, list): - return [await impl(item) for item in content_] # type: ignore[misc] # recursive list comprehension confuses mypy's type narrowing - else: - raise ValueError(f"Unsupported content type: {type(content_)}") - - ret = await impl(content) - - # OpenAI*Message expects a str or list - if isinstance(ret, str) or isinstance(ret, list): - return ret - else: - return [ret] - - out: OpenAIChatCompletionMessage - if isinstance(message, UserMessage): - out = OpenAIChatCompletionUserMessage( - role="user", - content=await _convert_message_content(message.content), - ) - elif isinstance(message, CompletionMessage): - tool_calls = [ - OpenAIChatCompletionMessageFunctionToolCall( - id=tool.call_id, - function=OpenAIFunction( - name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value), - arguments=tool.arguments, # Already a JSON string, don't double-encode - ), - type="function", - ) - for tool in (message.tool_calls or []) - ] - params = {} - if tool_calls: - params["tool_calls"] = tool_calls - out = OpenAIChatCompletionAssistantMessage( - role="assistant", - content=await _convert_message_content(message.content), - **params, # type: ignore[typeddict-item] # tool_calls dict expansion conflicts with TypedDict optional field - ) - elif isinstance(message, ToolResponseMessage): - out = OpenAIChatCompletionToolMessage( - role="tool", - tool_call_id=message.call_id, - content=await _convert_message_content(message.content), # type: ignore[typeddict-item] # content union type incompatible with TypedDict str requirement - ) - elif isinstance(message, SystemMessage): - out = OpenAIChatCompletionSystemMessage( - role="system", - content=await _convert_message_content(message.content), # type: ignore[typeddict-item] # content union type incompatible with TypedDict str requirement - ) - else: - raise ValueError(f"Unsupported message type: {type(message)}") - - return out - - def convert_tool_call( tool_call: ChatCompletionMessageToolCall, ) -> ToolCall | UnparseableToolCall: @@ -817,17 +291,6 @@ def _convert_openai_finish_reason(finish_reason: str) -> StopReason: }.get(finish_reason, StopReason.end_of_turn) -def _convert_openai_request_tool_config(tool_choice: str | dict[str, Any] | None = None) -> ToolConfig: - tool_config = ToolConfig() - if tool_choice: - try: - tool_choice = ToolChoice(tool_choice) # type: ignore[assignment] # reassigning to enum narrows union but mypy can't track after exception - except ValueError: - pass - tool_config.tool_choice = tool_choice # type: ignore[assignment] # ToolConfig.tool_choice accepts Union[ToolChoice, dict] but mypy tracks narrower type - return tool_config - - def _convert_openai_request_tools(tools: list[dict[str, Any]] | None = None) -> list[ToolDefinition]: lls_tools: list[ToolDefinition] = [] if not tools: @@ -898,40 +361,6 @@ def _convert_openai_tool_calls( ] -def _convert_openai_logprobs( - logprobs: OpenAIChoiceLogprobs, -) -> list[TokenLogProbs] | None: - """ - Convert an OpenAI ChoiceLogprobs into a list of TokenLogProbs. - - OpenAI ChoiceLogprobs: - content: Optional[List[ChatCompletionTokenLogprob]] - - OpenAI ChatCompletionTokenLogprob: - token: str - logprob: float - top_logprobs: List[TopLogprob] - - OpenAI TopLogprob: - token: str - logprob: float - - -> - - TokenLogProbs: - logprobs_by_token: Dict[str, float] - - token, logprob - - """ - if not logprobs or not logprobs.content: - return None - - return [ - TokenLogProbs(logprobs_by_token={logprobs.token: logprobs.logprob for logprobs in content.top_logprobs}) - for content in logprobs.content - ] - - def _convert_openai_sampling_params( max_tokens: int | None = None, temperature: float | None = None, @@ -956,37 +385,6 @@ def _convert_openai_sampling_params( return sampling_params -def openai_messages_to_messages( - messages: list[OpenAIMessageParam], -) -> list[Message]: - """ - Convert a list of OpenAIChatCompletionMessage into a list of Message. - """ - converted_messages: list[Message] = [] - for message in messages: - converted_message: Message - if message.role == "system": - converted_message = SystemMessage(content=openai_content_to_content(message.content)) # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types - elif message.role == "user": - converted_message = UserMessage(content=openai_content_to_content(message.content)) # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types - elif message.role == "assistant": - converted_message = CompletionMessage( - content=openai_content_to_content(message.content), # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types - tool_calls=_convert_openai_tool_calls(message.tool_calls) if message.tool_calls else [], # type: ignore[arg-type] # OpenAI tool_calls type incompatible with conversion function - stop_reason=StopReason.end_of_turn, - ) - elif message.role == "tool": - converted_message = ToolResponseMessage( - role="tool", - call_id=message.tool_call_id, - content=openai_content_to_content(message.content), # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types - ) - else: - raise ValueError(f"Unknown role {message.role}") - converted_messages.append(converted_message) - return converted_messages - - def openai_content_to_content(content: str | Iterable[OpenAIChatCompletionContentPartParam] | None): if content is None: return "" @@ -1005,216 +403,6 @@ def openai_content_to_content(content: str | Iterable[OpenAIChatCompletionConten raise ValueError(f"Unknown content type: {content}") -def convert_openai_chat_completion_choice( - choice: OpenAIChoice, -) -> ChatCompletionResponse: - """ - Convert an OpenAI Choice into a ChatCompletionResponse. - - OpenAI Choice: - message: ChatCompletionMessage - finish_reason: str - logprobs: Optional[ChoiceLogprobs] - - OpenAI ChatCompletionMessage: - role: Literal["assistant"] - content: Optional[str] - tool_calls: Optional[List[ChatCompletionMessageToolCall]] - - -> - - ChatCompletionResponse: - completion_message: CompletionMessage - logprobs: Optional[List[TokenLogProbs]] - - CompletionMessage: - role: Literal["assistant"] - content: str | ImageMedia | List[str | ImageMedia] - stop_reason: StopReason - tool_calls: List[ToolCall] - - class StopReason(Enum): - end_of_turn = "end_of_turn" - end_of_message = "end_of_message" - out_of_tokens = "out_of_tokens" - """ - assert hasattr(choice, "message") and choice.message, "error in server response: message not found" - assert hasattr(choice, "finish_reason") and choice.finish_reason, ( - "error in server response: finish_reason not found" - ) - - return ChatCompletionResponse( - completion_message=CompletionMessage( - content=choice.message.content or "", # CompletionMessage content is not optional - stop_reason=_convert_openai_finish_reason(choice.finish_reason), - tool_calls=_convert_openai_tool_calls(choice.message.tool_calls) if choice.message.tool_calls else [], # type: ignore[arg-type] # OpenAI tool_calls Optional type broadens union - ), - logprobs=_convert_openai_logprobs(getattr(choice, "logprobs", None)), # type: ignore[arg-type] # getattr returns Any, can't narrow without inspection - ) - - -async def convert_openai_chat_completion_stream( - stream: AsyncStream[OpenAIChatCompletionChunk], - enable_incremental_tool_calls: bool, -) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]: - """ - Convert a stream of OpenAI chat completion chunks into a stream - of ChatCompletionResponseStreamChunk. - """ - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.start, - delta=TextDelta(text=""), - ) - ) - event_type = ChatCompletionResponseEventType.progress - - stop_reason = None - tool_call_idx_to_buffer = {} - - async for chunk in stream: - choice = chunk.choices[0] # assuming only one choice per chunk - - # we assume there's only one finish_reason in the stream - stop_reason = _convert_openai_finish_reason(choice.finish_reason) if choice.finish_reason else stop_reason - logprobs = getattr(choice, "logprobs", None) - - # if there's a tool call, emit an event for each tool in the list - # if tool call and content, emit both separately - if choice.delta.tool_calls: - # the call may have content and a tool call. ChatCompletionResponseEvent - # does not support both, so we emit the content first - if choice.delta.content: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=event_type, - delta=TextDelta(text=choice.delta.content), - logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result - ) - ) - - # it is possible to have parallel tool calls in stream, but - # ChatCompletionResponseEvent only supports one per stream - if len(choice.delta.tool_calls) > 1: - warnings.warn( - "multiple tool calls found in a single delta, using the first, ignoring the rest", - stacklevel=2, - ) - - if not enable_incremental_tool_calls: - for tool_call in choice.delta.tool_calls: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=event_type, - delta=ToolCallDelta( - tool_call=_convert_openai_tool_calls([tool_call])[0], # type: ignore[arg-type, list-item] # delta tool_call type differs from complete tool_call - parse_status=ToolCallParseStatus.succeeded, - ), - logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result - ) - ) - else: - for tool_call in choice.delta.tool_calls: - idx = tool_call.index if hasattr(tool_call, "index") else 0 - - if idx not in tool_call_idx_to_buffer: - tool_call_idx_to_buffer[idx] = { - "call_id": tool_call.id, - "name": None, - "arguments": "", - "content": "", - } - - buffer = tool_call_idx_to_buffer[idx] - - if tool_call.function: - if tool_call.function.name: - buffer["name"] = tool_call.function.name - delta = f"{buffer['name']}(" - if buffer["content"] is not None: - buffer["content"] += delta - - if tool_call.function.arguments: - delta = tool_call.function.arguments - if buffer["arguments"] is not None and delta: - buffer["arguments"] += delta - if buffer["content"] is not None and delta: - buffer["content"] += delta - - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=event_type, - delta=ToolCallDelta( - tool_call=delta, - parse_status=ToolCallParseStatus.in_progress, - ), - logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result - ) - ) - elif choice.delta.content: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=event_type, - delta=TextDelta(text=choice.delta.content or ""), - logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result - ) - ) - - for idx, buffer in tool_call_idx_to_buffer.items(): - logger.debug(f"toolcall_buffer[{idx}]: {buffer}") - if buffer["name"]: - delta = ")" - if buffer["content"] is not None: - buffer["content"] += delta - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=event_type, - delta=ToolCallDelta( - tool_call=delta, - parse_status=ToolCallParseStatus.in_progress, - ), - logprobs=None, - ) - ) - - try: - parsed_tool_call = ToolCall( - call_id=buffer["call_id"] or "", - tool_name=buffer["name"] or "", - arguments=buffer["arguments"] or "", - ) - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - tool_call=parsed_tool_call, # type: ignore[arg-type] # ToolCallDelta.tool_call accepts Union[str, ToolCall] - parse_status=ToolCallParseStatus.succeeded, - ), - stop_reason=stop_reason, - ) - ) - except json.JSONDecodeError as e: - print(f"Failed to parse arguments: {e}") - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - tool_call=buffer["content"], # type: ignore[arg-type] # ToolCallDelta.tool_call accepts Union[str, ToolCall] - parse_status=ToolCallParseStatus.failed, - ), - stop_reason=stop_reason, - ) - ) - - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.complete, - delta=TextDelta(text=""), - stop_reason=stop_reason, - ) - ) - - async def prepare_openai_completion_params(**params): async def _prepare_value(value: Any) -> Any: new_value = value @@ -1233,163 +421,6 @@ async def prepare_openai_completion_params(**params): return completion_params -class OpenAIChatCompletionToLlamaStackMixin: - async def openai_chat_completion( - self, - model: str, - messages: list[OpenAIMessageParam], - frequency_penalty: float | None = None, - function_call: str | dict[str, Any] | None = None, - functions: list[dict[str, Any]] | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_completion_tokens: int | None = None, - max_tokens: int | None = None, - n: int | None = None, - parallel_tool_calls: bool | None = None, - presence_penalty: float | None = None, - response_format: OpenAIResponseFormatParam | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - tool_choice: str | dict[str, Any] | None = None, - tools: list[dict[str, Any]] | None = None, - top_logprobs: int | None = None, - top_p: float | None = None, - user: str | None = None, - ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: - messages = openai_messages_to_messages(messages) # type: ignore[assignment] # converted from OpenAI to LlamaStack message format - response_format = _convert_openai_request_response_format(response_format) - sampling_params = _convert_openai_sampling_params( - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - ) - tool_config = _convert_openai_request_tool_config(tool_choice) - - tools = _convert_openai_request_tools(tools) # type: ignore[assignment] # converted from OpenAI to LlamaStack tool format - if tool_config.tool_choice == ToolChoice.none: - tools = [] # type: ignore[assignment] # empty list narrows return type but mypy tracks broader type - - outstanding_responses = [] - # "n" is the number of completions to generate per prompt - n = n or 1 - for _i in range(0, n): - response = self.chat_completion( # type: ignore[attr-defined] # mixin expects class to implement chat_completion - model_id=model, - messages=messages, - sampling_params=sampling_params, - response_format=response_format, - stream=stream, - tool_config=tool_config, - tools=tools, - ) - outstanding_responses.append(response) - - if stream: - return OpenAIChatCompletionToLlamaStackMixin._process_stream_response(self, model, outstanding_responses) # type: ignore[no-any-return] # mixin async generator return type too complex for mypy - - return await OpenAIChatCompletionToLlamaStackMixin._process_non_stream_response( - self, model, outstanding_responses - ) - - async def _process_stream_response( - self, - model: str, - outstanding_responses: list[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]], - ): - id = f"chatcmpl-{uuid.uuid4()}" - for i, outstanding_response in enumerate(outstanding_responses): - response = await outstanding_response - async for chunk in response: - event = chunk.event - finish_reason = ( - _convert_stop_reason_to_openai_finish_reason(event.stop_reason) if event.stop_reason else None - ) - - if isinstance(event.delta, TextDelta): - text_delta = event.delta.text - delta = OpenAIChoiceDelta(content=text_delta) - yield OpenAIChatCompletionChunk( - id=id, - choices=[OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)], # type: ignore[arg-type] # finish_reason Optional[str] incompatible with Literal union - created=int(time.time()), - model=model, - object="chat.completion.chunk", - ) - elif isinstance(event.delta, ToolCallDelta): - if event.delta.parse_status == ToolCallParseStatus.succeeded: - tool_call = event.delta.tool_call - if isinstance(tool_call, str): - continue - - # First chunk includes full structure - openai_tool_call = OpenAIChoiceDeltaToolCall( - index=0, - id=tool_call.call_id, - function=OpenAIChoiceDeltaToolCallFunction( - name=tool_call.tool_name - if isinstance(tool_call.tool_name, str) - else tool_call.tool_name.value, # type: ignore[arg-type] # enum .value extraction on Union confuses mypy - arguments="", - ), - ) - delta = OpenAIChoiceDelta(tool_calls=[openai_tool_call]) - yield OpenAIChatCompletionChunk( - id=id, - choices=[ - OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta) # type: ignore[arg-type] # finish_reason Optional[str] incompatible with Literal union - ], - created=int(time.time()), - model=model, - object="chat.completion.chunk", - ) - # arguments - openai_tool_call = OpenAIChoiceDeltaToolCall( - index=0, - function=OpenAIChoiceDeltaToolCallFunction( - arguments=tool_call.arguments, - ), - ) - delta = OpenAIChoiceDelta(tool_calls=[openai_tool_call]) - yield OpenAIChatCompletionChunk( - id=id, - choices=[ - OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta) # type: ignore[arg-type] # finish_reason Optional[str] incompatible with Literal union - ], - created=int(time.time()), - model=model, - object="chat.completion.chunk", - ) - - async def _process_non_stream_response( - self, model: str, outstanding_responses: list[Awaitable[ChatCompletionResponse]] - ) -> OpenAIChatCompletion: - choices: list[OpenAIChatCompletionChoice] = [] - for outstanding_response in outstanding_responses: - response = await outstanding_response - completion_message = response.completion_message - message = await convert_message_to_openai_dict_new(completion_message) - finish_reason = _convert_stop_reason_to_openai_finish_reason(completion_message.stop_reason) - - choice = OpenAIChatCompletionChoice( - index=len(choices), - message=message, # type: ignore[arg-type] # OpenAIChatCompletionMessage union incompatible with narrower Message type - finish_reason=finish_reason, - ) - choices.append(choice) # type: ignore[arg-type] # OpenAIChatCompletionChoice type annotation mismatch - - return OpenAIChatCompletion( - id=f"chatcmpl-{uuid.uuid4()}", - choices=choices, # type: ignore[arg-type] # list[OpenAIChatCompletionChoice] union incompatible - created=int(time.time()), - model=model, - object="chat.completion", - ) - - def prepare_openai_embeddings_params( model: str, input: str | list[str], diff --git a/src/llama_stack/providers/utils/inference/prompt_adapter.py b/src/llama_stack/providers/utils/inference/prompt_adapter.py index d06b7454d..35a7b3484 100644 --- a/src/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/src/llama_stack/providers/utils/inference/prompt_adapter.py @@ -21,19 +21,18 @@ from llama_stack.apis.common.content_types import ( TextContentItem, ) from llama_stack.apis.inference import ( - ChatCompletionRequest, CompletionRequest, - Message, + OpenAIAssistantMessageParam, OpenAIChatCompletionContentPartImageParam, OpenAIChatCompletionContentPartTextParam, OpenAIFile, + OpenAIMessageParam, + OpenAISystemMessageParam, + OpenAIToolMessageParam, + OpenAIUserMessageParam, ResponseFormat, ResponseFormatType, - SystemMessage, - SystemMessageBehavior, ToolChoice, - ToolDefinition, - UserMessage, ) from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import ( @@ -42,33 +41,19 @@ from llama_stack.models.llama.datatypes import ( RawMediaItem, RawMessage, RawTextItem, - Role, StopReason, + ToolCall, + ToolDefinition, ToolPromptFormat, ) from llama_stack.models.llama.llama3.chat_format import ChatFormat -from llama_stack.models.llama.llama3.prompt_templates import ( - BuiltinToolGenerator, - FunctionTagCustomToolGenerator, - JsonCustomToolGenerator, - PythonListCustomToolGenerator, - SystemDefaultGenerator, -) from llama_stack.models.llama.llama3.tokenizer import Tokenizer -from llama_stack.models.llama.llama4.prompt_templates.system_prompts import ( - PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4, -) from llama_stack.models.llama.sku_list import resolve_model from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal -from llama_stack.providers.utils.inference import supported_inference_models log = get_logger(name=__name__, category="providers::utils") -class ChatCompletionRequestWithRawContent(ChatCompletionRequest): - messages: list[RawMessage] - - class CompletionRequestWithRawContent(CompletionRequest): content: RawContent @@ -103,28 +88,6 @@ def interleaved_content_as_str( return _process(content) -async def convert_request_to_raw( - request: ChatCompletionRequest | CompletionRequest, -) -> ChatCompletionRequestWithRawContent | CompletionRequestWithRawContent: - if isinstance(request, ChatCompletionRequest): - messages = [] - for m in request.messages: - content = await interleaved_content_convert_to_raw(m.content) - d = m.model_dump() - d["content"] = content - messages.append(RawMessage(**d)) - - d = request.model_dump() - d["messages"] = messages - request = ChatCompletionRequestWithRawContent(**d) - else: - d = request.model_dump() - d["content"] = await interleaved_content_convert_to_raw(request.content) - request = CompletionRequestWithRawContent(**d) - - return request - - async def interleaved_content_convert_to_raw( content: InterleavedContent, ) -> RawContent: @@ -171,6 +134,36 @@ async def interleaved_content_convert_to_raw( return await _localize_single(content) +async def convert_openai_message_to_raw_message(message: OpenAIMessageParam) -> RawMessage: + """Convert OpenAI message format to RawMessage format used by Llama formatters.""" + if isinstance(message, OpenAIUserMessageParam): + content = await interleaved_content_convert_to_raw(message.content) # type: ignore[arg-type] + return RawMessage(role="user", content=content) + elif isinstance(message, OpenAISystemMessageParam): + content = await interleaved_content_convert_to_raw(message.content) # type: ignore[arg-type] + return RawMessage(role="system", content=content) + elif isinstance(message, OpenAIAssistantMessageParam): + content = await interleaved_content_convert_to_raw(message.content or "") # type: ignore[arg-type] + tool_calls = [] + if message.tool_calls: + for tc in message.tool_calls: + if tc.function: + tool_calls.append( + ToolCall( + call_id=tc.id or "", + tool_name=tc.function.name or "", + arguments=tc.function.arguments or "{}", + ) + ) + return RawMessage(role="assistant", content=content, tool_calls=tool_calls) + elif isinstance(message, OpenAIToolMessageParam): + content = await interleaved_content_convert_to_raw(message.content) # type: ignore[arg-type] + return RawMessage(role="tool", content=content) + else: + # Handle OpenAIDeveloperMessageParam if needed + raise ValueError(f"Unsupported message type: {type(message)}") + + def content_has_media(content: InterleavedContent): def _has_media_content(c): return isinstance(c, ImageContentItem) @@ -181,17 +174,6 @@ def content_has_media(content: InterleavedContent): return _has_media_content(content) -def messages_have_media(messages: list[Message]): - return any(content_has_media(m.content) for m in messages) - - -def request_has_media(request: ChatCompletionRequest | CompletionRequest): - if isinstance(request, ChatCompletionRequest): - return messages_have_media(request.messages) - else: - return content_has_media(request.content) - - async def localize_image_content(uri: str) -> tuple[bytes, str] | None: if uri.startswith("http"): async with httpx.AsyncClient() as client: @@ -253,79 +235,6 @@ def augment_content_with_response_format_prompt(response_format, content): return content -async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llama_model: str) -> str: - messages = chat_completion_request_to_messages(request, llama_model) - request.messages = messages - request = await convert_request_to_raw(request) - - formatter = ChatFormat(tokenizer=Tokenizer.get_instance()) - model_input = formatter.encode_dialog_prompt( - request.messages, - tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model), - ) - return formatter.tokenizer.decode(model_input.tokens) - - -async def chat_completion_request_to_model_input_info( - request: ChatCompletionRequest, llama_model: str -) -> tuple[str, int]: - messages = chat_completion_request_to_messages(request, llama_model) - request.messages = messages - request = await convert_request_to_raw(request) - - formatter = ChatFormat(tokenizer=Tokenizer.get_instance()) - model_input = formatter.encode_dialog_prompt( - request.messages, - tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model), - ) - return ( - formatter.tokenizer.decode(model_input.tokens), - len(model_input.tokens), - ) - - -def chat_completion_request_to_messages( - request: ChatCompletionRequest, - llama_model: str, -) -> list[Message]: - """Reads chat completion request and augments the messages to handle tools. - For eg. for llama_3_1, add system message with the appropriate tools or - add user messsage for custom tools, etc. - """ - assert llama_model is not None, "llama_model is required" - model = resolve_model(llama_model) - if model is None: - log.error(f"Could not resolve model {llama_model}") - return request.messages - - allowed_models = supported_inference_models() - descriptors = [m.descriptor() for m in allowed_models] - if model.descriptor() not in descriptors: - log.error(f"Unsupported inference model? {model.descriptor()}") - return request.messages - - if model.model_family == ModelFamily.llama3_1 or ( - model.model_family == ModelFamily.llama3_2 and is_multimodal(model.core_model_id) - ): - # llama3.1 and llama3.2 multimodal models follow the same tool prompt format - messages = augment_messages_for_tools_llama_3_1(request) - elif model.model_family in ( - ModelFamily.llama3_2, - ModelFamily.llama3_3, - ): - # llama3.2, llama3.3 follow the same tool prompt format - messages = augment_messages_for_tools_llama(request, PythonListCustomToolGenerator) - elif model.model_family == ModelFamily.llama4: - messages = augment_messages_for_tools_llama(request, PythonListCustomToolGeneratorLlama4) - else: - messages = request.messages - - if fmt_prompt := response_format_prompt(request.response_format): - messages.append(UserMessage(content=fmt_prompt)) - - return messages - - def response_format_prompt(fmt: ResponseFormat | None): if not fmt: return None @@ -338,128 +247,6 @@ def response_format_prompt(fmt: ResponseFormat | None): raise ValueError(f"Unknown response format {fmt.type}") -def augment_messages_for_tools_llama_3_1( - request: ChatCompletionRequest, -) -> list[Message]: - existing_messages = request.messages - existing_system_message = None - if existing_messages[0].role == Role.system.value: - existing_system_message = existing_messages.pop(0) - - assert existing_messages[0].role != Role.system.value, "Should only have 1 system message" - - messages = [] - - default_gen = SystemDefaultGenerator() - default_template = default_gen.gen() - - sys_content = "" - - tool_template = None - if request.tools: - tool_gen = BuiltinToolGenerator() - tool_template = tool_gen.gen(request.tools) - - sys_content += tool_template.render() - sys_content += "\n" - - sys_content += default_template.render() - - if existing_system_message: - # TODO: this fn is needed in many places - def _process(c): - if isinstance(c, str): - return c - else: - return "" - - sys_content += "\n" - - if isinstance(existing_system_message.content, str): - sys_content += _process(existing_system_message.content) - elif isinstance(existing_system_message.content, list): - sys_content += "\n".join([_process(c) for c in existing_system_message.content]) - - tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools) - if tool_choice_prompt: - sys_content += "\n" + tool_choice_prompt - - messages.append(SystemMessage(content=sys_content)) - - has_custom_tools = request.tools is not None and any(isinstance(dfn.tool_name, str) for dfn in request.tools) - if has_custom_tools: - fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.json - if fmt == ToolPromptFormat.json: - tool_gen = JsonCustomToolGenerator() - elif fmt == ToolPromptFormat.function_tag: - tool_gen = FunctionTagCustomToolGenerator() - else: - raise ValueError(f"Non supported ToolPromptFormat {fmt}") - - custom_tools = [t for t in request.tools if isinstance(t.tool_name, str)] - custom_template = tool_gen.gen(custom_tools) - messages.append(UserMessage(content=custom_template.render())) - - # Add back existing messages from the request - messages += existing_messages - - return messages - - -def augment_messages_for_tools_llama( - request: ChatCompletionRequest, - custom_tool_prompt_generator, -) -> list[Message]: - existing_messages = request.messages - existing_system_message = None - if existing_messages[0].role == Role.system.value: - existing_system_message = existing_messages.pop(0) - - assert existing_messages[0].role != Role.system.value, "Should only have 1 system message" - - sys_content = "" - custom_tools, builtin_tools = [], [] - for t in request.tools: - if isinstance(t.tool_name, str): - custom_tools.append(t) - else: - builtin_tools.append(t) - - if builtin_tools: - tool_gen = BuiltinToolGenerator() - tool_template = tool_gen.gen(builtin_tools) - - sys_content += tool_template.render() - sys_content += "\n" - - custom_tools = [dfn for dfn in request.tools if isinstance(dfn.tool_name, str)] - if custom_tools: - fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.python_list - if fmt != ToolPromptFormat.python_list: - raise ValueError(f"Non supported ToolPromptFormat {request.tool_config.tool_prompt_format}") - - system_prompt = None - if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace: - system_prompt = existing_system_message.content - - tool_template = custom_tool_prompt_generator().gen(custom_tools, system_prompt) - - sys_content += tool_template.render() - sys_content += "\n" - - if existing_system_message and ( - request.tool_config.system_message_behavior == SystemMessageBehavior.append or not custom_tools - ): - sys_content += interleaved_content_as_str(existing_system_message.content, sep="\n") - - tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools) - if tool_choice_prompt: - sys_content += "\n" + tool_choice_prompt - - messages = [SystemMessage(content=sys_content.strip("\n")), *existing_messages] - return messages - - def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: list[ToolDefinition]) -> str: if tool_choice == ToolChoice.auto: return "" diff --git a/tests/unit/models/test_prompt_adapter.py b/tests/unit/models/test_prompt_adapter.py deleted file mode 100644 index d31426135..000000000 --- a/tests/unit/models/test_prompt_adapter.py +++ /dev/null @@ -1,303 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - - -from llama_stack.apis.inference import ( - ChatCompletionRequest, - CompletionMessage, - StopReason, - SystemMessage, - SystemMessageBehavior, - ToolCall, - ToolConfig, - UserMessage, -) -from llama_stack.models.llama.datatypes import ( - BuiltinTool, - ToolDefinition, - ToolPromptFormat, -) -from llama_stack.providers.utils.inference.prompt_adapter import ( - chat_completion_request_to_messages, - chat_completion_request_to_prompt, - interleaved_content_as_str, -) - -MODEL = "Llama3.1-8B-Instruct" -MODEL3_2 = "Llama3.2-3B-Instruct" - - -async def test_system_default(): - content = "Hello !" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - UserMessage(content=content), - ], - ) - messages = chat_completion_request_to_messages(request, MODEL) - assert len(messages) == 2 - assert messages[-1].content == content - assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content) - - -async def test_system_builtin_only(): - content = "Hello !" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ToolDefinition(tool_name=BuiltinTool.brave_search), - ], - ) - messages = chat_completion_request_to_messages(request, MODEL) - assert len(messages) == 2 - assert messages[-1].content == content - assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content) - assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content) - - -async def test_system_custom_only(): - content = "Hello !" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - UserMessage(content=content), - ], - tools=[ - ToolDefinition( - tool_name="custom1", - description="custom1 tool", - input_schema={ - "type": "object", - "properties": { - "param1": { - "type": "str", - "description": "param1 description", - }, - }, - "required": ["param1"], - }, - ) - ], - tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json), - ) - messages = chat_completion_request_to_messages(request, MODEL) - assert len(messages) == 3 - assert "Environment: ipython" in interleaved_content_as_str(messages[0].content) - - assert "Return function calls in JSON format" in interleaved_content_as_str(messages[1].content) - assert messages[-1].content == content - - -async def test_system_custom_and_builtin(): - content = "Hello !" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ToolDefinition(tool_name=BuiltinTool.brave_search), - ToolDefinition( - tool_name="custom1", - description="custom1 tool", - input_schema={ - "type": "object", - "properties": { - "param1": { - "type": "str", - "description": "param1 description", - }, - }, - "required": ["param1"], - }, - ), - ], - ) - messages = chat_completion_request_to_messages(request, MODEL) - assert len(messages) == 3 - - assert "Environment: ipython" in interleaved_content_as_str(messages[0].content) - assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content) - - assert "Return function calls in JSON format" in interleaved_content_as_str(messages[1].content) - assert messages[-1].content == content - - -async def test_completion_message_encoding(): - request = ChatCompletionRequest( - model=MODEL3_2, - messages=[ - UserMessage(content="hello"), - CompletionMessage( - content="", - stop_reason=StopReason.end_of_turn, - tool_calls=[ - ToolCall( - tool_name="custom1", - arguments='{"param1": "value1"}', # arguments must be a JSON string - call_id="123", - ) - ], - ), - ], - tools=[ - ToolDefinition( - tool_name="custom1", - description="custom1 tool", - input_schema={ - "type": "object", - "properties": { - "param1": { - "type": "str", - "description": "param1 description", - }, - }, - "required": ["param1"], - }, - ), - ], - tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list), - ) - prompt = await chat_completion_request_to_prompt(request, request.model) - assert '[custom1(param1="value1")]' in prompt - - request.model = MODEL - request.tool_config = ToolConfig(tool_prompt_format=ToolPromptFormat.json) - prompt = await chat_completion_request_to_prompt(request, request.model) - assert '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}' in prompt - - -async def test_user_provided_system_message(): - content = "Hello !" - system_prompt = "You are a pirate" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - SystemMessage(content=system_prompt), - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ], - ) - messages = chat_completion_request_to_messages(request, MODEL) - assert len(messages) == 2 - assert interleaved_content_as_str(messages[0].content).endswith(system_prompt) - - assert messages[-1].content == content - - -async def test_replace_system_message_behavior_builtin_tools(): - content = "Hello !" - system_prompt = "You are a pirate" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - SystemMessage(content=system_prompt), - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ], - tool_config=ToolConfig( - tool_choice="auto", - tool_prompt_format=ToolPromptFormat.python_list, - system_message_behavior=SystemMessageBehavior.replace, - ), - ) - messages = chat_completion_request_to_messages(request, MODEL3_2) - assert len(messages) == 2 - assert interleaved_content_as_str(messages[0].content).endswith(system_prompt) - assert "Environment: ipython" in interleaved_content_as_str(messages[0].content) - assert messages[-1].content == content - - -async def test_replace_system_message_behavior_custom_tools(): - content = "Hello !" - system_prompt = "You are a pirate" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - SystemMessage(content=system_prompt), - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ToolDefinition( - tool_name="custom1", - description="custom1 tool", - input_schema={ - "type": "object", - "properties": { - "param1": { - "type": "str", - "description": "param1 description", - }, - }, - "required": ["param1"], - }, - ), - ], - tool_config=ToolConfig( - tool_choice="auto", - tool_prompt_format=ToolPromptFormat.python_list, - system_message_behavior=SystemMessageBehavior.replace, - ), - ) - messages = chat_completion_request_to_messages(request, MODEL3_2) - - assert len(messages) == 2 - assert interleaved_content_as_str(messages[0].content).endswith(system_prompt) - assert "Environment: ipython" in interleaved_content_as_str(messages[0].content) - assert messages[-1].content == content - - -async def test_replace_system_message_behavior_custom_tools_with_template(): - content = "Hello !" - system_prompt = "You are a pirate {{ function_description }}" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - SystemMessage(content=system_prompt), - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ToolDefinition( - tool_name="custom1", - description="custom1 tool", - input_schema={ - "type": "object", - "properties": { - "param1": { - "type": "str", - "description": "param1 description", - }, - }, - "required": ["param1"], - }, - ), - ], - tool_config=ToolConfig( - tool_choice="auto", - tool_prompt_format=ToolPromptFormat.python_list, - system_message_behavior=SystemMessageBehavior.replace, - ), - ) - messages = chat_completion_request_to_messages(request, MODEL3_2) - - assert len(messages) == 2 - assert "Environment: ipython" in interleaved_content_as_str(messages[0].content) - assert "You are a pirate" in interleaved_content_as_str(messages[0].content) - # function description is present in the system prompt - assert '"name": "custom1"' in interleaved_content_as_str(messages[0].content) - assert messages[-1].content == content diff --git a/tests/unit/providers/inline/inference/__init__.py b/tests/unit/providers/inline/inference/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/tests/unit/providers/inline/inference/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/tests/unit/providers/inline/inference/test_meta_reference.py b/tests/unit/providers/inline/inference/test_meta_reference.py new file mode 100644 index 000000000..381836397 --- /dev/null +++ b/tests/unit/providers/inline/inference/test_meta_reference.py @@ -0,0 +1,44 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import Mock + +import pytest + +from llama_stack.providers.inline.inference.meta_reference.model_parallel import ( + ModelRunner, +) + + +class TestModelRunner: + """Test ModelRunner task dispatching for model-parallel inference.""" + + def test_chat_completion_task_dispatch(self): + """Verify ModelRunner correctly dispatches chat_completion tasks.""" + # Create a mock generator + mock_generator = Mock() + mock_generator.chat_completion = Mock(return_value=iter([])) + + runner = ModelRunner(mock_generator) + + # Create a chat_completion task + fake_params = {"model": "test"} + fake_messages = [{"role": "user", "content": "test"}] + task = ("chat_completion", [fake_params, fake_messages]) + + # Execute task + runner(task) + + # Verify chat_completion was called with correct arguments + mock_generator.chat_completion.assert_called_once_with(fake_params, fake_messages) + + def test_invalid_task_type_raises_error(self): + """Verify ModelRunner rejects invalid task types.""" + mock_generator = Mock() + runner = ModelRunner(mock_generator) + + with pytest.raises(ValueError, match="Unexpected task type"): + runner(("invalid_task", [])) diff --git a/tests/unit/providers/nvidia/test_safety.py b/tests/unit/providers/nvidia/test_safety.py index 922d7f61f..622302630 100644 --- a/tests/unit/providers/nvidia/test_safety.py +++ b/tests/unit/providers/nvidia/test_safety.py @@ -10,11 +10,13 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest -from llama_stack.apis.inference import CompletionMessage, UserMessage +from llama_stack.apis.inference import ( + OpenAIAssistantMessageParam, + OpenAIUserMessageParam, +) from llama_stack.apis.resource import ResourceType from llama_stack.apis.safety import RunShieldResponse, ViolationLevel from llama_stack.apis.shields import Shield -from llama_stack.models.llama.datatypes import StopReason from llama_stack.providers.remote.safety.nvidia.config import NVIDIASafetyConfig from llama_stack.providers.remote.safety.nvidia.nvidia import NVIDIASafetyAdapter @@ -136,11 +138,9 @@ async def test_run_shield_allowed(nvidia_adapter, mock_guardrails_post): # Run the shield messages = [ - UserMessage(role="user", content="Hello, how are you?"), - CompletionMessage( - role="assistant", + OpenAIUserMessageParam(content="Hello, how are you?"), + OpenAIAssistantMessageParam( content="I'm doing well, thank you for asking!", - stop_reason=StopReason.end_of_message, tool_calls=[], ), ] @@ -191,13 +191,10 @@ async def test_run_shield_blocked(nvidia_adapter, mock_guardrails_post): # Mock Guardrails API response mock_guardrails_post.return_value = {"status": "blocked", "rails_status": {"reason": "harmful_content"}} - # Run the shield messages = [ - UserMessage(role="user", content="Hello, how are you?"), - CompletionMessage( - role="assistant", + OpenAIUserMessageParam(content="Hello, how are you?"), + OpenAIAssistantMessageParam( content="I'm doing well, thank you for asking!", - stop_reason=StopReason.end_of_message, tool_calls=[], ), ] @@ -243,7 +240,7 @@ async def test_run_shield_not_found(nvidia_adapter, mock_guardrails_post): adapter.shield_store.get_shield.return_value = None messages = [ - UserMessage(role="user", content="Hello, how are you?"), + OpenAIUserMessageParam(content="Hello, how are you?"), ] with pytest.raises(ValueError): @@ -274,11 +271,9 @@ async def test_run_shield_http_error(nvidia_adapter, mock_guardrails_post): # Running the shield should raise an exception messages = [ - UserMessage(role="user", content="Hello, how are you?"), - CompletionMessage( - role="assistant", + OpenAIUserMessageParam(content="Hello, how are you?"), + OpenAIAssistantMessageParam( content="I'm doing well, thank you for asking!", - stop_reason=StopReason.end_of_message, tool_calls=[], ), ] diff --git a/tests/unit/providers/utils/inference/test_openai_compat.py b/tests/unit/providers/utils/inference/test_openai_compat.py deleted file mode 100644 index c200c4395..000000000 --- a/tests/unit/providers/utils/inference/test_openai_compat.py +++ /dev/null @@ -1,220 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import pytest -from pydantic import ValidationError - -from llama_stack.apis.common.content_types import TextContentItem -from llama_stack.apis.inference import ( - CompletionMessage, - OpenAIAssistantMessageParam, - OpenAIChatCompletionContentPartImageParam, - OpenAIChatCompletionContentPartTextParam, - OpenAIDeveloperMessageParam, - OpenAIImageURL, - OpenAISystemMessageParam, - OpenAIToolMessageParam, - OpenAIUserMessageParam, - SystemMessage, - UserMessage, -) -from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall -from llama_stack.providers.utils.inference.openai_compat import ( - convert_message_to_openai_dict, - convert_message_to_openai_dict_new, - openai_messages_to_messages, -) - - -async def test_convert_message_to_openai_dict(): - message = UserMessage(content=[TextContentItem(text="Hello, world!")], role="user") - assert await convert_message_to_openai_dict(message) == { - "role": "user", - "content": [{"type": "text", "text": "Hello, world!"}], - } - - -# Test convert_message_to_openai_dict with a tool call -async def test_convert_message_to_openai_dict_with_tool_call(): - message = CompletionMessage( - content="", - tool_calls=[ToolCall(call_id="123", tool_name="test_tool", arguments='{"foo": "bar"}')], - stop_reason=StopReason.end_of_turn, - ) - - openai_dict = await convert_message_to_openai_dict(message) - - assert openai_dict == { - "role": "assistant", - "content": [{"type": "text", "text": ""}], - "tool_calls": [ - {"id": "123", "type": "function", "function": {"name": "test_tool", "arguments": '{"foo": "bar"}'}} - ], - } - - -async def test_convert_message_to_openai_dict_with_builtin_tool_call(): - message = CompletionMessage( - content="", - tool_calls=[ - ToolCall( - call_id="123", - tool_name=BuiltinTool.brave_search, - arguments='{"foo": "bar"}', - ) - ], - stop_reason=StopReason.end_of_turn, - ) - - openai_dict = await convert_message_to_openai_dict(message) - - assert openai_dict == { - "role": "assistant", - "content": [{"type": "text", "text": ""}], - "tool_calls": [ - {"id": "123", "type": "function", "function": {"name": "brave_search", "arguments": '{"foo": "bar"}'}} - ], - } - - -async def test_openai_messages_to_messages_with_content_str(): - openai_messages = [ - OpenAISystemMessageParam(content="system message"), - OpenAIUserMessageParam(content="user message"), - OpenAIAssistantMessageParam(content="assistant message"), - ] - - llama_messages = openai_messages_to_messages(openai_messages) - assert len(llama_messages) == 3 - assert isinstance(llama_messages[0], SystemMessage) - assert isinstance(llama_messages[1], UserMessage) - assert isinstance(llama_messages[2], CompletionMessage) - assert llama_messages[0].content == "system message" - assert llama_messages[1].content == "user message" - assert llama_messages[2].content == "assistant message" - - -async def test_openai_messages_to_messages_with_content_list(): - openai_messages = [ - OpenAISystemMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="system message")]), - OpenAIUserMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="user message")]), - OpenAIAssistantMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="assistant message")]), - ] - - llama_messages = openai_messages_to_messages(openai_messages) - assert len(llama_messages) == 3 - assert isinstance(llama_messages[0], SystemMessage) - assert isinstance(llama_messages[1], UserMessage) - assert isinstance(llama_messages[2], CompletionMessage) - assert llama_messages[0].content[0].text == "system message" - assert llama_messages[1].content[0].text == "user message" - assert llama_messages[2].content[0].text == "assistant message" - - -@pytest.mark.parametrize( - "message_class,kwargs", - [ - (OpenAISystemMessageParam, {}), - (OpenAIAssistantMessageParam, {}), - (OpenAIDeveloperMessageParam, {}), - (OpenAIUserMessageParam, {}), - (OpenAIToolMessageParam, {"tool_call_id": "call_123"}), - ], -) -def test_message_accepts_text_string(message_class, kwargs): - """Test that messages accept string text content.""" - msg = message_class(content="Test message", **kwargs) - assert msg.content == "Test message" - - -@pytest.mark.parametrize( - "message_class,kwargs", - [ - (OpenAISystemMessageParam, {}), - (OpenAIAssistantMessageParam, {}), - (OpenAIDeveloperMessageParam, {}), - (OpenAIUserMessageParam, {}), - (OpenAIToolMessageParam, {"tool_call_id": "call_123"}), - ], -) -def test_message_accepts_text_list(message_class, kwargs): - """Test that messages accept list of text content parts.""" - content_list = [OpenAIChatCompletionContentPartTextParam(text="Test message")] - msg = message_class(content=content_list, **kwargs) - assert len(msg.content) == 1 - assert msg.content[0].text == "Test message" - - -@pytest.mark.parametrize( - "message_class,kwargs", - [ - (OpenAISystemMessageParam, {}), - (OpenAIAssistantMessageParam, {}), - (OpenAIDeveloperMessageParam, {}), - (OpenAIToolMessageParam, {"tool_call_id": "call_123"}), - ], -) -def test_message_rejects_images(message_class, kwargs): - """Test that system, assistant, developer, and tool messages reject image content.""" - with pytest.raises(ValidationError): - message_class( - content=[ - OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg")) - ], - **kwargs, - ) - - -def test_user_message_accepts_images(): - """Test that user messages accept image content (unlike other message types).""" - # List with images should work - msg = OpenAIUserMessageParam( - content=[ - OpenAIChatCompletionContentPartTextParam(text="Describe this image:"), - OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg")), - ] - ) - assert len(msg.content) == 2 - assert msg.content[0].text == "Describe this image:" - assert msg.content[1].image_url.url == "http://example.com/image.jpg" - - -async def test_convert_message_to_openai_dict_new_user_message(): - """Test convert_message_to_openai_dict_new with UserMessage.""" - message = UserMessage(content="Hello, world!", role="user") - result = await convert_message_to_openai_dict_new(message) - - assert result["role"] == "user" - assert result["content"] == "Hello, world!" - - -async def test_convert_message_to_openai_dict_new_completion_message_with_tool_calls(): - """Test convert_message_to_openai_dict_new with CompletionMessage containing tool calls.""" - message = CompletionMessage( - content="I'll help you find the weather.", - tool_calls=[ - ToolCall( - call_id="call_123", - tool_name="get_weather", - arguments='{"city": "Sligo"}', - ) - ], - stop_reason=StopReason.end_of_turn, - ) - result = await convert_message_to_openai_dict_new(message) - - # This would have failed with "Cannot instantiate typing.Union" before the fix - assert result["role"] == "assistant" - assert result["content"] == "I'll help you find the weather." - assert "tool_calls" in result - assert result["tool_calls"] is not None - assert len(result["tool_calls"]) == 1 - - tool_call = result["tool_calls"][0] - assert tool_call.id == "call_123" - assert tool_call.type == "function" - assert tool_call.function.name == "get_weather" - assert tool_call.function.arguments == '{"city": "Sligo"}' diff --git a/tests/unit/providers/utils/inference/test_prompt_adapter.py b/tests/unit/providers/utils/inference/test_prompt_adapter.py new file mode 100644 index 000000000..62c8db74d --- /dev/null +++ b/tests/unit/providers/utils/inference/test_prompt_adapter.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.inference import ( + OpenAIAssistantMessageParam, + OpenAIUserMessageParam, +) +from llama_stack.models.llama.datatypes import RawTextItem +from llama_stack.providers.utils.inference.prompt_adapter import ( + convert_openai_message_to_raw_message, +) + + +class TestConvertOpenAIMessageToRawMessage: + """Test conversion of OpenAI message types to RawMessage format.""" + + async def test_user_message_conversion(self): + msg = OpenAIUserMessageParam(role="user", content="Hello world") + raw_msg = await convert_openai_message_to_raw_message(msg) + + assert raw_msg.role == "user" + assert isinstance(raw_msg.content, RawTextItem) + assert raw_msg.content.text == "Hello world" + + async def test_assistant_message_conversion(self): + msg = OpenAIAssistantMessageParam(role="assistant", content="Hi there!") + raw_msg = await convert_openai_message_to_raw_message(msg) + + assert raw_msg.role == "assistant" + assert isinstance(raw_msg.content, RawTextItem) + assert raw_msg.content.text == "Hi there!" + assert raw_msg.tool_calls == [] From 97ccfb5e626919956aee0bf0e890a7196e8af6a2 Mon Sep 17 00:00:00 2001 From: Nathan Weinberg <31703736+nathan-weinberg@users.noreply.github.com> Date: Mon, 10 Nov 2025 18:57:17 -0500 Subject: [PATCH 2/2] refactor: inspect routes now shows all non-deprecated APIs (#4116) # What does this PR do? the inspect API lacked any mechanism to get all non-deprecated APIs (v1, v1alpha, v1beta) change default to this behavior 'v1' filter can be used for user' wanting a list of stable APIs ## Test Plan 1. pull the PR 2. launch a LLS server 3. run `curl http://beanlab3.bss.redhat.com:8321/v1/inspect/routes` 4. note there are APIs for `v1`, `v1alpha`, and `v1beta` but no deprecated APIs Signed-off-by: Nathan Weinberg --- client-sdks/stainless/openapi.yml | 2 +- docs/static/llama-stack-spec.yaml | 2 +- docs/static/stainless-llama-stack-spec.yaml | 2 +- src/llama_stack/apis/inspect/inspect.py | 2 +- src/llama_stack/core/inspect.py | 5 ++--- 5 files changed, 6 insertions(+), 7 deletions(-) diff --git a/client-sdks/stainless/openapi.yml b/client-sdks/stainless/openapi.yml index 58ebaa8ae..9f3ef15b5 100644 --- a/client-sdks/stainless/openapi.yml +++ b/client-sdks/stainless/openapi.yml @@ -963,7 +963,7 @@ paths: Optional filter to control which routes are returned. Can be an API level ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, or 'deprecated' to show deprecated routes across all levels. If not specified, - returns only non-deprecated v1 routes. + returns all non-deprecated routes. required: false schema: type: string diff --git a/docs/static/llama-stack-spec.yaml b/docs/static/llama-stack-spec.yaml index 135ae910f..ce8708b68 100644 --- a/docs/static/llama-stack-spec.yaml +++ b/docs/static/llama-stack-spec.yaml @@ -960,7 +960,7 @@ paths: Optional filter to control which routes are returned. Can be an API level ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, or 'deprecated' to show deprecated routes across all levels. If not specified, - returns only non-deprecated v1 routes. + returns all non-deprecated routes. required: false schema: type: string diff --git a/docs/static/stainless-llama-stack-spec.yaml b/docs/static/stainless-llama-stack-spec.yaml index 58ebaa8ae..9f3ef15b5 100644 --- a/docs/static/stainless-llama-stack-spec.yaml +++ b/docs/static/stainless-llama-stack-spec.yaml @@ -963,7 +963,7 @@ paths: Optional filter to control which routes are returned. Can be an API level ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, or 'deprecated' to show deprecated routes across all levels. If not specified, - returns only non-deprecated v1 routes. + returns all non-deprecated routes. required: false schema: type: string diff --git a/src/llama_stack/apis/inspect/inspect.py b/src/llama_stack/apis/inspect/inspect.py index 4e0e2548b..235abb124 100644 --- a/src/llama_stack/apis/inspect/inspect.py +++ b/src/llama_stack/apis/inspect/inspect.py @@ -76,7 +76,7 @@ class Inspect(Protocol): List all available API routes with their methods and implementing providers. - :param api_filter: Optional filter to control which routes are returned. Can be an API level ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, or 'deprecated' to show deprecated routes across all levels. If not specified, returns only non-deprecated v1 routes. + :param api_filter: Optional filter to control which routes are returned. Can be an API level ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, or 'deprecated' to show deprecated routes across all levels. If not specified, returns all non-deprecated routes. :returns: Response containing information about all available routes. """ ... diff --git a/src/llama_stack/core/inspect.py b/src/llama_stack/core/inspect.py index 6352af00f..07b51128f 100644 --- a/src/llama_stack/core/inspect.py +++ b/src/llama_stack/core/inspect.py @@ -15,7 +15,6 @@ from llama_stack.apis.inspect import ( RouteInfo, VersionInfo, ) -from llama_stack.apis.version import LLAMA_STACK_API_V1 from llama_stack.core.datatypes import StackRunConfig from llama_stack.core.external import load_external_apis from llama_stack.core.server.routes import get_all_api_routes @@ -46,8 +45,8 @@ class DistributionInspectImpl(Inspect): # Helper function to determine if a route should be included based on api_filter def should_include_route(webmethod) -> bool: if api_filter is None: - # Default: only non-deprecated v1 APIs - return not webmethod.deprecated and webmethod.level == LLAMA_STACK_API_V1 + # Default: only non-deprecated APIs + return not webmethod.deprecated elif api_filter == "deprecated": # Special filter: show deprecated routes regardless of their actual level return bool(webmethod.deprecated)