diff --git a/src/llama_stack/apis/inference/event_logger.py b/src/llama_stack/apis/inference/event_logger.py deleted file mode 100644 index d97ece6d4..000000000 --- a/src/llama_stack/apis/inference/event_logger.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from termcolor import cprint - -from llama_stack.apis.inference import ( - ChatCompletionResponseEventType, - ChatCompletionResponseStreamChunk, -) - - -class LogEvent: - def __init__( - self, - content: str = "", - end: str = "\n", - color="white", - ): - self.content = content - self.color = color - self.end = "\n" if end is None else end - - def print(self, flush=True): - cprint(f"{self.content}", color=self.color, end=self.end, flush=flush) - - -class EventLogger: - async def log(self, event_generator): - async for chunk in event_generator: - if isinstance(chunk, ChatCompletionResponseStreamChunk): - event = chunk.event - if event.event_type == ChatCompletionResponseEventType.start: - yield LogEvent("Assistant> ", color="cyan", end="") - elif event.event_type == ChatCompletionResponseEventType.progress: - yield LogEvent(event.delta, color="yellow", end="") - elif event.event_type == ChatCompletionResponseEventType.complete: - yield LogEvent("") - else: - yield LogEvent("Assistant> ", color="cyan", end="") - yield LogEvent(chunk.completion_message.content, color="yellow") diff --git a/src/llama_stack/apis/inference/inference.py b/src/llama_stack/apis/inference/inference.py index 1a865ce5f..9f04917c9 100644 --- a/src/llama_stack/apis/inference/inference.py +++ b/src/llama_stack/apis/inference/inference.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from collections.abc import AsyncIterator -from enum import Enum +from enum import Enum, StrEnum from typing import ( Annotated, Any, @@ -15,28 +15,18 @@ from typing import ( ) from fastapi import Body -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field from typing_extensions import TypedDict -from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent -from llama_stack.apis.common.responses import MetricResponseMixin, Order +from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.common.responses import ( + Order, +) from llama_stack.apis.common.tracing import telemetry_traceable from llama_stack.apis.models import Model from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA -from llama_stack.models.llama.datatypes import ( - BuiltinTool, - StopReason, - ToolCall, - ToolDefinition, - ToolPromptFormat, -) from llama_stack.schema_utils import json_schema_type, register_schema, webmethod -register_schema(ToolCall) -register_schema(ToolDefinition) - -from enum import StrEnum - @json_schema_type class GreedySamplingStrategy(BaseModel): @@ -201,58 +191,6 @@ class ToolResponseMessage(BaseModel): content: InterleavedContent -@json_schema_type -class CompletionMessage(BaseModel): - """A message containing the model's (assistant) response in a chat conversation. - - :param role: Must be "assistant" to identify this as the model's response - :param content: The content of the model's response - :param stop_reason: Reason why the model stopped generating. Options are: - - `StopReason.end_of_turn`: The model finished generating the entire response. - - `StopReason.end_of_message`: The model finished generating but generated a partial response -- usually, a tool call. The user may call the tool and continue the conversation with the tool's response. - - `StopReason.out_of_tokens`: The model ran out of token budget. - :param tool_calls: List of tool calls. Each tool call is a ToolCall object. - """ - - role: Literal["assistant"] = "assistant" - content: InterleavedContent - stop_reason: StopReason - tool_calls: list[ToolCall] | None = Field(default_factory=lambda: []) - - -Message = Annotated[ - UserMessage | SystemMessage | ToolResponseMessage | CompletionMessage, - Field(discriminator="role"), -] -register_schema(Message, name="Message") - - -@json_schema_type -class ToolResponse(BaseModel): - """Response from a tool invocation. - - :param call_id: Unique identifier for the tool call this response is for - :param tool_name: Name of the tool that was invoked - :param content: The response content from the tool - :param metadata: (Optional) Additional metadata about the tool response - """ - - call_id: str - tool_name: BuiltinTool | str - content: InterleavedContent - metadata: dict[str, Any] | None = None - - @field_validator("tool_name", mode="before") - @classmethod - def validate_field(cls, v): - if isinstance(v, str): - try: - return BuiltinTool(v) - except ValueError: - return v - return v - - class ToolChoice(Enum): """Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model. @@ -289,22 +227,6 @@ class ChatCompletionResponseEventType(Enum): progress = "progress" -@json_schema_type -class ChatCompletionResponseEvent(BaseModel): - """An event during chat completion generation. - - :param event_type: Type of the event - :param delta: Content generated since last event. This can be one or more tokens, or a tool call. - :param logprobs: Optional log probabilities for generated tokens - :param stop_reason: Optional reason why generation stopped, if complete - """ - - event_type: ChatCompletionResponseEventType - delta: ContentDelta - logprobs: list[TokenLogProbs] | None = None - stop_reason: StopReason | None = None - - class ResponseFormatType(StrEnum): """Types of formats for structured (guided) decoding. @@ -357,34 +279,6 @@ class CompletionRequest(BaseModel): logprobs: LogProbConfig | None = None -@json_schema_type -class CompletionResponse(MetricResponseMixin): - """Response from a completion request. - - :param content: The generated completion text - :param stop_reason: Reason why generation stopped - :param logprobs: Optional log probabilities for generated tokens - """ - - content: str - stop_reason: StopReason - logprobs: list[TokenLogProbs] | None = None - - -@json_schema_type -class CompletionResponseStreamChunk(MetricResponseMixin): - """A chunk of a streamed completion response. - - :param delta: New content generated since last chunk. This can be one or more tokens. - :param stop_reason: Optional reason why generation stopped, if complete - :param logprobs: Optional log probabilities for generated tokens - """ - - delta: str - stop_reason: StopReason | None = None - logprobs: list[TokenLogProbs] | None = None - - class SystemMessageBehavior(Enum): """Config for how to override the default system prompt. @@ -398,70 +292,6 @@ class SystemMessageBehavior(Enum): replace = "replace" -@json_schema_type -class ToolConfig(BaseModel): - """Configuration for tool use. - - :param tool_choice: (Optional) Whether tool use is automatic, required, or none. Can also specify a tool name to use a specific tool. Defaults to ToolChoice.auto. - :param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model. - - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object. - - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a tag. - - `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls. - :param system_message_behavior: (Optional) Config for how to override the default system prompt. - - `SystemMessageBehavior.append`: Appends the provided system message to the default system prompt. - - `SystemMessageBehavior.replace`: Replaces the default system prompt with the provided system message. The system message can include the string - '{{function_definitions}}' to indicate where the function definitions should be inserted. - """ - - tool_choice: ToolChoice | str | None = Field(default=ToolChoice.auto) - tool_prompt_format: ToolPromptFormat | None = Field(default=None) - system_message_behavior: SystemMessageBehavior | None = Field(default=SystemMessageBehavior.append) - - def model_post_init(self, __context: Any) -> None: - if isinstance(self.tool_choice, str): - try: - self.tool_choice = ToolChoice[self.tool_choice] - except KeyError: - pass - - -# This is an internally used class -@json_schema_type -class ChatCompletionRequest(BaseModel): - model: str - messages: list[Message] - sampling_params: SamplingParams | None = Field(default_factory=SamplingParams) - - tools: list[ToolDefinition] | None = Field(default_factory=lambda: []) - tool_config: ToolConfig | None = Field(default_factory=ToolConfig) - - response_format: ResponseFormat | None = None - stream: bool | None = False - logprobs: LogProbConfig | None = None - - -@json_schema_type -class ChatCompletionResponseStreamChunk(MetricResponseMixin): - """A chunk of a streamed chat completion response. - - :param event: The event containing the new content - """ - - event: ChatCompletionResponseEvent - - -@json_schema_type -class ChatCompletionResponse(MetricResponseMixin): - """Response from a chat completion request. - - :param completion_message: The complete response message - :param logprobs: Optional log probabilities for generated tokens - """ - - completion_message: CompletionMessage - logprobs: list[TokenLogProbs] | None = None - - @json_schema_type class EmbeddingsResponse(BaseModel): """Response containing generated embeddings. diff --git a/src/llama_stack/core/routers/safety.py b/src/llama_stack/core/routers/safety.py index 79eac8b46..e5ff2ada9 100644 --- a/src/llama_stack/core/routers/safety.py +++ b/src/llama_stack/core/routers/safety.py @@ -6,7 +6,7 @@ from typing import Any -from llama_stack.apis.inference import Message +from llama_stack.apis.inference import OpenAIMessageParam from llama_stack.apis.safety import RunShieldResponse, Safety from llama_stack.apis.safety.safety import ModerationObject from llama_stack.apis.shields import Shield @@ -52,7 +52,7 @@ class SafetyRouter(Safety): async def run_shield( self, shield_id: str, - messages: list[Message], + messages: list[OpenAIMessageParam], params: dict[str, Any] = None, ) -> RunShieldResponse: logger.debug(f"SafetyRouter.run_shield: {shield_id}") diff --git a/src/llama_stack/models/llama/llama3/generation.py b/src/llama_stack/models/llama/llama3/generation.py index fe7be5ea9..9ac215c3b 100644 --- a/src/llama_stack/models/llama/llama3/generation.py +++ b/src/llama_stack/models/llama/llama3/generation.py @@ -26,8 +26,10 @@ from fairscale.nn.model_parallel.initialize import ( ) from termcolor import cprint +from llama_stack.models.llama.datatypes import ToolPromptFormat + from ..checkpoint import maybe_reshard_state_dict -from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage, ToolPromptFormat +from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage from .args import ModelArgs from .chat_format import ChatFormat, LLMInput from .model import Transformer diff --git a/src/llama_stack/models/llama/llama3/interface.py b/src/llama_stack/models/llama/llama3/interface.py index b63ba4847..89be31a55 100644 --- a/src/llama_stack/models/llama/llama3/interface.py +++ b/src/llama_stack/models/llama/llama3/interface.py @@ -15,13 +15,10 @@ from pathlib import Path from termcolor import colored +from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall, ToolDefinition, ToolPromptFormat + from ..datatypes import ( - BuiltinTool, RawMessage, - StopReason, - ToolCall, - ToolDefinition, - ToolPromptFormat, ) from . import template_data from .chat_format import ChatFormat diff --git a/src/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py b/src/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py index 11a5993e9..3fbaa103e 100644 --- a/src/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py +++ b/src/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py @@ -15,7 +15,7 @@ import textwrap from datetime import datetime from typing import Any -from llama_stack.apis.inference import ( +from llama_stack.models.llama.datatypes import ( BuiltinTool, ToolDefinition, ) diff --git a/src/llama_stack/models/llama/llama3/tool_utils.py b/src/llama_stack/models/llama/llama3/tool_utils.py index 8c12fe680..6f919e1fa 100644 --- a/src/llama_stack/models/llama/llama3/tool_utils.py +++ b/src/llama_stack/models/llama/llama3/tool_utils.py @@ -8,8 +8,9 @@ import json import re from llama_stack.log import get_logger +from llama_stack.models.llama.datatypes import BuiltinTool, ToolCall, ToolPromptFormat -from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat +from ..datatypes import RecursiveType logger = get_logger(name=__name__, category="models::llama") diff --git a/src/llama_stack/models/llama/llama4/prompt_templates/system_prompts.py b/src/llama_stack/models/llama/llama4/prompt_templates/system_prompts.py index 1ee570933..feded9f8c 100644 --- a/src/llama_stack/models/llama/llama4/prompt_templates/system_prompts.py +++ b/src/llama_stack/models/llama/llama4/prompt_templates/system_prompts.py @@ -13,7 +13,7 @@ import textwrap -from llama_stack.apis.inference import ToolDefinition +from llama_stack.models.llama.datatypes import ToolDefinition from llama_stack.models.llama.llama3.prompt_templates.base import ( PromptTemplate, PromptTemplateGeneratorBase, diff --git a/src/llama_stack/providers/inline/inference/meta_reference/generators.py b/src/llama_stack/providers/inline/inference/meta_reference/generators.py index cb926f529..51a2ddfad 100644 --- a/src/llama_stack/providers/inline/inference/meta_reference/generators.py +++ b/src/llama_stack/providers/inline/inference/meta_reference/generators.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import math -from collections.abc import Generator from typing import Optional import torch @@ -14,21 +13,19 @@ from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerToken from llama_stack.apis.inference import ( GreedySamplingStrategy, JsonSchemaResponseFormat, + OpenAIChatCompletionRequestWithExtraBody, + OpenAIResponseFormatJSONSchema, ResponseFormat, + ResponseFormatType, SamplingParams, TopPSamplingStrategy, ) -from llama_stack.models.llama.datatypes import QuantizationMode +from llama_stack.models.llama.datatypes import QuantizationMode, ToolPromptFormat from llama_stack.models.llama.llama3.generation import Llama3 from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer from llama_stack.models.llama.llama4.generation import Llama4 from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer from llama_stack.models.llama.sku_types import Model, ModelFamily -from llama_stack.providers.utils.inference.prompt_adapter import ( - ChatCompletionRequestWithRawContent, - CompletionRequestWithRawContent, - get_default_tool_prompt_format, -) from .common import model_checkpoint_dir from .config import MetaReferenceInferenceConfig @@ -106,14 +103,6 @@ def _infer_sampling_params(sampling_params: SamplingParams): return temperature, top_p -def _infer_tool_prompt_format(request: ChatCompletionRequestWithRawContent): - tool_config = request.tool_config - if tool_config is not None and tool_config.tool_prompt_format is not None: - return tool_config.tool_prompt_format - else: - return get_default_tool_prompt_format(request.model) - - class LlamaGenerator: def __init__( self, @@ -157,55 +146,56 @@ class LlamaGenerator: self.args = self.inner_generator.args self.formatter = self.inner_generator.formatter - def completion( - self, - request_batch: list[CompletionRequestWithRawContent], - ) -> Generator: - first_request = request_batch[0] - sampling_params = first_request.sampling_params or SamplingParams() - max_gen_len = sampling_params.max_tokens - if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len: - max_gen_len = self.args.max_seq_len - 1 - - temperature, top_p = _infer_sampling_params(sampling_params) - yield from self.inner_generator.generate( - llm_inputs=[self.formatter.encode_content(request.content) for request in request_batch], - max_gen_len=max_gen_len, - temperature=temperature, - top_p=top_p, - logprobs=bool(first_request.logprobs), - echo=False, - logits_processor=get_logits_processor( - self.tokenizer, - self.args.vocab_size, - first_request.response_format, - ), - ) - def chat_completion( self, - request_batch: list[ChatCompletionRequestWithRawContent], - ) -> Generator: - first_request = request_batch[0] - sampling_params = first_request.sampling_params or SamplingParams() + request: OpenAIChatCompletionRequestWithExtraBody, + raw_messages: list, + ): + """Generate chat completion using OpenAI request format. + + Args: + request: OpenAI chat completion request + raw_messages: Pre-converted list of RawMessage objects + """ + + # Determine tool prompt format + tool_prompt_format = ToolPromptFormat.json if request.tools else ToolPromptFormat.json + + # Prepare sampling params + sampling_params = SamplingParams() + if request.temperature is not None or request.top_p is not None: + sampling_params.strategy = TopPSamplingStrategy( + temperature=request.temperature if request.temperature is not None else 1.0, + top_p=request.top_p if request.top_p is not None else 1.0, + ) + if request.max_tokens: + sampling_params.max_tokens = request.max_tokens + max_gen_len = sampling_params.max_tokens if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len: max_gen_len = self.args.max_seq_len - 1 temperature, top_p = _infer_sampling_params(sampling_params) + + # Get logits processor for response format + logits_processor = None + if request.response_format: + if isinstance(request.response_format, OpenAIResponseFormatJSONSchema): + # Extract the actual schema from OpenAIJSONSchema TypedDict + schema_dict = request.response_format.json_schema.get("schema") or {} + json_schema_format = JsonSchemaResponseFormat( + type=ResponseFormatType.json_schema, + json_schema=schema_dict, + ) + logits_processor = get_logits_processor(self.tokenizer, self.args.vocab_size, json_schema_format) + + # Generate yield from self.inner_generator.generate( - llm_inputs=[ - self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request)) - for request in request_batch - ], + llm_inputs=[self.formatter.encode_dialog_prompt(raw_messages, tool_prompt_format)], max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, - logprobs=bool(first_request.logprobs), + logprobs=False, echo=False, - logits_processor=get_logits_processor( - self.tokenizer, - self.args.vocab_size, - first_request.response_format, - ), + logits_processor=logits_processor, ) diff --git a/src/llama_stack/providers/inline/inference/meta_reference/inference.py b/src/llama_stack/providers/inline/inference/meta_reference/inference.py index 76d3fdd50..ef21132a0 100644 --- a/src/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/src/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -5,12 +5,19 @@ # the root directory of this source tree. import asyncio +import time +import uuid from collections.abc import AsyncIterator from llama_stack.apis.inference import ( InferenceProvider, + OpenAIAssistantMessageParam, OpenAIChatCompletionRequestWithExtraBody, + OpenAIChatCompletionUsage, + OpenAIChoice, OpenAICompletionRequestWithExtraBody, + OpenAIUserMessageParam, + ToolChoice, ) from llama_stack.apis.inference.inference import ( OpenAIChatCompletion, @@ -19,12 +26,20 @@ from llama_stack.apis.inference.inference import ( ) from llama_stack.apis.models import Model, ModelType from llama_stack.log import get_logger +from llama_stack.models.llama.datatypes import RawMessage, RawTextItem, ToolDefinition from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat +from llama_stack.models.llama.llama3.prompt_templates import ( + JsonCustomToolGenerator, + SystemDefaultGenerator, +) from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat +from llama_stack.models.llama.llama4.prompt_templates.system_prompts import ( + PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4, +) from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer from llama_stack.models.llama.sku_list import resolve_model -from llama_stack.models.llama.sku_types import ModelFamily +from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.embedding_mixin import ( SentenceTransformerEmbeddingMixin, @@ -44,6 +59,170 @@ log = get_logger(__name__, category="inference") SEMAPHORE = asyncio.Semaphore(1) +def _convert_openai_tool_to_tool_definition(tool) -> ToolDefinition: + """Convert OpenAI tool format to ToolDefinition format.""" + # OpenAI tools have function.name and function.parameters + return ToolDefinition( + tool_name=tool.function.name, + description=tool.function.description or "", + parameters=tool.function.parameters or {}, + ) + + +def _get_tool_choice_prompt(tool_choice, tools) -> str: + """Generate prompt text for tool_choice behavior.""" + if not tool_choice or tool_choice == ToolChoice.auto or tool_choice == "auto": + return "" + elif tool_choice == ToolChoice.required or tool_choice == "required": + return "You MUST use one of the provided functions/tools to answer the user query." + elif tool_choice == ToolChoice.none or tool_choice == "none": + return "" + else: + # Specific tool specified + return f"You MUST use the tool `{tool_choice}` to answer the user query." + + +def _raw_content_as_str(content) -> str: + """Convert RawContent to string for system messages.""" + if isinstance(content, str): + return content + elif isinstance(content, RawTextItem): + return content.text + elif isinstance(content, list): + return "\n".join(_raw_content_as_str(c) for c in content) + else: + return "" + + +def _augment_raw_messages_for_tools_llama_3_1( + raw_messages: list[RawMessage], + tools: list, + tool_choice, +) -> list[RawMessage]: + """Augment raw messages with tool definitions for Llama 3.1 style models.""" + messages = raw_messages.copy() + existing_system_message = None + if messages and messages[0].role == "system": + existing_system_message = messages.pop(0) + + sys_content = "" + + # Add tool definitions first (if present) + if tools: + # Convert OpenAI tools to ToolDefinitions + tool_definitions = [_convert_openai_tool_to_tool_definition(t) for t in tools] + + # For OpenAI format, all tools are custom (have string names) + tool_gen = JsonCustomToolGenerator() + tool_template = tool_gen.gen(tool_definitions) + sys_content += tool_template.render() + sys_content += "\n" + + # Add default system prompt + default_gen = SystemDefaultGenerator() + default_template = default_gen.gen() + sys_content += default_template.render() + + # Add existing system message if present + if existing_system_message: + sys_content += "\n" + _raw_content_as_str(existing_system_message.content) + + # Add tool choice prompt if needed + if tool_choice_prompt := _get_tool_choice_prompt(tool_choice, tools): + sys_content += "\n" + tool_choice_prompt + + # Create new system message + new_system_message = RawMessage( + role="system", + content=[RawTextItem(text=sys_content.strip())], + ) + + return [new_system_message] + messages + + +def _augment_raw_messages_for_tools_llama_4( + raw_messages: list[RawMessage], + tools: list, + tool_choice, +) -> list[RawMessage]: + """Augment raw messages with tool definitions for Llama 4/3.2/3.3 style models.""" + messages = raw_messages.copy() + existing_system_message = None + if messages and messages[0].role == "system": + existing_system_message = messages.pop(0) + + sys_content = "" + + # Add tool definitions if present + if tools: + # Convert OpenAI tools to ToolDefinitions + tool_definitions = [_convert_openai_tool_to_tool_definition(t) for t in tools] + + # Use python_list format for Llama 4 + tool_gen = PythonListCustomToolGeneratorLlama4() + system_prompt = None + if existing_system_message: + system_prompt = _raw_content_as_str(existing_system_message.content) + + tool_template = tool_gen.gen(tool_definitions, system_prompt) + sys_content = tool_template.render() + elif existing_system_message: + # No tools, just use existing system message + sys_content = _raw_content_as_str(existing_system_message.content) + + # Add tool choice prompt if needed + if tool_choice_prompt := _get_tool_choice_prompt(tool_choice, tools): + sys_content += "\n" + tool_choice_prompt + + if sys_content: + new_system_message = RawMessage( + role="system", + content=[RawTextItem(text=sys_content.strip())], + ) + return [new_system_message] + messages + + return messages + + +def augment_raw_messages_for_tools( + raw_messages: list[RawMessage], + params: OpenAIChatCompletionRequestWithExtraBody, + llama_model, +) -> list[RawMessage]: + """Augment raw messages with tool definitions based on model family.""" + if not params.tools: + return raw_messages + + # Determine augmentation strategy based on model family + if llama_model.model_family == ModelFamily.llama3_1 or ( + llama_model.model_family == ModelFamily.llama3_2 and is_multimodal(llama_model.core_model_id) + ): + # Llama 3.1 and Llama 3.2 multimodal use JSON format + return _augment_raw_messages_for_tools_llama_3_1( + raw_messages, + params.tools, + params.tool_choice, + ) + elif llama_model.model_family in ( + ModelFamily.llama3_2, + ModelFamily.llama3_3, + ModelFamily.llama4, + ): + # Llama 3.2/3.3/4 use python_list format + return _augment_raw_messages_for_tools_llama_4( + raw_messages, + params.tools, + params.tool_choice, + ) + else: + # Default to Llama 3.1 style + return _augment_raw_messages_for_tools_llama_3_1( + raw_messages, + params.tools, + params.tool_choice, + ) + + def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator: return LlamaGenerator(config, model_id, llama_model) @@ -136,10 +315,13 @@ class MetaReferenceInferenceImpl( self.llama_model = llama_model log.info("Warming up...") + await self.openai_chat_completion( - model=model_id, - messages=[{"role": "user", "content": "Hi how are you?"}], - max_tokens=20, + params=OpenAIChatCompletionRequestWithExtraBody( + model=model_id, + messages=[OpenAIUserMessageParam(role="user", content="Hi how are you?")], + max_tokens=20, + ) ) log.info("Warmed up!") @@ -155,4 +337,207 @@ class MetaReferenceInferenceImpl( self, params: OpenAIChatCompletionRequestWithExtraBody, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: - raise NotImplementedError("OpenAI chat completion not supported by meta-reference inference provider") + self.check_model(params) + + # Convert OpenAI messages to RawMessages + from llama_stack.models.llama.datatypes import StopReason + from llama_stack.providers.utils.inference.prompt_adapter import ( + convert_openai_message_to_raw_message, + decode_assistant_message, + ) + + raw_messages = [await convert_openai_message_to_raw_message(msg) for msg in params.messages] + + # Augment messages with tool definitions if tools are present + raw_messages = augment_raw_messages_for_tools(raw_messages, params, self.llama_model) + + # Call generator's chat_completion method (works for both single-GPU and model-parallel) + if isinstance(self.generator, LlamaGenerator): + generator = self.generator.chat_completion(params, raw_messages) + else: + # Model parallel: submit task to process group + generator = self.generator.group.run_inference(("chat_completion", [params, raw_messages])) + + # Check if streaming is requested + if params.stream: + return self._stream_chat_completion(generator, params) + + # Non-streaming: collect all generated text + generated_text = "" + for result_batch in generator: + for result in result_batch: + if not result.ignore_token and result.source == "output": + generated_text += result.text + + # Decode assistant message to extract tool calls and determine stop_reason + # Default to end_of_turn if generation completed normally + decoded_message = decode_assistant_message(generated_text, StopReason.end_of_turn) + + # Convert tool calls to OpenAI format + openai_tool_calls = None + if decoded_message.tool_calls: + from llama_stack.apis.inference import ( + OpenAIChatCompletionToolCall, + OpenAIChatCompletionToolCallFunction, + ) + + openai_tool_calls = [ + OpenAIChatCompletionToolCall( + # generate a uuid for the call id. This is the only inline provider that does this, so need to get creative. + id=f"call_{uuid.uuid4().hex[:24]}", + type="function", + function=OpenAIChatCompletionToolCallFunction( + name=str(tc.tool_name), + arguments=tc.arguments, + ), + ) + for tc in decoded_message.tool_calls + ] + + # Determine finish_reason based on whether tool calls are present + finish_reason = "tool_calls" if openai_tool_calls else "stop" + + # Extract content from decoded message + content = "" + if isinstance(decoded_message.content, str): + content = decoded_message.content + elif isinstance(decoded_message.content, list): + for item in decoded_message.content: + if isinstance(item, RawTextItem): + content += item.text + + # Create OpenAI response + # generate a uuid for the call id. This is the only inline provider that does this, so need to get creative. + response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}" + created = int(time.time()) + + return OpenAIChatCompletion( + id=response_id, + object="chat.completion", + created=created, + model=params.model, + choices=[ + OpenAIChoice( + index=0, + message=OpenAIAssistantMessageParam( + role="assistant", + content=content, + tool_calls=openai_tool_calls, + ), + finish_reason=finish_reason, + logprobs=None, + ) + ], + usage=OpenAIChatCompletionUsage( + prompt_tokens=0, # TODO: calculate properly + completion_tokens=0, # TODO: calculate properly + total_tokens=0, # TODO: calculate properly + ), + ) + + async def _stream_chat_completion( + self, + generator, + params: OpenAIChatCompletionRequestWithExtraBody, + ) -> AsyncIterator[OpenAIChatCompletionChunk]: + """Stream chat completion chunks as they're generated.""" + from llama_stack.apis.inference import ( + OpenAIChatCompletionChunk, + OpenAIChatCompletionToolCall, + OpenAIChatCompletionToolCallFunction, + OpenAIChoiceDelta, + OpenAIChunkChoice, + ) + from llama_stack.models.llama.datatypes import StopReason + from llama_stack.providers.utils.inference.prompt_adapter import decode_assistant_message + + response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}" + created = int(time.time()) + generated_text = "" + + # Yield chunks as tokens are generated + for result_batch in generator: + for result in result_batch: + if result.ignore_token or result.source != "output": + continue + + generated_text += result.text + + # Yield delta chunk with the new text + chunk = OpenAIChatCompletionChunk( + id=response_id, + object="chat.completion.chunk", + created=created, + model=params.model, + choices=[ + OpenAIChunkChoice( + index=0, + delta=OpenAIChoiceDelta( + role="assistant", + content=result.text, + ), + finish_reason="", + logprobs=None, + ) + ], + ) + yield chunk + + # After generation completes, decode the full message to extract tool calls + decoded_message = decode_assistant_message(generated_text, StopReason.end_of_turn) + + # If tool calls are present, yield a final chunk with tool_calls + if decoded_message.tool_calls: + openai_tool_calls = [ + OpenAIChatCompletionToolCall( + # generate a uuid for the call id. This is the only inline provider that does this, so need to get creative. + id=f"call_{uuid.uuid4().hex[:24]}", + type="function", + function=OpenAIChatCompletionToolCallFunction( + name=str(tc.tool_name), + arguments=tc.arguments, + ), + ) + for tc in decoded_message.tool_calls + ] + + # Yield chunk with tool_calls + chunk = OpenAIChatCompletionChunk( + id=response_id, + object="chat.completion.chunk", + created=created, + model=params.model, + choices=[ + OpenAIChunkChoice( + index=0, + delta=OpenAIChoiceDelta( + role="assistant", + tool_calls=openai_tool_calls, + ), + finish_reason="", + logprobs=None, + ) + ], + ) + yield chunk + + finish_reason = "tool_calls" + else: + finish_reason = "stop" + + # Yield final chunk with finish_reason + final_chunk = OpenAIChatCompletionChunk( + id=response_id, + object="chat.completion.chunk", + created=created, + model=params.model, + choices=[ + OpenAIChunkChoice( + index=0, + delta=OpenAIChoiceDelta(), + finish_reason=finish_reason, + logprobs=None, + ) + ], + ) + yield final_chunk diff --git a/src/llama_stack/providers/inline/inference/meta_reference/model_parallel.py b/src/llama_stack/providers/inline/inference/meta_reference/model_parallel.py index 9d0295d65..f50b41f34 100644 --- a/src/llama_stack/providers/inline/inference/meta_reference/model_parallel.py +++ b/src/llama_stack/providers/inline/inference/meta_reference/model_parallel.py @@ -4,17 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from collections.abc import Callable, Generator -from copy import deepcopy +from collections.abc import Callable from functools import partial from typing import Any from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat -from llama_stack.providers.utils.inference.prompt_adapter import ( - ChatCompletionRequestWithRawContent, - CompletionRequestWithRawContent, -) from .parallel_utils import ModelParallelProcessGroup @@ -23,12 +18,14 @@ class ModelRunner: def __init__(self, llama): self.llama = llama - # the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()` def __call__(self, task: Any): - if task[0] == "chat_completion": - return self.llama.chat_completion(task[1]) + task_type = task[0] + if task_type == "chat_completion": + # task[1] is [params, raw_messages] + params, raw_messages = task[1] + return self.llama.chat_completion(params, raw_messages) else: - raise ValueError(f"Unexpected task type {task[0]}") + raise ValueError(f"Unexpected task type {task_type}") def init_model_cb( @@ -78,19 +75,3 @@ class LlamaModelParallelGenerator: def __exit__(self, exc_type, exc_value, exc_traceback): self.group.stop() - - def completion( - self, - request_batch: list[CompletionRequestWithRawContent], - ) -> Generator: - req_obj = deepcopy(request_batch) - gen = self.group.run_inference(("completion", req_obj)) - yield from gen - - def chat_completion( - self, - request_batch: list[ChatCompletionRequestWithRawContent], - ) -> Generator: - req_obj = deepcopy(request_batch) - gen = self.group.run_inference(("chat_completion", req_obj)) - yield from gen diff --git a/src/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py b/src/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py index bb6a1bd03..663e4793b 100644 --- a/src/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +++ b/src/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py @@ -33,10 +33,6 @@ from torch.distributed.launcher.api import LaunchConfig, elastic_launch from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import GenerationResult -from llama_stack.providers.utils.inference.prompt_adapter import ( - ChatCompletionRequestWithRawContent, - CompletionRequestWithRawContent, -) log = get_logger(name=__name__, category="inference") @@ -69,10 +65,7 @@ class CancelSentinel(BaseModel): class TaskRequest(BaseModel): type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request - task: tuple[ - str, - list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent], - ] + task: tuple[str, list] class TaskResponse(BaseModel): @@ -328,10 +321,7 @@ class ModelParallelProcessGroup: def run_inference( self, - req: tuple[ - str, - list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent], - ], + req: tuple[str, list], ) -> Generator: assert not self.running, "inference already running" diff --git a/src/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/src/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index cb72aa13a..e6dcf3ae7 100644 --- a/src/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/src/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -22,9 +22,6 @@ from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.utils.inference.embedding_mixin import ( SentenceTransformerEmbeddingMixin, ) -from llama_stack.providers.utils.inference.openai_compat import ( - OpenAIChatCompletionToLlamaStackMixin, -) from .config import SentenceTransformersInferenceConfig @@ -32,7 +29,6 @@ log = get_logger(name=__name__, category="inference") class SentenceTransformersInferenceImpl( - OpenAIChatCompletionToLlamaStackMixin, SentenceTransformerEmbeddingMixin, InferenceProvider, ModelsProtocolPrivate, diff --git a/src/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/src/llama_stack/providers/utils/inference/litellm_openai_mixin.py index 223497fb8..a793c499e 100644 --- a/src/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/src/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -11,9 +11,7 @@ from collections.abc import AsyncIterator import litellm from llama_stack.apis.inference import ( - ChatCompletionRequest, InferenceProvider, - JsonSchemaResponseFormat, OpenAIChatCompletion, OpenAIChatCompletionChunk, OpenAIChatCompletionRequestWithExtraBody, @@ -23,15 +21,11 @@ from llama_stack.apis.inference import ( OpenAIEmbeddingsRequestWithExtraBody, OpenAIEmbeddingsResponse, OpenAIEmbeddingUsage, - ToolChoice, ) from llama_stack.core.request_headers import NeedsRequestProviderData from llama_stack.log import get_logger from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry from llama_stack.providers.utils.inference.openai_compat import ( - convert_message_to_openai_dict_new, - convert_tooldef_to_openai_tool, - get_sampling_options, prepare_openai_completion_params, ) @@ -127,51 +121,6 @@ class LiteLLMOpenAIMixin( return schema - async def _get_params(self, request: ChatCompletionRequest) -> dict: - from typing import Any - - input_dict: dict[str, Any] = {} - - input_dict["messages"] = [ - await convert_message_to_openai_dict_new(m, download_images=self.download_images) for m in request.messages - ] - if fmt := request.response_format: - if not isinstance(fmt, JsonSchemaResponseFormat): - raise ValueError( - f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported." - ) - - # Convert to dict for manipulation - fmt_dict = dict(fmt.json_schema) - name = fmt_dict["title"] - del fmt_dict["title"] - fmt_dict["additionalProperties"] = False - - # Apply additionalProperties: False recursively to all objects - fmt_dict = self._add_additional_properties_recursive(fmt_dict) - - input_dict["response_format"] = { - "type": "json_schema", - "json_schema": { - "name": name, - "schema": fmt_dict, - "strict": self.json_schema_strict, - }, - } - if request.tools: - input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools] - if request.tool_config and (tool_choice := request.tool_config.tool_choice): - input_dict["tool_choice"] = tool_choice.value if isinstance(tool_choice, ToolChoice) else tool_choice - - return { - "model": request.model, - "api_key": self.get_api_key(), - "api_base": self.api_base, - **input_dict, - "stream": request.stream, - **get_sampling_options(request.sampling_params), - } - def get_api_key(self) -> str: provider_data = self.get_request_provider_data() key_field = self.provider_data_api_key_field diff --git a/src/llama_stack/providers/utils/inference/openai_compat.py b/src/llama_stack/providers/utils/inference/openai_compat.py index aabcb50f8..c2e6829e0 100644 --- a/src/llama_stack/providers/utils/inference/openai_compat.py +++ b/src/llama_stack/providers/utils/inference/openai_compat.py @@ -3,31 +3,14 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import json -import time -import uuid -import warnings -from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Iterable +from collections.abc import Iterable from typing import ( Any, ) -from openai import AsyncStream -from openai.types.chat import ( - ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage, -) -from openai.types.chat import ( - ChatCompletionChunk as OpenAIChatCompletionChunk, -) -from openai.types.chat import ( - ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam, -) from openai.types.chat import ( ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam, ) -from openai.types.chat import ( - ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam, -) try: from openai.types.chat import ( @@ -37,84 +20,24 @@ except ImportError: from openai.types.chat.chat_completion_message_tool_call import ( ChatCompletionMessageToolCall as OpenAIChatCompletionMessageFunctionToolCall, ) -from openai.types.chat import ( - ChatCompletionMessageParam as OpenAIChatCompletionMessage, -) from openai.types.chat import ( ChatCompletionMessageToolCall, ) -from openai.types.chat import ( - ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage, -) -from openai.types.chat import ( - ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage, -) -from openai.types.chat import ( - ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage, -) -from openai.types.chat.chat_completion import ( - Choice as OpenAIChoice, -) -from openai.types.chat.chat_completion import ( - ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs -) -from openai.types.chat.chat_completion_chunk import ( - Choice as OpenAIChatCompletionChunkChoice, -) -from openai.types.chat.chat_completion_chunk import ( - ChoiceDelta as OpenAIChoiceDelta, -) -from openai.types.chat.chat_completion_chunk import ( - ChoiceDeltaToolCall as OpenAIChoiceDeltaToolCall, -) -from openai.types.chat.chat_completion_chunk import ( - ChoiceDeltaToolCallFunction as OpenAIChoiceDeltaToolCallFunction, -) -from openai.types.chat.chat_completion_content_part_image_param import ( - ImageURL as OpenAIImageURL, -) -from openai.types.chat.chat_completion_message_tool_call import ( - Function as OpenAIFunction, -) from pydantic import BaseModel from llama_stack.apis.common.content_types import ( URL, ImageContentItem, - InterleavedContent, TextContentItem, - TextDelta, - ToolCallDelta, - ToolCallParseStatus, _URLOrData, ) from llama_stack.apis.inference import ( - ChatCompletionRequest, - ChatCompletionResponse, - ChatCompletionResponseEvent, - ChatCompletionResponseEventType, - ChatCompletionResponseStreamChunk, - CompletionMessage, - CompletionResponse, - CompletionResponseStreamChunk, GreedySamplingStrategy, JsonSchemaResponseFormat, - Message, - OpenAIChatCompletion, - OpenAIMessageParam, OpenAIResponseFormatParam, SamplingParams, - SystemMessage, - TokenLogProbs, - ToolChoice, - ToolConfig, - ToolResponseMessage, TopKSamplingStrategy, TopPSamplingStrategy, - UserMessage, -) -from llama_stack.apis.inference import ( - OpenAIChoice as OpenAIChatCompletionChoice, ) from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import ( @@ -123,10 +46,6 @@ from llama_stack.models.llama.datatypes import ( ToolCall, ToolDefinition, ) -from llama_stack.providers.utils.inference.prompt_adapter import ( - convert_image_content_to_url, - decode_assistant_message, -) logger = get_logger(name=__name__, category="providers::utils") @@ -213,345 +132,6 @@ def get_stop_reason(finish_reason: str) -> StopReason: return StopReason.out_of_tokens -def convert_openai_completion_logprobs( - logprobs: OpenAICompatLogprobs | None, -) -> list[TokenLogProbs] | None: - if not logprobs: - return None - if hasattr(logprobs, "top_logprobs") and logprobs.top_logprobs: - return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs] - - # Together supports logprobs with top_k=1 only. This means for each token position, - # they return only the logprobs for the selected token (vs. the top n most likely tokens). - # Here we construct the response by matching the selected token with the logprobs. - if logprobs.tokens and logprobs.token_logprobs: - return [ - TokenLogProbs(logprobs_by_token={token: token_lp}) - for token, token_lp in zip(logprobs.tokens, logprobs.token_logprobs, strict=False) - ] - return None - - -def convert_openai_completion_logprobs_stream(text: str, logprobs: float | OpenAICompatLogprobs | None): - if logprobs is None: - return None - if isinstance(logprobs, float): - # Adapt response from Together CompletionChoicesChunk - return [TokenLogProbs(logprobs_by_token={text: logprobs})] - if hasattr(logprobs, "top_logprobs") and logprobs.top_logprobs: - return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs] - return None - - -def process_completion_response( - response: OpenAICompatCompletionResponse, -) -> CompletionResponse: - choice = response.choices[0] - text = choice.text or "" - # drop suffix if present and return stop reason as end of turn - if text.endswith("<|eot_id|>"): - return CompletionResponse( - stop_reason=StopReason.end_of_turn, - content=text[: -len("<|eot_id|>")], - logprobs=convert_openai_completion_logprobs(choice.logprobs), - ) - # drop suffix if present and return stop reason as end of message - if text.endswith("<|eom_id|>"): - return CompletionResponse( - stop_reason=StopReason.end_of_message, - content=text[: -len("<|eom_id|>")], - logprobs=convert_openai_completion_logprobs(choice.logprobs), - ) - return CompletionResponse( - stop_reason=get_stop_reason(choice.finish_reason or "stop"), - content=text, - logprobs=convert_openai_completion_logprobs(choice.logprobs), - ) - - -def process_chat_completion_response( - response: OpenAICompatCompletionResponse, - request: ChatCompletionRequest, -) -> ChatCompletionResponse: - choice = response.choices[0] - if choice.finish_reason == "tool_calls": - if not hasattr(choice, "message") or not choice.message or not choice.message.tool_calls: # type: ignore[attr-defined] # OpenAICompatCompletionChoice is runtime duck-typed - raise ValueError("Tool calls are not present in the response") - - tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls] # type: ignore[attr-defined] # OpenAICompatCompletionChoice is runtime duck-typed - if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls): - # If we couldn't parse a tool call, jsonify the tool calls and return them - return ChatCompletionResponse( - completion_message=CompletionMessage( - stop_reason=StopReason.end_of_turn, - content=json.dumps(tool_calls, default=lambda x: x.model_dump()), - ), - logprobs=None, - ) - else: - # Otherwise, return tool calls as normal - # Filter to only valid ToolCall objects - valid_tool_calls = [tc for tc in tool_calls if isinstance(tc, ToolCall)] - return ChatCompletionResponse( - completion_message=CompletionMessage( - tool_calls=valid_tool_calls, - stop_reason=StopReason.end_of_turn, - # Content is not optional - content="", - ), - logprobs=None, - ) - - # TODO: This does not work well with tool calls for vLLM remote provider - # Ref: https://github.com/meta-llama/llama-stack/issues/1058 - raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason or "stop")) - - # NOTE: If we do not set tools in chat-completion request, we should not - # expect the ToolCall in the response. Instead, we should return the raw - # response from the model. - if raw_message.tool_calls: - if not request.tools: - raw_message.tool_calls = [] - raw_message.content = text_from_choice(choice) - else: - # only return tool_calls if provided in the request - new_tool_calls = [] - request_tools = {t.tool_name: t for t in request.tools} - for t in raw_message.tool_calls: - if t.tool_name in request_tools: - new_tool_calls.append(t) - else: - logger.warning(f"Tool {t.tool_name} not found in request tools") - - if len(new_tool_calls) < len(raw_message.tool_calls): - raw_message.tool_calls = new_tool_calls - raw_message.content = text_from_choice(choice) - - return ChatCompletionResponse( - completion_message=CompletionMessage( - content=raw_message.content, # type: ignore[arg-type] # decode_assistant_message returns Union[str, InterleavedContent] - stop_reason=raw_message.stop_reason or StopReason.end_of_turn, - tool_calls=raw_message.tool_calls, - ), - logprobs=None, - ) - - -async def process_completion_stream_response( - stream: AsyncGenerator[OpenAICompatCompletionResponse, None], -) -> AsyncGenerator[CompletionResponseStreamChunk, None]: - stop_reason = None - - async for chunk in stream: - choice = chunk.choices[0] - finish_reason = choice.finish_reason - - text = text_from_choice(choice) - if text == "<|eot_id|>": - stop_reason = StopReason.end_of_turn - text = "" - continue - elif text == "<|eom_id|>": - stop_reason = StopReason.end_of_message - text = "" - continue - yield CompletionResponseStreamChunk( - delta=text, - stop_reason=stop_reason, - logprobs=convert_openai_completion_logprobs_stream(text, choice.logprobs), - ) - if finish_reason: - if finish_reason in ["stop", "eos", "eos_token"]: - stop_reason = StopReason.end_of_turn - elif finish_reason == "length": - stop_reason = StopReason.out_of_tokens - break - - yield CompletionResponseStreamChunk( - delta="", - stop_reason=stop_reason, - ) - - -async def process_chat_completion_stream_response( - stream: AsyncGenerator[OpenAICompatCompletionResponse, None], - request: ChatCompletionRequest, -) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.start, - delta=TextDelta(text=""), - ) - ) - - buffer = "" - ipython = False - stop_reason = None - - async for chunk in stream: - choice = chunk.choices[0] - finish_reason = choice.finish_reason - - if finish_reason: - if stop_reason is None and finish_reason in ["stop", "eos", "eos_token"]: - stop_reason = StopReason.end_of_turn - elif stop_reason is None and finish_reason == "length": - stop_reason = StopReason.out_of_tokens - break - - text = text_from_choice(choice) - if not text: - # Sometimes you get empty chunks from providers - continue - - # check if its a tool call ( aka starts with <|python_tag|> ) - if not ipython and text.startswith("<|python_tag|>"): - ipython = True - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - tool_call="", - parse_status=ToolCallParseStatus.started, - ), - ) - ) - buffer += text - continue - - if text == "<|eot_id|>": - stop_reason = StopReason.end_of_turn - text = "" - continue - elif text == "<|eom_id|>": - stop_reason = StopReason.end_of_message - text = "" - continue - - if ipython: - buffer += text - delta = ToolCallDelta( - tool_call=text, - parse_status=ToolCallParseStatus.in_progress, - ) - - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=delta, - stop_reason=stop_reason, - ) - ) - else: - buffer += text - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=TextDelta(text=text), - stop_reason=stop_reason, - ) - ) - - # parse tool calls and report errors - message = decode_assistant_message(buffer, stop_reason or StopReason.end_of_turn) - - parsed_tool_calls = len(message.tool_calls) > 0 - if ipython and not parsed_tool_calls: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - tool_call="", - parse_status=ToolCallParseStatus.failed, - ), - stop_reason=stop_reason, - ) - ) - - request_tools = {t.tool_name: t for t in (request.tools or [])} - for tool_call in message.tool_calls: - if tool_call.tool_name in request_tools: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - tool_call=tool_call, - parse_status=ToolCallParseStatus.succeeded, - ), - stop_reason=stop_reason, - ) - ) - else: - logger.warning(f"Tool {tool_call.tool_name} not found in request tools") - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - # Parsing tool call failed due to tool call not being found in request tools, - # We still add the raw message text inside tool_call for responding back to the user - tool_call=buffer, - parse_status=ToolCallParseStatus.failed, - ), - stop_reason=stop_reason, - ) - ) - - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.complete, - delta=TextDelta(text=""), - stop_reason=stop_reason, - ) - ) - - -async def convert_message_to_openai_dict(message: Message, download: bool = False) -> dict: - async def _convert_content(content) -> dict: - if isinstance(content, ImageContentItem): - return { - "type": "image_url", - "image_url": { - "url": await convert_image_content_to_url(content, download=download), - }, - } - else: - text = content.text if isinstance(content, TextContentItem) else content - assert isinstance(text, str) - return {"type": "text", "text": text} - - if isinstance(message.content, list): - content = [await _convert_content(c) for c in message.content] - else: - content = [await _convert_content(message.content)] - - result = { - "role": message.role, - "content": content, - } - - if hasattr(message, "tool_calls") and message.tool_calls: - tool_calls_list = [] - for tc in message.tool_calls: - # The tool.tool_name can be a str or a BuiltinTool enum. If - # it's the latter, convert to a string. - tool_name = tc.tool_name - if isinstance(tool_name, BuiltinTool): - tool_name = tool_name.value - - tool_calls_list.append( - { - "id": tc.call_id, - "type": "function", - "function": { - "name": tool_name, - "arguments": tc.arguments, - }, - } - ) - result["tool_calls"] = tool_calls_list # type: ignore[assignment] # dict allows Any value, stricter type expected - return result - - class UnparseableToolCall(BaseModel): """ A ToolCall with arguments that are not valid JSON. @@ -563,112 +143,6 @@ class UnparseableToolCall(BaseModel): arguments: str = "" -async def convert_message_to_openai_dict_new( - message: Message | dict, - download_images: bool = False, -) -> OpenAIChatCompletionMessage: - """ - Convert a Message to an OpenAI API-compatible dictionary. - """ - # users can supply a dict instead of a Message object, we'll - # convert it to a Message object and proceed with some type safety. - if isinstance(message, dict): - if "role" not in message: - raise ValueError("role is required in message") - if message["role"] == "user": - message = UserMessage(**message) - elif message["role"] == "assistant": - message = CompletionMessage(**message) - elif message["role"] == "tool": - message = ToolResponseMessage(**message) - elif message["role"] == "system": - message = SystemMessage(**message) - else: - raise ValueError(f"Unsupported message role: {message['role']}") - - # Map Llama Stack spec to OpenAI spec - - # str -> str - # {"type": "text", "text": ...} -> {"type": "text", "text": ...} - # {"type": "image", "image": {"url": {"uri": ...}}} -> {"type": "image_url", "image_url": {"url": ...}} - # {"type": "image", "image": {"data": ...}} -> {"type": "image_url", "image_url": {"url": "data:image/?;base64,..."}} - # List[...] -> List[...] - async def _convert_message_content( - content: InterleavedContent, - ) -> str | Iterable[OpenAIChatCompletionContentPartParam]: - async def impl( - content_: InterleavedContent, - ) -> str | OpenAIChatCompletionContentPartParam | list[OpenAIChatCompletionContentPartParam]: - # Llama Stack and OpenAI spec match for str and text input - if isinstance(content_, str): - return content_ - elif isinstance(content_, TextContentItem): - return OpenAIChatCompletionContentPartTextParam( - type="text", - text=content_.text, - ) - elif isinstance(content_, ImageContentItem): - return OpenAIChatCompletionContentPartImageParam( - type="image_url", - image_url=OpenAIImageURL( - url=await convert_image_content_to_url(content_, download=download_images) - ), - ) - elif isinstance(content_, list): - return [await impl(item) for item in content_] # type: ignore[misc] # recursive list comprehension confuses mypy's type narrowing - else: - raise ValueError(f"Unsupported content type: {type(content_)}") - - ret = await impl(content) - - # OpenAI*Message expects a str or list - if isinstance(ret, str) or isinstance(ret, list): - return ret - else: - return [ret] - - out: OpenAIChatCompletionMessage - if isinstance(message, UserMessage): - out = OpenAIChatCompletionUserMessage( - role="user", - content=await _convert_message_content(message.content), - ) - elif isinstance(message, CompletionMessage): - tool_calls = [ - OpenAIChatCompletionMessageFunctionToolCall( - id=tool.call_id, - function=OpenAIFunction( - name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value), - arguments=tool.arguments, # Already a JSON string, don't double-encode - ), - type="function", - ) - for tool in (message.tool_calls or []) - ] - params = {} - if tool_calls: - params["tool_calls"] = tool_calls - out = OpenAIChatCompletionAssistantMessage( - role="assistant", - content=await _convert_message_content(message.content), - **params, # type: ignore[typeddict-item] # tool_calls dict expansion conflicts with TypedDict optional field - ) - elif isinstance(message, ToolResponseMessage): - out = OpenAIChatCompletionToolMessage( - role="tool", - tool_call_id=message.call_id, - content=await _convert_message_content(message.content), # type: ignore[typeddict-item] # content union type incompatible with TypedDict str requirement - ) - elif isinstance(message, SystemMessage): - out = OpenAIChatCompletionSystemMessage( - role="system", - content=await _convert_message_content(message.content), # type: ignore[typeddict-item] # content union type incompatible with TypedDict str requirement - ) - else: - raise ValueError(f"Unsupported message type: {type(message)}") - - return out - - def convert_tool_call( tool_call: ChatCompletionMessageToolCall, ) -> ToolCall | UnparseableToolCall: @@ -817,17 +291,6 @@ def _convert_openai_finish_reason(finish_reason: str) -> StopReason: }.get(finish_reason, StopReason.end_of_turn) -def _convert_openai_request_tool_config(tool_choice: str | dict[str, Any] | None = None) -> ToolConfig: - tool_config = ToolConfig() - if tool_choice: - try: - tool_choice = ToolChoice(tool_choice) # type: ignore[assignment] # reassigning to enum narrows union but mypy can't track after exception - except ValueError: - pass - tool_config.tool_choice = tool_choice # type: ignore[assignment] # ToolConfig.tool_choice accepts Union[ToolChoice, dict] but mypy tracks narrower type - return tool_config - - def _convert_openai_request_tools(tools: list[dict[str, Any]] | None = None) -> list[ToolDefinition]: lls_tools: list[ToolDefinition] = [] if not tools: @@ -898,40 +361,6 @@ def _convert_openai_tool_calls( ] -def _convert_openai_logprobs( - logprobs: OpenAIChoiceLogprobs, -) -> list[TokenLogProbs] | None: - """ - Convert an OpenAI ChoiceLogprobs into a list of TokenLogProbs. - - OpenAI ChoiceLogprobs: - content: Optional[List[ChatCompletionTokenLogprob]] - - OpenAI ChatCompletionTokenLogprob: - token: str - logprob: float - top_logprobs: List[TopLogprob] - - OpenAI TopLogprob: - token: str - logprob: float - - -> - - TokenLogProbs: - logprobs_by_token: Dict[str, float] - - token, logprob - - """ - if not logprobs or not logprobs.content: - return None - - return [ - TokenLogProbs(logprobs_by_token={logprobs.token: logprobs.logprob for logprobs in content.top_logprobs}) - for content in logprobs.content - ] - - def _convert_openai_sampling_params( max_tokens: int | None = None, temperature: float | None = None, @@ -956,37 +385,6 @@ def _convert_openai_sampling_params( return sampling_params -def openai_messages_to_messages( - messages: list[OpenAIMessageParam], -) -> list[Message]: - """ - Convert a list of OpenAIChatCompletionMessage into a list of Message. - """ - converted_messages: list[Message] = [] - for message in messages: - converted_message: Message - if message.role == "system": - converted_message = SystemMessage(content=openai_content_to_content(message.content)) # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types - elif message.role == "user": - converted_message = UserMessage(content=openai_content_to_content(message.content)) # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types - elif message.role == "assistant": - converted_message = CompletionMessage( - content=openai_content_to_content(message.content), # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types - tool_calls=_convert_openai_tool_calls(message.tool_calls) if message.tool_calls else [], # type: ignore[arg-type] # OpenAI tool_calls type incompatible with conversion function - stop_reason=StopReason.end_of_turn, - ) - elif message.role == "tool": - converted_message = ToolResponseMessage( - role="tool", - call_id=message.tool_call_id, - content=openai_content_to_content(message.content), # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types - ) - else: - raise ValueError(f"Unknown role {message.role}") - converted_messages.append(converted_message) - return converted_messages - - def openai_content_to_content(content: str | Iterable[OpenAIChatCompletionContentPartParam] | None): if content is None: return "" @@ -1005,216 +403,6 @@ def openai_content_to_content(content: str | Iterable[OpenAIChatCompletionConten raise ValueError(f"Unknown content type: {content}") -def convert_openai_chat_completion_choice( - choice: OpenAIChoice, -) -> ChatCompletionResponse: - """ - Convert an OpenAI Choice into a ChatCompletionResponse. - - OpenAI Choice: - message: ChatCompletionMessage - finish_reason: str - logprobs: Optional[ChoiceLogprobs] - - OpenAI ChatCompletionMessage: - role: Literal["assistant"] - content: Optional[str] - tool_calls: Optional[List[ChatCompletionMessageToolCall]] - - -> - - ChatCompletionResponse: - completion_message: CompletionMessage - logprobs: Optional[List[TokenLogProbs]] - - CompletionMessage: - role: Literal["assistant"] - content: str | ImageMedia | List[str | ImageMedia] - stop_reason: StopReason - tool_calls: List[ToolCall] - - class StopReason(Enum): - end_of_turn = "end_of_turn" - end_of_message = "end_of_message" - out_of_tokens = "out_of_tokens" - """ - assert hasattr(choice, "message") and choice.message, "error in server response: message not found" - assert hasattr(choice, "finish_reason") and choice.finish_reason, ( - "error in server response: finish_reason not found" - ) - - return ChatCompletionResponse( - completion_message=CompletionMessage( - content=choice.message.content or "", # CompletionMessage content is not optional - stop_reason=_convert_openai_finish_reason(choice.finish_reason), - tool_calls=_convert_openai_tool_calls(choice.message.tool_calls) if choice.message.tool_calls else [], # type: ignore[arg-type] # OpenAI tool_calls Optional type broadens union - ), - logprobs=_convert_openai_logprobs(getattr(choice, "logprobs", None)), # type: ignore[arg-type] # getattr returns Any, can't narrow without inspection - ) - - -async def convert_openai_chat_completion_stream( - stream: AsyncStream[OpenAIChatCompletionChunk], - enable_incremental_tool_calls: bool, -) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]: - """ - Convert a stream of OpenAI chat completion chunks into a stream - of ChatCompletionResponseStreamChunk. - """ - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.start, - delta=TextDelta(text=""), - ) - ) - event_type = ChatCompletionResponseEventType.progress - - stop_reason = None - tool_call_idx_to_buffer = {} - - async for chunk in stream: - choice = chunk.choices[0] # assuming only one choice per chunk - - # we assume there's only one finish_reason in the stream - stop_reason = _convert_openai_finish_reason(choice.finish_reason) if choice.finish_reason else stop_reason - logprobs = getattr(choice, "logprobs", None) - - # if there's a tool call, emit an event for each tool in the list - # if tool call and content, emit both separately - if choice.delta.tool_calls: - # the call may have content and a tool call. ChatCompletionResponseEvent - # does not support both, so we emit the content first - if choice.delta.content: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=event_type, - delta=TextDelta(text=choice.delta.content), - logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result - ) - ) - - # it is possible to have parallel tool calls in stream, but - # ChatCompletionResponseEvent only supports one per stream - if len(choice.delta.tool_calls) > 1: - warnings.warn( - "multiple tool calls found in a single delta, using the first, ignoring the rest", - stacklevel=2, - ) - - if not enable_incremental_tool_calls: - for tool_call in choice.delta.tool_calls: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=event_type, - delta=ToolCallDelta( - tool_call=_convert_openai_tool_calls([tool_call])[0], # type: ignore[arg-type, list-item] # delta tool_call type differs from complete tool_call - parse_status=ToolCallParseStatus.succeeded, - ), - logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result - ) - ) - else: - for tool_call in choice.delta.tool_calls: - idx = tool_call.index if hasattr(tool_call, "index") else 0 - - if idx not in tool_call_idx_to_buffer: - tool_call_idx_to_buffer[idx] = { - "call_id": tool_call.id, - "name": None, - "arguments": "", - "content": "", - } - - buffer = tool_call_idx_to_buffer[idx] - - if tool_call.function: - if tool_call.function.name: - buffer["name"] = tool_call.function.name - delta = f"{buffer['name']}(" - if buffer["content"] is not None: - buffer["content"] += delta - - if tool_call.function.arguments: - delta = tool_call.function.arguments - if buffer["arguments"] is not None and delta: - buffer["arguments"] += delta - if buffer["content"] is not None and delta: - buffer["content"] += delta - - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=event_type, - delta=ToolCallDelta( - tool_call=delta, - parse_status=ToolCallParseStatus.in_progress, - ), - logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result - ) - ) - elif choice.delta.content: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=event_type, - delta=TextDelta(text=choice.delta.content or ""), - logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result - ) - ) - - for idx, buffer in tool_call_idx_to_buffer.items(): - logger.debug(f"toolcall_buffer[{idx}]: {buffer}") - if buffer["name"]: - delta = ")" - if buffer["content"] is not None: - buffer["content"] += delta - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=event_type, - delta=ToolCallDelta( - tool_call=delta, - parse_status=ToolCallParseStatus.in_progress, - ), - logprobs=None, - ) - ) - - try: - parsed_tool_call = ToolCall( - call_id=buffer["call_id"] or "", - tool_name=buffer["name"] or "", - arguments=buffer["arguments"] or "", - ) - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - tool_call=parsed_tool_call, # type: ignore[arg-type] # ToolCallDelta.tool_call accepts Union[str, ToolCall] - parse_status=ToolCallParseStatus.succeeded, - ), - stop_reason=stop_reason, - ) - ) - except json.JSONDecodeError as e: - print(f"Failed to parse arguments: {e}") - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - tool_call=buffer["content"], # type: ignore[arg-type] # ToolCallDelta.tool_call accepts Union[str, ToolCall] - parse_status=ToolCallParseStatus.failed, - ), - stop_reason=stop_reason, - ) - ) - - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.complete, - delta=TextDelta(text=""), - stop_reason=stop_reason, - ) - ) - - async def prepare_openai_completion_params(**params): async def _prepare_value(value: Any) -> Any: new_value = value @@ -1233,163 +421,6 @@ async def prepare_openai_completion_params(**params): return completion_params -class OpenAIChatCompletionToLlamaStackMixin: - async def openai_chat_completion( - self, - model: str, - messages: list[OpenAIMessageParam], - frequency_penalty: float | None = None, - function_call: str | dict[str, Any] | None = None, - functions: list[dict[str, Any]] | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_completion_tokens: int | None = None, - max_tokens: int | None = None, - n: int | None = None, - parallel_tool_calls: bool | None = None, - presence_penalty: float | None = None, - response_format: OpenAIResponseFormatParam | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - tool_choice: str | dict[str, Any] | None = None, - tools: list[dict[str, Any]] | None = None, - top_logprobs: int | None = None, - top_p: float | None = None, - user: str | None = None, - ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: - messages = openai_messages_to_messages(messages) # type: ignore[assignment] # converted from OpenAI to LlamaStack message format - response_format = _convert_openai_request_response_format(response_format) - sampling_params = _convert_openai_sampling_params( - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - ) - tool_config = _convert_openai_request_tool_config(tool_choice) - - tools = _convert_openai_request_tools(tools) # type: ignore[assignment] # converted from OpenAI to LlamaStack tool format - if tool_config.tool_choice == ToolChoice.none: - tools = [] # type: ignore[assignment] # empty list narrows return type but mypy tracks broader type - - outstanding_responses = [] - # "n" is the number of completions to generate per prompt - n = n or 1 - for _i in range(0, n): - response = self.chat_completion( # type: ignore[attr-defined] # mixin expects class to implement chat_completion - model_id=model, - messages=messages, - sampling_params=sampling_params, - response_format=response_format, - stream=stream, - tool_config=tool_config, - tools=tools, - ) - outstanding_responses.append(response) - - if stream: - return OpenAIChatCompletionToLlamaStackMixin._process_stream_response(self, model, outstanding_responses) # type: ignore[no-any-return] # mixin async generator return type too complex for mypy - - return await OpenAIChatCompletionToLlamaStackMixin._process_non_stream_response( - self, model, outstanding_responses - ) - - async def _process_stream_response( - self, - model: str, - outstanding_responses: list[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]], - ): - id = f"chatcmpl-{uuid.uuid4()}" - for i, outstanding_response in enumerate(outstanding_responses): - response = await outstanding_response - async for chunk in response: - event = chunk.event - finish_reason = ( - _convert_stop_reason_to_openai_finish_reason(event.stop_reason) if event.stop_reason else None - ) - - if isinstance(event.delta, TextDelta): - text_delta = event.delta.text - delta = OpenAIChoiceDelta(content=text_delta) - yield OpenAIChatCompletionChunk( - id=id, - choices=[OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)], # type: ignore[arg-type] # finish_reason Optional[str] incompatible with Literal union - created=int(time.time()), - model=model, - object="chat.completion.chunk", - ) - elif isinstance(event.delta, ToolCallDelta): - if event.delta.parse_status == ToolCallParseStatus.succeeded: - tool_call = event.delta.tool_call - if isinstance(tool_call, str): - continue - - # First chunk includes full structure - openai_tool_call = OpenAIChoiceDeltaToolCall( - index=0, - id=tool_call.call_id, - function=OpenAIChoiceDeltaToolCallFunction( - name=tool_call.tool_name - if isinstance(tool_call.tool_name, str) - else tool_call.tool_name.value, # type: ignore[arg-type] # enum .value extraction on Union confuses mypy - arguments="", - ), - ) - delta = OpenAIChoiceDelta(tool_calls=[openai_tool_call]) - yield OpenAIChatCompletionChunk( - id=id, - choices=[ - OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta) # type: ignore[arg-type] # finish_reason Optional[str] incompatible with Literal union - ], - created=int(time.time()), - model=model, - object="chat.completion.chunk", - ) - # arguments - openai_tool_call = OpenAIChoiceDeltaToolCall( - index=0, - function=OpenAIChoiceDeltaToolCallFunction( - arguments=tool_call.arguments, - ), - ) - delta = OpenAIChoiceDelta(tool_calls=[openai_tool_call]) - yield OpenAIChatCompletionChunk( - id=id, - choices=[ - OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta) # type: ignore[arg-type] # finish_reason Optional[str] incompatible with Literal union - ], - created=int(time.time()), - model=model, - object="chat.completion.chunk", - ) - - async def _process_non_stream_response( - self, model: str, outstanding_responses: list[Awaitable[ChatCompletionResponse]] - ) -> OpenAIChatCompletion: - choices: list[OpenAIChatCompletionChoice] = [] - for outstanding_response in outstanding_responses: - response = await outstanding_response - completion_message = response.completion_message - message = await convert_message_to_openai_dict_new(completion_message) - finish_reason = _convert_stop_reason_to_openai_finish_reason(completion_message.stop_reason) - - choice = OpenAIChatCompletionChoice( - index=len(choices), - message=message, # type: ignore[arg-type] # OpenAIChatCompletionMessage union incompatible with narrower Message type - finish_reason=finish_reason, - ) - choices.append(choice) # type: ignore[arg-type] # OpenAIChatCompletionChoice type annotation mismatch - - return OpenAIChatCompletion( - id=f"chatcmpl-{uuid.uuid4()}", - choices=choices, # type: ignore[arg-type] # list[OpenAIChatCompletionChoice] union incompatible - created=int(time.time()), - model=model, - object="chat.completion", - ) - - def prepare_openai_embeddings_params( model: str, input: str | list[str], diff --git a/src/llama_stack/providers/utils/inference/prompt_adapter.py b/src/llama_stack/providers/utils/inference/prompt_adapter.py index d06b7454d..35a7b3484 100644 --- a/src/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/src/llama_stack/providers/utils/inference/prompt_adapter.py @@ -21,19 +21,18 @@ from llama_stack.apis.common.content_types import ( TextContentItem, ) from llama_stack.apis.inference import ( - ChatCompletionRequest, CompletionRequest, - Message, + OpenAIAssistantMessageParam, OpenAIChatCompletionContentPartImageParam, OpenAIChatCompletionContentPartTextParam, OpenAIFile, + OpenAIMessageParam, + OpenAISystemMessageParam, + OpenAIToolMessageParam, + OpenAIUserMessageParam, ResponseFormat, ResponseFormatType, - SystemMessage, - SystemMessageBehavior, ToolChoice, - ToolDefinition, - UserMessage, ) from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import ( @@ -42,33 +41,19 @@ from llama_stack.models.llama.datatypes import ( RawMediaItem, RawMessage, RawTextItem, - Role, StopReason, + ToolCall, + ToolDefinition, ToolPromptFormat, ) from llama_stack.models.llama.llama3.chat_format import ChatFormat -from llama_stack.models.llama.llama3.prompt_templates import ( - BuiltinToolGenerator, - FunctionTagCustomToolGenerator, - JsonCustomToolGenerator, - PythonListCustomToolGenerator, - SystemDefaultGenerator, -) from llama_stack.models.llama.llama3.tokenizer import Tokenizer -from llama_stack.models.llama.llama4.prompt_templates.system_prompts import ( - PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4, -) from llama_stack.models.llama.sku_list import resolve_model from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal -from llama_stack.providers.utils.inference import supported_inference_models log = get_logger(name=__name__, category="providers::utils") -class ChatCompletionRequestWithRawContent(ChatCompletionRequest): - messages: list[RawMessage] - - class CompletionRequestWithRawContent(CompletionRequest): content: RawContent @@ -103,28 +88,6 @@ def interleaved_content_as_str( return _process(content) -async def convert_request_to_raw( - request: ChatCompletionRequest | CompletionRequest, -) -> ChatCompletionRequestWithRawContent | CompletionRequestWithRawContent: - if isinstance(request, ChatCompletionRequest): - messages = [] - for m in request.messages: - content = await interleaved_content_convert_to_raw(m.content) - d = m.model_dump() - d["content"] = content - messages.append(RawMessage(**d)) - - d = request.model_dump() - d["messages"] = messages - request = ChatCompletionRequestWithRawContent(**d) - else: - d = request.model_dump() - d["content"] = await interleaved_content_convert_to_raw(request.content) - request = CompletionRequestWithRawContent(**d) - - return request - - async def interleaved_content_convert_to_raw( content: InterleavedContent, ) -> RawContent: @@ -171,6 +134,36 @@ async def interleaved_content_convert_to_raw( return await _localize_single(content) +async def convert_openai_message_to_raw_message(message: OpenAIMessageParam) -> RawMessage: + """Convert OpenAI message format to RawMessage format used by Llama formatters.""" + if isinstance(message, OpenAIUserMessageParam): + content = await interleaved_content_convert_to_raw(message.content) # type: ignore[arg-type] + return RawMessage(role="user", content=content) + elif isinstance(message, OpenAISystemMessageParam): + content = await interleaved_content_convert_to_raw(message.content) # type: ignore[arg-type] + return RawMessage(role="system", content=content) + elif isinstance(message, OpenAIAssistantMessageParam): + content = await interleaved_content_convert_to_raw(message.content or "") # type: ignore[arg-type] + tool_calls = [] + if message.tool_calls: + for tc in message.tool_calls: + if tc.function: + tool_calls.append( + ToolCall( + call_id=tc.id or "", + tool_name=tc.function.name or "", + arguments=tc.function.arguments or "{}", + ) + ) + return RawMessage(role="assistant", content=content, tool_calls=tool_calls) + elif isinstance(message, OpenAIToolMessageParam): + content = await interleaved_content_convert_to_raw(message.content) # type: ignore[arg-type] + return RawMessage(role="tool", content=content) + else: + # Handle OpenAIDeveloperMessageParam if needed + raise ValueError(f"Unsupported message type: {type(message)}") + + def content_has_media(content: InterleavedContent): def _has_media_content(c): return isinstance(c, ImageContentItem) @@ -181,17 +174,6 @@ def content_has_media(content: InterleavedContent): return _has_media_content(content) -def messages_have_media(messages: list[Message]): - return any(content_has_media(m.content) for m in messages) - - -def request_has_media(request: ChatCompletionRequest | CompletionRequest): - if isinstance(request, ChatCompletionRequest): - return messages_have_media(request.messages) - else: - return content_has_media(request.content) - - async def localize_image_content(uri: str) -> tuple[bytes, str] | None: if uri.startswith("http"): async with httpx.AsyncClient() as client: @@ -253,79 +235,6 @@ def augment_content_with_response_format_prompt(response_format, content): return content -async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llama_model: str) -> str: - messages = chat_completion_request_to_messages(request, llama_model) - request.messages = messages - request = await convert_request_to_raw(request) - - formatter = ChatFormat(tokenizer=Tokenizer.get_instance()) - model_input = formatter.encode_dialog_prompt( - request.messages, - tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model), - ) - return formatter.tokenizer.decode(model_input.tokens) - - -async def chat_completion_request_to_model_input_info( - request: ChatCompletionRequest, llama_model: str -) -> tuple[str, int]: - messages = chat_completion_request_to_messages(request, llama_model) - request.messages = messages - request = await convert_request_to_raw(request) - - formatter = ChatFormat(tokenizer=Tokenizer.get_instance()) - model_input = formatter.encode_dialog_prompt( - request.messages, - tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model), - ) - return ( - formatter.tokenizer.decode(model_input.tokens), - len(model_input.tokens), - ) - - -def chat_completion_request_to_messages( - request: ChatCompletionRequest, - llama_model: str, -) -> list[Message]: - """Reads chat completion request and augments the messages to handle tools. - For eg. for llama_3_1, add system message with the appropriate tools or - add user messsage for custom tools, etc. - """ - assert llama_model is not None, "llama_model is required" - model = resolve_model(llama_model) - if model is None: - log.error(f"Could not resolve model {llama_model}") - return request.messages - - allowed_models = supported_inference_models() - descriptors = [m.descriptor() for m in allowed_models] - if model.descriptor() not in descriptors: - log.error(f"Unsupported inference model? {model.descriptor()}") - return request.messages - - if model.model_family == ModelFamily.llama3_1 or ( - model.model_family == ModelFamily.llama3_2 and is_multimodal(model.core_model_id) - ): - # llama3.1 and llama3.2 multimodal models follow the same tool prompt format - messages = augment_messages_for_tools_llama_3_1(request) - elif model.model_family in ( - ModelFamily.llama3_2, - ModelFamily.llama3_3, - ): - # llama3.2, llama3.3 follow the same tool prompt format - messages = augment_messages_for_tools_llama(request, PythonListCustomToolGenerator) - elif model.model_family == ModelFamily.llama4: - messages = augment_messages_for_tools_llama(request, PythonListCustomToolGeneratorLlama4) - else: - messages = request.messages - - if fmt_prompt := response_format_prompt(request.response_format): - messages.append(UserMessage(content=fmt_prompt)) - - return messages - - def response_format_prompt(fmt: ResponseFormat | None): if not fmt: return None @@ -338,128 +247,6 @@ def response_format_prompt(fmt: ResponseFormat | None): raise ValueError(f"Unknown response format {fmt.type}") -def augment_messages_for_tools_llama_3_1( - request: ChatCompletionRequest, -) -> list[Message]: - existing_messages = request.messages - existing_system_message = None - if existing_messages[0].role == Role.system.value: - existing_system_message = existing_messages.pop(0) - - assert existing_messages[0].role != Role.system.value, "Should only have 1 system message" - - messages = [] - - default_gen = SystemDefaultGenerator() - default_template = default_gen.gen() - - sys_content = "" - - tool_template = None - if request.tools: - tool_gen = BuiltinToolGenerator() - tool_template = tool_gen.gen(request.tools) - - sys_content += tool_template.render() - sys_content += "\n" - - sys_content += default_template.render() - - if existing_system_message: - # TODO: this fn is needed in many places - def _process(c): - if isinstance(c, str): - return c - else: - return "" - - sys_content += "\n" - - if isinstance(existing_system_message.content, str): - sys_content += _process(existing_system_message.content) - elif isinstance(existing_system_message.content, list): - sys_content += "\n".join([_process(c) for c in existing_system_message.content]) - - tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools) - if tool_choice_prompt: - sys_content += "\n" + tool_choice_prompt - - messages.append(SystemMessage(content=sys_content)) - - has_custom_tools = request.tools is not None and any(isinstance(dfn.tool_name, str) for dfn in request.tools) - if has_custom_tools: - fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.json - if fmt == ToolPromptFormat.json: - tool_gen = JsonCustomToolGenerator() - elif fmt == ToolPromptFormat.function_tag: - tool_gen = FunctionTagCustomToolGenerator() - else: - raise ValueError(f"Non supported ToolPromptFormat {fmt}") - - custom_tools = [t for t in request.tools if isinstance(t.tool_name, str)] - custom_template = tool_gen.gen(custom_tools) - messages.append(UserMessage(content=custom_template.render())) - - # Add back existing messages from the request - messages += existing_messages - - return messages - - -def augment_messages_for_tools_llama( - request: ChatCompletionRequest, - custom_tool_prompt_generator, -) -> list[Message]: - existing_messages = request.messages - existing_system_message = None - if existing_messages[0].role == Role.system.value: - existing_system_message = existing_messages.pop(0) - - assert existing_messages[0].role != Role.system.value, "Should only have 1 system message" - - sys_content = "" - custom_tools, builtin_tools = [], [] - for t in request.tools: - if isinstance(t.tool_name, str): - custom_tools.append(t) - else: - builtin_tools.append(t) - - if builtin_tools: - tool_gen = BuiltinToolGenerator() - tool_template = tool_gen.gen(builtin_tools) - - sys_content += tool_template.render() - sys_content += "\n" - - custom_tools = [dfn for dfn in request.tools if isinstance(dfn.tool_name, str)] - if custom_tools: - fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.python_list - if fmt != ToolPromptFormat.python_list: - raise ValueError(f"Non supported ToolPromptFormat {request.tool_config.tool_prompt_format}") - - system_prompt = None - if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace: - system_prompt = existing_system_message.content - - tool_template = custom_tool_prompt_generator().gen(custom_tools, system_prompt) - - sys_content += tool_template.render() - sys_content += "\n" - - if existing_system_message and ( - request.tool_config.system_message_behavior == SystemMessageBehavior.append or not custom_tools - ): - sys_content += interleaved_content_as_str(existing_system_message.content, sep="\n") - - tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools) - if tool_choice_prompt: - sys_content += "\n" + tool_choice_prompt - - messages = [SystemMessage(content=sys_content.strip("\n")), *existing_messages] - return messages - - def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: list[ToolDefinition]) -> str: if tool_choice == ToolChoice.auto: return "" diff --git a/tests/unit/models/test_prompt_adapter.py b/tests/unit/models/test_prompt_adapter.py deleted file mode 100644 index d31426135..000000000 --- a/tests/unit/models/test_prompt_adapter.py +++ /dev/null @@ -1,303 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - - -from llama_stack.apis.inference import ( - ChatCompletionRequest, - CompletionMessage, - StopReason, - SystemMessage, - SystemMessageBehavior, - ToolCall, - ToolConfig, - UserMessage, -) -from llama_stack.models.llama.datatypes import ( - BuiltinTool, - ToolDefinition, - ToolPromptFormat, -) -from llama_stack.providers.utils.inference.prompt_adapter import ( - chat_completion_request_to_messages, - chat_completion_request_to_prompt, - interleaved_content_as_str, -) - -MODEL = "Llama3.1-8B-Instruct" -MODEL3_2 = "Llama3.2-3B-Instruct" - - -async def test_system_default(): - content = "Hello !" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - UserMessage(content=content), - ], - ) - messages = chat_completion_request_to_messages(request, MODEL) - assert len(messages) == 2 - assert messages[-1].content == content - assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content) - - -async def test_system_builtin_only(): - content = "Hello !" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ToolDefinition(tool_name=BuiltinTool.brave_search), - ], - ) - messages = chat_completion_request_to_messages(request, MODEL) - assert len(messages) == 2 - assert messages[-1].content == content - assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content) - assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content) - - -async def test_system_custom_only(): - content = "Hello !" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - UserMessage(content=content), - ], - tools=[ - ToolDefinition( - tool_name="custom1", - description="custom1 tool", - input_schema={ - "type": "object", - "properties": { - "param1": { - "type": "str", - "description": "param1 description", - }, - }, - "required": ["param1"], - }, - ) - ], - tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json), - ) - messages = chat_completion_request_to_messages(request, MODEL) - assert len(messages) == 3 - assert "Environment: ipython" in interleaved_content_as_str(messages[0].content) - - assert "Return function calls in JSON format" in interleaved_content_as_str(messages[1].content) - assert messages[-1].content == content - - -async def test_system_custom_and_builtin(): - content = "Hello !" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ToolDefinition(tool_name=BuiltinTool.brave_search), - ToolDefinition( - tool_name="custom1", - description="custom1 tool", - input_schema={ - "type": "object", - "properties": { - "param1": { - "type": "str", - "description": "param1 description", - }, - }, - "required": ["param1"], - }, - ), - ], - ) - messages = chat_completion_request_to_messages(request, MODEL) - assert len(messages) == 3 - - assert "Environment: ipython" in interleaved_content_as_str(messages[0].content) - assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content) - - assert "Return function calls in JSON format" in interleaved_content_as_str(messages[1].content) - assert messages[-1].content == content - - -async def test_completion_message_encoding(): - request = ChatCompletionRequest( - model=MODEL3_2, - messages=[ - UserMessage(content="hello"), - CompletionMessage( - content="", - stop_reason=StopReason.end_of_turn, - tool_calls=[ - ToolCall( - tool_name="custom1", - arguments='{"param1": "value1"}', # arguments must be a JSON string - call_id="123", - ) - ], - ), - ], - tools=[ - ToolDefinition( - tool_name="custom1", - description="custom1 tool", - input_schema={ - "type": "object", - "properties": { - "param1": { - "type": "str", - "description": "param1 description", - }, - }, - "required": ["param1"], - }, - ), - ], - tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list), - ) - prompt = await chat_completion_request_to_prompt(request, request.model) - assert '[custom1(param1="value1")]' in prompt - - request.model = MODEL - request.tool_config = ToolConfig(tool_prompt_format=ToolPromptFormat.json) - prompt = await chat_completion_request_to_prompt(request, request.model) - assert '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}' in prompt - - -async def test_user_provided_system_message(): - content = "Hello !" - system_prompt = "You are a pirate" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - SystemMessage(content=system_prompt), - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ], - ) - messages = chat_completion_request_to_messages(request, MODEL) - assert len(messages) == 2 - assert interleaved_content_as_str(messages[0].content).endswith(system_prompt) - - assert messages[-1].content == content - - -async def test_replace_system_message_behavior_builtin_tools(): - content = "Hello !" - system_prompt = "You are a pirate" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - SystemMessage(content=system_prompt), - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ], - tool_config=ToolConfig( - tool_choice="auto", - tool_prompt_format=ToolPromptFormat.python_list, - system_message_behavior=SystemMessageBehavior.replace, - ), - ) - messages = chat_completion_request_to_messages(request, MODEL3_2) - assert len(messages) == 2 - assert interleaved_content_as_str(messages[0].content).endswith(system_prompt) - assert "Environment: ipython" in interleaved_content_as_str(messages[0].content) - assert messages[-1].content == content - - -async def test_replace_system_message_behavior_custom_tools(): - content = "Hello !" - system_prompt = "You are a pirate" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - SystemMessage(content=system_prompt), - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ToolDefinition( - tool_name="custom1", - description="custom1 tool", - input_schema={ - "type": "object", - "properties": { - "param1": { - "type": "str", - "description": "param1 description", - }, - }, - "required": ["param1"], - }, - ), - ], - tool_config=ToolConfig( - tool_choice="auto", - tool_prompt_format=ToolPromptFormat.python_list, - system_message_behavior=SystemMessageBehavior.replace, - ), - ) - messages = chat_completion_request_to_messages(request, MODEL3_2) - - assert len(messages) == 2 - assert interleaved_content_as_str(messages[0].content).endswith(system_prompt) - assert "Environment: ipython" in interleaved_content_as_str(messages[0].content) - assert messages[-1].content == content - - -async def test_replace_system_message_behavior_custom_tools_with_template(): - content = "Hello !" - system_prompt = "You are a pirate {{ function_description }}" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - SystemMessage(content=system_prompt), - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ToolDefinition( - tool_name="custom1", - description="custom1 tool", - input_schema={ - "type": "object", - "properties": { - "param1": { - "type": "str", - "description": "param1 description", - }, - }, - "required": ["param1"], - }, - ), - ], - tool_config=ToolConfig( - tool_choice="auto", - tool_prompt_format=ToolPromptFormat.python_list, - system_message_behavior=SystemMessageBehavior.replace, - ), - ) - messages = chat_completion_request_to_messages(request, MODEL3_2) - - assert len(messages) == 2 - assert "Environment: ipython" in interleaved_content_as_str(messages[0].content) - assert "You are a pirate" in interleaved_content_as_str(messages[0].content) - # function description is present in the system prompt - assert '"name": "custom1"' in interleaved_content_as_str(messages[0].content) - assert messages[-1].content == content diff --git a/tests/unit/providers/inline/inference/__init__.py b/tests/unit/providers/inline/inference/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/tests/unit/providers/inline/inference/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/tests/unit/providers/inline/inference/test_meta_reference.py b/tests/unit/providers/inline/inference/test_meta_reference.py new file mode 100644 index 000000000..381836397 --- /dev/null +++ b/tests/unit/providers/inline/inference/test_meta_reference.py @@ -0,0 +1,44 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import Mock + +import pytest + +from llama_stack.providers.inline.inference.meta_reference.model_parallel import ( + ModelRunner, +) + + +class TestModelRunner: + """Test ModelRunner task dispatching for model-parallel inference.""" + + def test_chat_completion_task_dispatch(self): + """Verify ModelRunner correctly dispatches chat_completion tasks.""" + # Create a mock generator + mock_generator = Mock() + mock_generator.chat_completion = Mock(return_value=iter([])) + + runner = ModelRunner(mock_generator) + + # Create a chat_completion task + fake_params = {"model": "test"} + fake_messages = [{"role": "user", "content": "test"}] + task = ("chat_completion", [fake_params, fake_messages]) + + # Execute task + runner(task) + + # Verify chat_completion was called with correct arguments + mock_generator.chat_completion.assert_called_once_with(fake_params, fake_messages) + + def test_invalid_task_type_raises_error(self): + """Verify ModelRunner rejects invalid task types.""" + mock_generator = Mock() + runner = ModelRunner(mock_generator) + + with pytest.raises(ValueError, match="Unexpected task type"): + runner(("invalid_task", [])) diff --git a/tests/unit/providers/nvidia/test_safety.py b/tests/unit/providers/nvidia/test_safety.py index 922d7f61f..622302630 100644 --- a/tests/unit/providers/nvidia/test_safety.py +++ b/tests/unit/providers/nvidia/test_safety.py @@ -10,11 +10,13 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest -from llama_stack.apis.inference import CompletionMessage, UserMessage +from llama_stack.apis.inference import ( + OpenAIAssistantMessageParam, + OpenAIUserMessageParam, +) from llama_stack.apis.resource import ResourceType from llama_stack.apis.safety import RunShieldResponse, ViolationLevel from llama_stack.apis.shields import Shield -from llama_stack.models.llama.datatypes import StopReason from llama_stack.providers.remote.safety.nvidia.config import NVIDIASafetyConfig from llama_stack.providers.remote.safety.nvidia.nvidia import NVIDIASafetyAdapter @@ -136,11 +138,9 @@ async def test_run_shield_allowed(nvidia_adapter, mock_guardrails_post): # Run the shield messages = [ - UserMessage(role="user", content="Hello, how are you?"), - CompletionMessage( - role="assistant", + OpenAIUserMessageParam(content="Hello, how are you?"), + OpenAIAssistantMessageParam( content="I'm doing well, thank you for asking!", - stop_reason=StopReason.end_of_message, tool_calls=[], ), ] @@ -191,13 +191,10 @@ async def test_run_shield_blocked(nvidia_adapter, mock_guardrails_post): # Mock Guardrails API response mock_guardrails_post.return_value = {"status": "blocked", "rails_status": {"reason": "harmful_content"}} - # Run the shield messages = [ - UserMessage(role="user", content="Hello, how are you?"), - CompletionMessage( - role="assistant", + OpenAIUserMessageParam(content="Hello, how are you?"), + OpenAIAssistantMessageParam( content="I'm doing well, thank you for asking!", - stop_reason=StopReason.end_of_message, tool_calls=[], ), ] @@ -243,7 +240,7 @@ async def test_run_shield_not_found(nvidia_adapter, mock_guardrails_post): adapter.shield_store.get_shield.return_value = None messages = [ - UserMessage(role="user", content="Hello, how are you?"), + OpenAIUserMessageParam(content="Hello, how are you?"), ] with pytest.raises(ValueError): @@ -274,11 +271,9 @@ async def test_run_shield_http_error(nvidia_adapter, mock_guardrails_post): # Running the shield should raise an exception messages = [ - UserMessage(role="user", content="Hello, how are you?"), - CompletionMessage( - role="assistant", + OpenAIUserMessageParam(content="Hello, how are you?"), + OpenAIAssistantMessageParam( content="I'm doing well, thank you for asking!", - stop_reason=StopReason.end_of_message, tool_calls=[], ), ] diff --git a/tests/unit/providers/utils/inference/test_openai_compat.py b/tests/unit/providers/utils/inference/test_openai_compat.py deleted file mode 100644 index c200c4395..000000000 --- a/tests/unit/providers/utils/inference/test_openai_compat.py +++ /dev/null @@ -1,220 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import pytest -from pydantic import ValidationError - -from llama_stack.apis.common.content_types import TextContentItem -from llama_stack.apis.inference import ( - CompletionMessage, - OpenAIAssistantMessageParam, - OpenAIChatCompletionContentPartImageParam, - OpenAIChatCompletionContentPartTextParam, - OpenAIDeveloperMessageParam, - OpenAIImageURL, - OpenAISystemMessageParam, - OpenAIToolMessageParam, - OpenAIUserMessageParam, - SystemMessage, - UserMessage, -) -from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall -from llama_stack.providers.utils.inference.openai_compat import ( - convert_message_to_openai_dict, - convert_message_to_openai_dict_new, - openai_messages_to_messages, -) - - -async def test_convert_message_to_openai_dict(): - message = UserMessage(content=[TextContentItem(text="Hello, world!")], role="user") - assert await convert_message_to_openai_dict(message) == { - "role": "user", - "content": [{"type": "text", "text": "Hello, world!"}], - } - - -# Test convert_message_to_openai_dict with a tool call -async def test_convert_message_to_openai_dict_with_tool_call(): - message = CompletionMessage( - content="", - tool_calls=[ToolCall(call_id="123", tool_name="test_tool", arguments='{"foo": "bar"}')], - stop_reason=StopReason.end_of_turn, - ) - - openai_dict = await convert_message_to_openai_dict(message) - - assert openai_dict == { - "role": "assistant", - "content": [{"type": "text", "text": ""}], - "tool_calls": [ - {"id": "123", "type": "function", "function": {"name": "test_tool", "arguments": '{"foo": "bar"}'}} - ], - } - - -async def test_convert_message_to_openai_dict_with_builtin_tool_call(): - message = CompletionMessage( - content="", - tool_calls=[ - ToolCall( - call_id="123", - tool_name=BuiltinTool.brave_search, - arguments='{"foo": "bar"}', - ) - ], - stop_reason=StopReason.end_of_turn, - ) - - openai_dict = await convert_message_to_openai_dict(message) - - assert openai_dict == { - "role": "assistant", - "content": [{"type": "text", "text": ""}], - "tool_calls": [ - {"id": "123", "type": "function", "function": {"name": "brave_search", "arguments": '{"foo": "bar"}'}} - ], - } - - -async def test_openai_messages_to_messages_with_content_str(): - openai_messages = [ - OpenAISystemMessageParam(content="system message"), - OpenAIUserMessageParam(content="user message"), - OpenAIAssistantMessageParam(content="assistant message"), - ] - - llama_messages = openai_messages_to_messages(openai_messages) - assert len(llama_messages) == 3 - assert isinstance(llama_messages[0], SystemMessage) - assert isinstance(llama_messages[1], UserMessage) - assert isinstance(llama_messages[2], CompletionMessage) - assert llama_messages[0].content == "system message" - assert llama_messages[1].content == "user message" - assert llama_messages[2].content == "assistant message" - - -async def test_openai_messages_to_messages_with_content_list(): - openai_messages = [ - OpenAISystemMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="system message")]), - OpenAIUserMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="user message")]), - OpenAIAssistantMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="assistant message")]), - ] - - llama_messages = openai_messages_to_messages(openai_messages) - assert len(llama_messages) == 3 - assert isinstance(llama_messages[0], SystemMessage) - assert isinstance(llama_messages[1], UserMessage) - assert isinstance(llama_messages[2], CompletionMessage) - assert llama_messages[0].content[0].text == "system message" - assert llama_messages[1].content[0].text == "user message" - assert llama_messages[2].content[0].text == "assistant message" - - -@pytest.mark.parametrize( - "message_class,kwargs", - [ - (OpenAISystemMessageParam, {}), - (OpenAIAssistantMessageParam, {}), - (OpenAIDeveloperMessageParam, {}), - (OpenAIUserMessageParam, {}), - (OpenAIToolMessageParam, {"tool_call_id": "call_123"}), - ], -) -def test_message_accepts_text_string(message_class, kwargs): - """Test that messages accept string text content.""" - msg = message_class(content="Test message", **kwargs) - assert msg.content == "Test message" - - -@pytest.mark.parametrize( - "message_class,kwargs", - [ - (OpenAISystemMessageParam, {}), - (OpenAIAssistantMessageParam, {}), - (OpenAIDeveloperMessageParam, {}), - (OpenAIUserMessageParam, {}), - (OpenAIToolMessageParam, {"tool_call_id": "call_123"}), - ], -) -def test_message_accepts_text_list(message_class, kwargs): - """Test that messages accept list of text content parts.""" - content_list = [OpenAIChatCompletionContentPartTextParam(text="Test message")] - msg = message_class(content=content_list, **kwargs) - assert len(msg.content) == 1 - assert msg.content[0].text == "Test message" - - -@pytest.mark.parametrize( - "message_class,kwargs", - [ - (OpenAISystemMessageParam, {}), - (OpenAIAssistantMessageParam, {}), - (OpenAIDeveloperMessageParam, {}), - (OpenAIToolMessageParam, {"tool_call_id": "call_123"}), - ], -) -def test_message_rejects_images(message_class, kwargs): - """Test that system, assistant, developer, and tool messages reject image content.""" - with pytest.raises(ValidationError): - message_class( - content=[ - OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg")) - ], - **kwargs, - ) - - -def test_user_message_accepts_images(): - """Test that user messages accept image content (unlike other message types).""" - # List with images should work - msg = OpenAIUserMessageParam( - content=[ - OpenAIChatCompletionContentPartTextParam(text="Describe this image:"), - OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg")), - ] - ) - assert len(msg.content) == 2 - assert msg.content[0].text == "Describe this image:" - assert msg.content[1].image_url.url == "http://example.com/image.jpg" - - -async def test_convert_message_to_openai_dict_new_user_message(): - """Test convert_message_to_openai_dict_new with UserMessage.""" - message = UserMessage(content="Hello, world!", role="user") - result = await convert_message_to_openai_dict_new(message) - - assert result["role"] == "user" - assert result["content"] == "Hello, world!" - - -async def test_convert_message_to_openai_dict_new_completion_message_with_tool_calls(): - """Test convert_message_to_openai_dict_new with CompletionMessage containing tool calls.""" - message = CompletionMessage( - content="I'll help you find the weather.", - tool_calls=[ - ToolCall( - call_id="call_123", - tool_name="get_weather", - arguments='{"city": "Sligo"}', - ) - ], - stop_reason=StopReason.end_of_turn, - ) - result = await convert_message_to_openai_dict_new(message) - - # This would have failed with "Cannot instantiate typing.Union" before the fix - assert result["role"] == "assistant" - assert result["content"] == "I'll help you find the weather." - assert "tool_calls" in result - assert result["tool_calls"] is not None - assert len(result["tool_calls"]) == 1 - - tool_call = result["tool_calls"][0] - assert tool_call.id == "call_123" - assert tool_call.type == "function" - assert tool_call.function.name == "get_weather" - assert tool_call.function.arguments == '{"city": "Sligo"}' diff --git a/tests/unit/providers/utils/inference/test_prompt_adapter.py b/tests/unit/providers/utils/inference/test_prompt_adapter.py new file mode 100644 index 000000000..62c8db74d --- /dev/null +++ b/tests/unit/providers/utils/inference/test_prompt_adapter.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.inference import ( + OpenAIAssistantMessageParam, + OpenAIUserMessageParam, +) +from llama_stack.models.llama.datatypes import RawTextItem +from llama_stack.providers.utils.inference.prompt_adapter import ( + convert_openai_message_to_raw_message, +) + + +class TestConvertOpenAIMessageToRawMessage: + """Test conversion of OpenAI message types to RawMessage format.""" + + async def test_user_message_conversion(self): + msg = OpenAIUserMessageParam(role="user", content="Hello world") + raw_msg = await convert_openai_message_to_raw_message(msg) + + assert raw_msg.role == "user" + assert isinstance(raw_msg.content, RawTextItem) + assert raw_msg.content.text == "Hello world" + + async def test_assistant_message_conversion(self): + msg = OpenAIAssistantMessageParam(role="assistant", content="Hi there!") + raw_msg = await convert_openai_message_to_raw_message(msg) + + assert raw_msg.role == "assistant" + assert isinstance(raw_msg.content, RawTextItem) + assert raw_msg.content.text == "Hi there!" + assert raw_msg.tool_calls == []