mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 18:00:36 +00:00
refactor: remove dead inference API code and clean up imports
Delete ~1,300 lines of dead code from the old bespoke inference API that was replaced by OpenAI-only API. This includes removing unused type conversion functions, dead provider methods, and event_logger.py. Clean up imports across the codebase to remove references to deleted types. This eliminates unnecessary code and dependencies, helping isolate the API package as a self-contained module. This is the last interdependency between the .api package and "exterior" packages, meaning that now every other package in llama stack imports the API, not the other way around. The API is now self contained and can be moved into its own package. Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
8f4c431370
commit
54754cb2a7
18 changed files with 34 additions and 2141 deletions
|
|
@ -1,43 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
|
||||||
ChatCompletionResponseEventType,
|
|
||||||
ChatCompletionResponseStreamChunk,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LogEvent:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
content: str = "",
|
|
||||||
end: str = "\n",
|
|
||||||
color="white",
|
|
||||||
):
|
|
||||||
self.content = content
|
|
||||||
self.color = color
|
|
||||||
self.end = "\n" if end is None else end
|
|
||||||
|
|
||||||
def print(self, flush=True):
|
|
||||||
cprint(f"{self.content}", color=self.color, end=self.end, flush=flush)
|
|
||||||
|
|
||||||
|
|
||||||
class EventLogger:
|
|
||||||
async def log(self, event_generator):
|
|
||||||
async for chunk in event_generator:
|
|
||||||
if isinstance(chunk, ChatCompletionResponseStreamChunk):
|
|
||||||
event = chunk.event
|
|
||||||
if event.event_type == ChatCompletionResponseEventType.start:
|
|
||||||
yield LogEvent("Assistant> ", color="cyan", end="")
|
|
||||||
elif event.event_type == ChatCompletionResponseEventType.progress:
|
|
||||||
yield LogEvent(event.delta, color="yellow", end="")
|
|
||||||
elif event.event_type == ChatCompletionResponseEventType.complete:
|
|
||||||
yield LogEvent("")
|
|
||||||
else:
|
|
||||||
yield LogEvent("Assistant> ", color="cyan", end="")
|
|
||||||
yield LogEvent(chunk.completion_message.content, color="yellow")
|
|
||||||
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from enum import Enum
|
from enum import Enum, StrEnum
|
||||||
from typing import (
|
from typing import (
|
||||||
Annotated,
|
Annotated,
|
||||||
Any,
|
Any,
|
||||||
|
|
@ -15,28 +15,18 @@ from typing import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from fastapi import Body
|
from fastapi import Body
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
from llama_stack.apis.common.responses import MetricResponseMixin, Order
|
from llama_stack.apis.common.responses import (
|
||||||
|
Order,
|
||||||
|
)
|
||||||
from llama_stack.apis.common.tracing import telemetry_traceable
|
from llama_stack.apis.common.tracing import telemetry_traceable
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
|
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
|
||||||
from llama_stack.models.llama.datatypes import (
|
|
||||||
BuiltinTool,
|
|
||||||
StopReason,
|
|
||||||
ToolCall,
|
|
||||||
ToolDefinition,
|
|
||||||
ToolPromptFormat,
|
|
||||||
)
|
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
register_schema(ToolCall)
|
|
||||||
register_schema(ToolDefinition)
|
|
||||||
|
|
||||||
from enum import StrEnum
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class GreedySamplingStrategy(BaseModel):
|
class GreedySamplingStrategy(BaseModel):
|
||||||
|
|
@ -201,58 +191,6 @@ class ToolResponseMessage(BaseModel):
|
||||||
content: InterleavedContent
|
content: InterleavedContent
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class CompletionMessage(BaseModel):
|
|
||||||
"""A message containing the model's (assistant) response in a chat conversation.
|
|
||||||
|
|
||||||
:param role: Must be "assistant" to identify this as the model's response
|
|
||||||
:param content: The content of the model's response
|
|
||||||
:param stop_reason: Reason why the model stopped generating. Options are:
|
|
||||||
- `StopReason.end_of_turn`: The model finished generating the entire response.
|
|
||||||
- `StopReason.end_of_message`: The model finished generating but generated a partial response -- usually, a tool call. The user may call the tool and continue the conversation with the tool's response.
|
|
||||||
- `StopReason.out_of_tokens`: The model ran out of token budget.
|
|
||||||
:param tool_calls: List of tool calls. Each tool call is a ToolCall object.
|
|
||||||
"""
|
|
||||||
|
|
||||||
role: Literal["assistant"] = "assistant"
|
|
||||||
content: InterleavedContent
|
|
||||||
stop_reason: StopReason
|
|
||||||
tool_calls: list[ToolCall] | None = Field(default_factory=lambda: [])
|
|
||||||
|
|
||||||
|
|
||||||
Message = Annotated[
|
|
||||||
UserMessage | SystemMessage | ToolResponseMessage | CompletionMessage,
|
|
||||||
Field(discriminator="role"),
|
|
||||||
]
|
|
||||||
register_schema(Message, name="Message")
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ToolResponse(BaseModel):
|
|
||||||
"""Response from a tool invocation.
|
|
||||||
|
|
||||||
:param call_id: Unique identifier for the tool call this response is for
|
|
||||||
:param tool_name: Name of the tool that was invoked
|
|
||||||
:param content: The response content from the tool
|
|
||||||
:param metadata: (Optional) Additional metadata about the tool response
|
|
||||||
"""
|
|
||||||
|
|
||||||
call_id: str
|
|
||||||
tool_name: BuiltinTool | str
|
|
||||||
content: InterleavedContent
|
|
||||||
metadata: dict[str, Any] | None = None
|
|
||||||
|
|
||||||
@field_validator("tool_name", mode="before")
|
|
||||||
@classmethod
|
|
||||||
def validate_field(cls, v):
|
|
||||||
if isinstance(v, str):
|
|
||||||
try:
|
|
||||||
return BuiltinTool(v)
|
|
||||||
except ValueError:
|
|
||||||
return v
|
|
||||||
return v
|
|
||||||
|
|
||||||
|
|
||||||
class ToolChoice(Enum):
|
class ToolChoice(Enum):
|
||||||
"""Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model.
|
"""Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model.
|
||||||
|
|
||||||
|
|
@ -289,22 +227,6 @@ class ChatCompletionResponseEventType(Enum):
|
||||||
progress = "progress"
|
progress = "progress"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ChatCompletionResponseEvent(BaseModel):
|
|
||||||
"""An event during chat completion generation.
|
|
||||||
|
|
||||||
:param event_type: Type of the event
|
|
||||||
:param delta: Content generated since last event. This can be one or more tokens, or a tool call.
|
|
||||||
:param logprobs: Optional log probabilities for generated tokens
|
|
||||||
:param stop_reason: Optional reason why generation stopped, if complete
|
|
||||||
"""
|
|
||||||
|
|
||||||
event_type: ChatCompletionResponseEventType
|
|
||||||
delta: ContentDelta
|
|
||||||
logprobs: list[TokenLogProbs] | None = None
|
|
||||||
stop_reason: StopReason | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class ResponseFormatType(StrEnum):
|
class ResponseFormatType(StrEnum):
|
||||||
"""Types of formats for structured (guided) decoding.
|
"""Types of formats for structured (guided) decoding.
|
||||||
|
|
||||||
|
|
@ -357,34 +279,6 @@ class CompletionRequest(BaseModel):
|
||||||
logprobs: LogProbConfig | None = None
|
logprobs: LogProbConfig | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class CompletionResponse(MetricResponseMixin):
|
|
||||||
"""Response from a completion request.
|
|
||||||
|
|
||||||
:param content: The generated completion text
|
|
||||||
:param stop_reason: Reason why generation stopped
|
|
||||||
:param logprobs: Optional log probabilities for generated tokens
|
|
||||||
"""
|
|
||||||
|
|
||||||
content: str
|
|
||||||
stop_reason: StopReason
|
|
||||||
logprobs: list[TokenLogProbs] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class CompletionResponseStreamChunk(MetricResponseMixin):
|
|
||||||
"""A chunk of a streamed completion response.
|
|
||||||
|
|
||||||
:param delta: New content generated since last chunk. This can be one or more tokens.
|
|
||||||
:param stop_reason: Optional reason why generation stopped, if complete
|
|
||||||
:param logprobs: Optional log probabilities for generated tokens
|
|
||||||
"""
|
|
||||||
|
|
||||||
delta: str
|
|
||||||
stop_reason: StopReason | None = None
|
|
||||||
logprobs: list[TokenLogProbs] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class SystemMessageBehavior(Enum):
|
class SystemMessageBehavior(Enum):
|
||||||
"""Config for how to override the default system prompt.
|
"""Config for how to override the default system prompt.
|
||||||
|
|
||||||
|
|
@ -398,70 +292,6 @@ class SystemMessageBehavior(Enum):
|
||||||
replace = "replace"
|
replace = "replace"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ToolConfig(BaseModel):
|
|
||||||
"""Configuration for tool use.
|
|
||||||
|
|
||||||
:param tool_choice: (Optional) Whether tool use is automatic, required, or none. Can also specify a tool name to use a specific tool. Defaults to ToolChoice.auto.
|
|
||||||
:param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model.
|
|
||||||
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
|
|
||||||
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag.
|
|
||||||
- `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls.
|
|
||||||
:param system_message_behavior: (Optional) Config for how to override the default system prompt.
|
|
||||||
- `SystemMessageBehavior.append`: Appends the provided system message to the default system prompt.
|
|
||||||
- `SystemMessageBehavior.replace`: Replaces the default system prompt with the provided system message. The system message can include the string
|
|
||||||
'{{function_definitions}}' to indicate where the function definitions should be inserted.
|
|
||||||
"""
|
|
||||||
|
|
||||||
tool_choice: ToolChoice | str | None = Field(default=ToolChoice.auto)
|
|
||||||
tool_prompt_format: ToolPromptFormat | None = Field(default=None)
|
|
||||||
system_message_behavior: SystemMessageBehavior | None = Field(default=SystemMessageBehavior.append)
|
|
||||||
|
|
||||||
def model_post_init(self, __context: Any) -> None:
|
|
||||||
if isinstance(self.tool_choice, str):
|
|
||||||
try:
|
|
||||||
self.tool_choice = ToolChoice[self.tool_choice]
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# This is an internally used class
|
|
||||||
@json_schema_type
|
|
||||||
class ChatCompletionRequest(BaseModel):
|
|
||||||
model: str
|
|
||||||
messages: list[Message]
|
|
||||||
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
|
|
||||||
|
|
||||||
tools: list[ToolDefinition] | None = Field(default_factory=lambda: [])
|
|
||||||
tool_config: ToolConfig | None = Field(default_factory=ToolConfig)
|
|
||||||
|
|
||||||
response_format: ResponseFormat | None = None
|
|
||||||
stream: bool | None = False
|
|
||||||
logprobs: LogProbConfig | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ChatCompletionResponseStreamChunk(MetricResponseMixin):
|
|
||||||
"""A chunk of a streamed chat completion response.
|
|
||||||
|
|
||||||
:param event: The event containing the new content
|
|
||||||
"""
|
|
||||||
|
|
||||||
event: ChatCompletionResponseEvent
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ChatCompletionResponse(MetricResponseMixin):
|
|
||||||
"""Response from a chat completion request.
|
|
||||||
|
|
||||||
:param completion_message: The complete response message
|
|
||||||
:param logprobs: Optional log probabilities for generated tokens
|
|
||||||
"""
|
|
||||||
|
|
||||||
completion_message: CompletionMessage
|
|
||||||
logprobs: list[TokenLogProbs] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class EmbeddingsResponse(BaseModel):
|
class EmbeddingsResponse(BaseModel):
|
||||||
"""Response containing generated embeddings.
|
"""Response containing generated embeddings.
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.inference import Message
|
from llama_stack.apis.inference import OpenAIMessageParam
|
||||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||||
from llama_stack.apis.safety.safety import ModerationObject
|
from llama_stack.apis.safety.safety import ModerationObject
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
|
|
@ -52,7 +52,7 @@ class SafetyRouter(Safety):
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self,
|
self,
|
||||||
shield_id: str,
|
shield_id: str,
|
||||||
messages: list[Message],
|
messages: list[OpenAIMessageParam],
|
||||||
params: dict[str, Any] = None,
|
params: dict[str, Any] = None,
|
||||||
) -> RunShieldResponse:
|
) -> RunShieldResponse:
|
||||||
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
||||||
|
|
|
||||||
|
|
@ -26,8 +26,10 @@ from fairscale.nn.model_parallel.initialize import (
|
||||||
)
|
)
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
|
from llama_stack.models.llama.datatypes import ToolPromptFormat
|
||||||
|
|
||||||
from ..checkpoint import maybe_reshard_state_dict
|
from ..checkpoint import maybe_reshard_state_dict
|
||||||
from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage, ToolPromptFormat
|
from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage
|
||||||
from .args import ModelArgs
|
from .args import ModelArgs
|
||||||
from .chat_format import ChatFormat, LLMInput
|
from .chat_format import ChatFormat, LLMInput
|
||||||
from .model import Transformer
|
from .model import Transformer
|
||||||
|
|
|
||||||
|
|
@ -15,13 +15,10 @@ from pathlib import Path
|
||||||
|
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
|
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall, ToolDefinition, ToolPromptFormat
|
||||||
|
|
||||||
from ..datatypes import (
|
from ..datatypes import (
|
||||||
BuiltinTool,
|
|
||||||
RawMessage,
|
RawMessage,
|
||||||
StopReason,
|
|
||||||
ToolCall,
|
|
||||||
ToolDefinition,
|
|
||||||
ToolPromptFormat,
|
|
||||||
)
|
)
|
||||||
from . import template_data
|
from . import template_data
|
||||||
from .chat_format import ChatFormat
|
from .chat_format import ChatFormat
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ import textwrap
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -8,8 +8,9 @@ import json
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.models.llama.datatypes import BuiltinTool, ToolCall, ToolPromptFormat
|
||||||
|
|
||||||
from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
|
from ..datatypes import RecursiveType
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="models::llama")
|
logger = get_logger(name=__name__, category="models::llama")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@
|
||||||
|
|
||||||
import textwrap
|
import textwrap
|
||||||
|
|
||||||
from llama_stack.apis.inference import ToolDefinition
|
from llama_stack.models.llama.datatypes import ToolDefinition
|
||||||
from llama_stack.models.llama.llama3.prompt_templates.base import (
|
from llama_stack.models.llama.llama3.prompt_templates.base import (
|
||||||
PromptTemplate,
|
PromptTemplate,
|
||||||
PromptTemplateGeneratorBase,
|
PromptTemplateGeneratorBase,
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from collections.abc import Generator
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -24,11 +23,6 @@ from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokeniz
|
||||||
from llama_stack.models.llama.llama4.generation import Llama4
|
from llama_stack.models.llama.llama4.generation import Llama4
|
||||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
|
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
|
||||||
from llama_stack.models.llama.sku_types import Model, ModelFamily
|
from llama_stack.models.llama.sku_types import Model, ModelFamily
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
||||||
ChatCompletionRequestWithRawContent,
|
|
||||||
CompletionRequestWithRawContent,
|
|
||||||
get_default_tool_prompt_format,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .common import model_checkpoint_dir
|
from .common import model_checkpoint_dir
|
||||||
from .config import MetaReferenceInferenceConfig
|
from .config import MetaReferenceInferenceConfig
|
||||||
|
|
@ -106,14 +100,6 @@ def _infer_sampling_params(sampling_params: SamplingParams):
|
||||||
return temperature, top_p
|
return temperature, top_p
|
||||||
|
|
||||||
|
|
||||||
def _infer_tool_prompt_format(request: ChatCompletionRequestWithRawContent):
|
|
||||||
tool_config = request.tool_config
|
|
||||||
if tool_config is not None and tool_config.tool_prompt_format is not None:
|
|
||||||
return tool_config.tool_prompt_format
|
|
||||||
else:
|
|
||||||
return get_default_tool_prompt_format(request.model)
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaGenerator:
|
class LlamaGenerator:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -156,56 +142,3 @@ class LlamaGenerator:
|
||||||
self.tokenizer = self.inner_generator.tokenizer
|
self.tokenizer = self.inner_generator.tokenizer
|
||||||
self.args = self.inner_generator.args
|
self.args = self.inner_generator.args
|
||||||
self.formatter = self.inner_generator.formatter
|
self.formatter = self.inner_generator.formatter
|
||||||
|
|
||||||
def completion(
|
|
||||||
self,
|
|
||||||
request_batch: list[CompletionRequestWithRawContent],
|
|
||||||
) -> Generator:
|
|
||||||
first_request = request_batch[0]
|
|
||||||
sampling_params = first_request.sampling_params or SamplingParams()
|
|
||||||
max_gen_len = sampling_params.max_tokens
|
|
||||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
|
||||||
max_gen_len = self.args.max_seq_len - 1
|
|
||||||
|
|
||||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
|
||||||
yield from self.inner_generator.generate(
|
|
||||||
llm_inputs=[self.formatter.encode_content(request.content) for request in request_batch],
|
|
||||||
max_gen_len=max_gen_len,
|
|
||||||
temperature=temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
logprobs=bool(first_request.logprobs),
|
|
||||||
echo=False,
|
|
||||||
logits_processor=get_logits_processor(
|
|
||||||
self.tokenizer,
|
|
||||||
self.args.vocab_size,
|
|
||||||
first_request.response_format,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def chat_completion(
|
|
||||||
self,
|
|
||||||
request_batch: list[ChatCompletionRequestWithRawContent],
|
|
||||||
) -> Generator:
|
|
||||||
first_request = request_batch[0]
|
|
||||||
sampling_params = first_request.sampling_params or SamplingParams()
|
|
||||||
max_gen_len = sampling_params.max_tokens
|
|
||||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
|
||||||
max_gen_len = self.args.max_seq_len - 1
|
|
||||||
|
|
||||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
|
||||||
yield from self.inner_generator.generate(
|
|
||||||
llm_inputs=[
|
|
||||||
self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))
|
|
||||||
for request in request_batch
|
|
||||||
],
|
|
||||||
max_gen_len=max_gen_len,
|
|
||||||
temperature=temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
logprobs=bool(first_request.logprobs),
|
|
||||||
echo=False,
|
|
||||||
logits_processor=get_logits_processor(
|
|
||||||
self.tokenizer,
|
|
||||||
self.args.vocab_size,
|
|
||||||
first_request.response_format,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
|
||||||
|
|
@ -4,17 +4,12 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from collections.abc import Callable, Generator
|
from collections.abc import Callable
|
||||||
from copy import deepcopy
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
|
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
|
||||||
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
|
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
||||||
ChatCompletionRequestWithRawContent,
|
|
||||||
CompletionRequestWithRawContent,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .parallel_utils import ModelParallelProcessGroup
|
from .parallel_utils import ModelParallelProcessGroup
|
||||||
|
|
||||||
|
|
@ -23,11 +18,7 @@ class ModelRunner:
|
||||||
def __init__(self, llama):
|
def __init__(self, llama):
|
||||||
self.llama = llama
|
self.llama = llama
|
||||||
|
|
||||||
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
|
|
||||||
def __call__(self, task: Any):
|
def __call__(self, task: Any):
|
||||||
if task[0] == "chat_completion":
|
|
||||||
return self.llama.chat_completion(task[1])
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unexpected task type {task[0]}")
|
raise ValueError(f"Unexpected task type {task[0]}")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -78,19 +69,3 @@ class LlamaModelParallelGenerator:
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, exc_traceback):
|
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||||
self.group.stop()
|
self.group.stop()
|
||||||
|
|
||||||
def completion(
|
|
||||||
self,
|
|
||||||
request_batch: list[CompletionRequestWithRawContent],
|
|
||||||
) -> Generator:
|
|
||||||
req_obj = deepcopy(request_batch)
|
|
||||||
gen = self.group.run_inference(("completion", req_obj))
|
|
||||||
yield from gen
|
|
||||||
|
|
||||||
def chat_completion(
|
|
||||||
self,
|
|
||||||
request_batch: list[ChatCompletionRequestWithRawContent],
|
|
||||||
) -> Generator:
|
|
||||||
req_obj = deepcopy(request_batch)
|
|
||||||
gen = self.group.run_inference(("chat_completion", req_obj))
|
|
||||||
yield from gen
|
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,6 @@ from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.datatypes import GenerationResult
|
from llama_stack.models.llama.datatypes import GenerationResult
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
ChatCompletionRequestWithRawContent,
|
|
||||||
CompletionRequestWithRawContent,
|
CompletionRequestWithRawContent,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -71,7 +70,7 @@ class TaskRequest(BaseModel):
|
||||||
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
|
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
|
||||||
task: tuple[
|
task: tuple[
|
||||||
str,
|
str,
|
||||||
list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent],
|
list[CompletionRequestWithRawContent],
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -330,7 +329,7 @@ class ModelParallelProcessGroup:
|
||||||
self,
|
self,
|
||||||
req: tuple[
|
req: tuple[
|
||||||
str,
|
str,
|
||||||
list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent],
|
list[CompletionRequestWithRawContent],
|
||||||
],
|
],
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
assert not self.running, "inference already running"
|
assert not self.running, "inference already running"
|
||||||
|
|
|
||||||
|
|
@ -22,9 +22,6 @@ from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||||
SentenceTransformerEmbeddingMixin,
|
SentenceTransformerEmbeddingMixin,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
|
||||||
OpenAIChatCompletionToLlamaStackMixin,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .config import SentenceTransformersInferenceConfig
|
from .config import SentenceTransformersInferenceConfig
|
||||||
|
|
||||||
|
|
@ -32,7 +29,6 @@ log = get_logger(name=__name__, category="inference")
|
||||||
|
|
||||||
|
|
||||||
class SentenceTransformersInferenceImpl(
|
class SentenceTransformersInferenceImpl(
|
||||||
OpenAIChatCompletionToLlamaStackMixin,
|
|
||||||
SentenceTransformerEmbeddingMixin,
|
SentenceTransformerEmbeddingMixin,
|
||||||
InferenceProvider,
|
InferenceProvider,
|
||||||
ModelsProtocolPrivate,
|
ModelsProtocolPrivate,
|
||||||
|
|
|
||||||
|
|
@ -11,9 +11,7 @@ from collections.abc import AsyncIterator
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
|
||||||
InferenceProvider,
|
InferenceProvider,
|
||||||
JsonSchemaResponseFormat,
|
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
OpenAIChatCompletionChunk,
|
OpenAIChatCompletionChunk,
|
||||||
OpenAIChatCompletionRequestWithExtraBody,
|
OpenAIChatCompletionRequestWithExtraBody,
|
||||||
|
|
@ -23,15 +21,11 @@ from llama_stack.apis.inference import (
|
||||||
OpenAIEmbeddingsRequestWithExtraBody,
|
OpenAIEmbeddingsRequestWithExtraBody,
|
||||||
OpenAIEmbeddingsResponse,
|
OpenAIEmbeddingsResponse,
|
||||||
OpenAIEmbeddingUsage,
|
OpenAIEmbeddingUsage,
|
||||||
ToolChoice,
|
|
||||||
)
|
)
|
||||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
convert_message_to_openai_dict_new,
|
|
||||||
convert_tooldef_to_openai_tool,
|
|
||||||
get_sampling_options,
|
|
||||||
prepare_openai_completion_params,
|
prepare_openai_completion_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -127,51 +121,6 @@ class LiteLLMOpenAIMixin(
|
||||||
|
|
||||||
return schema
|
return schema
|
||||||
|
|
||||||
async def _get_params(self, request: ChatCompletionRequest) -> dict:
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
input_dict: dict[str, Any] = {}
|
|
||||||
|
|
||||||
input_dict["messages"] = [
|
|
||||||
await convert_message_to_openai_dict_new(m, download_images=self.download_images) for m in request.messages
|
|
||||||
]
|
|
||||||
if fmt := request.response_format:
|
|
||||||
if not isinstance(fmt, JsonSchemaResponseFormat):
|
|
||||||
raise ValueError(
|
|
||||||
f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert to dict for manipulation
|
|
||||||
fmt_dict = dict(fmt.json_schema)
|
|
||||||
name = fmt_dict["title"]
|
|
||||||
del fmt_dict["title"]
|
|
||||||
fmt_dict["additionalProperties"] = False
|
|
||||||
|
|
||||||
# Apply additionalProperties: False recursively to all objects
|
|
||||||
fmt_dict = self._add_additional_properties_recursive(fmt_dict)
|
|
||||||
|
|
||||||
input_dict["response_format"] = {
|
|
||||||
"type": "json_schema",
|
|
||||||
"json_schema": {
|
|
||||||
"name": name,
|
|
||||||
"schema": fmt_dict,
|
|
||||||
"strict": self.json_schema_strict,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
if request.tools:
|
|
||||||
input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools]
|
|
||||||
if request.tool_config and (tool_choice := request.tool_config.tool_choice):
|
|
||||||
input_dict["tool_choice"] = tool_choice.value if isinstance(tool_choice, ToolChoice) else tool_choice
|
|
||||||
|
|
||||||
return {
|
|
||||||
"model": request.model,
|
|
||||||
"api_key": self.get_api_key(),
|
|
||||||
"api_base": self.api_base,
|
|
||||||
**input_dict,
|
|
||||||
"stream": request.stream,
|
|
||||||
**get_sampling_options(request.sampling_params),
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_api_key(self) -> str:
|
def get_api_key(self) -> str:
|
||||||
provider_data = self.get_request_provider_data()
|
provider_data = self.get_request_provider_data()
|
||||||
key_field = self.provider_data_api_key_field
|
key_field = self.provider_data_api_key_field
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -21,19 +21,13 @@ from llama_stack.apis.common.content_types import (
|
||||||
TextContentItem,
|
TextContentItem,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
Message,
|
|
||||||
OpenAIChatCompletionContentPartImageParam,
|
OpenAIChatCompletionContentPartImageParam,
|
||||||
OpenAIChatCompletionContentPartTextParam,
|
OpenAIChatCompletionContentPartTextParam,
|
||||||
OpenAIFile,
|
OpenAIFile,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
ResponseFormatType,
|
ResponseFormatType,
|
||||||
SystemMessage,
|
|
||||||
SystemMessageBehavior,
|
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolDefinition,
|
|
||||||
UserMessage,
|
|
||||||
)
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
|
|
@ -42,33 +36,18 @@ from llama_stack.models.llama.datatypes import (
|
||||||
RawMediaItem,
|
RawMediaItem,
|
||||||
RawMessage,
|
RawMessage,
|
||||||
RawTextItem,
|
RawTextItem,
|
||||||
Role,
|
|
||||||
StopReason,
|
StopReason,
|
||||||
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||||
from llama_stack.models.llama.llama3.prompt_templates import (
|
|
||||||
BuiltinToolGenerator,
|
|
||||||
FunctionTagCustomToolGenerator,
|
|
||||||
JsonCustomToolGenerator,
|
|
||||||
PythonListCustomToolGenerator,
|
|
||||||
SystemDefaultGenerator,
|
|
||||||
)
|
|
||||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||||
from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
|
|
||||||
PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4,
|
|
||||||
)
|
|
||||||
from llama_stack.models.llama.sku_list import resolve_model
|
from llama_stack.models.llama.sku_list import resolve_model
|
||||||
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
|
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
|
||||||
from llama_stack.providers.utils.inference import supported_inference_models
|
|
||||||
|
|
||||||
log = get_logger(name=__name__, category="providers::utils")
|
log = get_logger(name=__name__, category="providers::utils")
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionRequestWithRawContent(ChatCompletionRequest):
|
|
||||||
messages: list[RawMessage]
|
|
||||||
|
|
||||||
|
|
||||||
class CompletionRequestWithRawContent(CompletionRequest):
|
class CompletionRequestWithRawContent(CompletionRequest):
|
||||||
content: RawContent
|
content: RawContent
|
||||||
|
|
||||||
|
|
@ -103,28 +82,6 @@ def interleaved_content_as_str(
|
||||||
return _process(content)
|
return _process(content)
|
||||||
|
|
||||||
|
|
||||||
async def convert_request_to_raw(
|
|
||||||
request: ChatCompletionRequest | CompletionRequest,
|
|
||||||
) -> ChatCompletionRequestWithRawContent | CompletionRequestWithRawContent:
|
|
||||||
if isinstance(request, ChatCompletionRequest):
|
|
||||||
messages = []
|
|
||||||
for m in request.messages:
|
|
||||||
content = await interleaved_content_convert_to_raw(m.content)
|
|
||||||
d = m.model_dump()
|
|
||||||
d["content"] = content
|
|
||||||
messages.append(RawMessage(**d))
|
|
||||||
|
|
||||||
d = request.model_dump()
|
|
||||||
d["messages"] = messages
|
|
||||||
request = ChatCompletionRequestWithRawContent(**d)
|
|
||||||
else:
|
|
||||||
d = request.model_dump()
|
|
||||||
d["content"] = await interleaved_content_convert_to_raw(request.content)
|
|
||||||
request = CompletionRequestWithRawContent(**d)
|
|
||||||
|
|
||||||
return request
|
|
||||||
|
|
||||||
|
|
||||||
async def interleaved_content_convert_to_raw(
|
async def interleaved_content_convert_to_raw(
|
||||||
content: InterleavedContent,
|
content: InterleavedContent,
|
||||||
) -> RawContent:
|
) -> RawContent:
|
||||||
|
|
@ -181,17 +138,6 @@ def content_has_media(content: InterleavedContent):
|
||||||
return _has_media_content(content)
|
return _has_media_content(content)
|
||||||
|
|
||||||
|
|
||||||
def messages_have_media(messages: list[Message]):
|
|
||||||
return any(content_has_media(m.content) for m in messages)
|
|
||||||
|
|
||||||
|
|
||||||
def request_has_media(request: ChatCompletionRequest | CompletionRequest):
|
|
||||||
if isinstance(request, ChatCompletionRequest):
|
|
||||||
return messages_have_media(request.messages)
|
|
||||||
else:
|
|
||||||
return content_has_media(request.content)
|
|
||||||
|
|
||||||
|
|
||||||
async def localize_image_content(uri: str) -> tuple[bytes, str] | None:
|
async def localize_image_content(uri: str) -> tuple[bytes, str] | None:
|
||||||
if uri.startswith("http"):
|
if uri.startswith("http"):
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
|
|
@ -253,79 +199,6 @@ def augment_content_with_response_format_prompt(response_format, content):
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
||||||
async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llama_model: str) -> str:
|
|
||||||
messages = chat_completion_request_to_messages(request, llama_model)
|
|
||||||
request.messages = messages
|
|
||||||
request = await convert_request_to_raw(request)
|
|
||||||
|
|
||||||
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
|
|
||||||
model_input = formatter.encode_dialog_prompt(
|
|
||||||
request.messages,
|
|
||||||
tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model),
|
|
||||||
)
|
|
||||||
return formatter.tokenizer.decode(model_input.tokens)
|
|
||||||
|
|
||||||
|
|
||||||
async def chat_completion_request_to_model_input_info(
|
|
||||||
request: ChatCompletionRequest, llama_model: str
|
|
||||||
) -> tuple[str, int]:
|
|
||||||
messages = chat_completion_request_to_messages(request, llama_model)
|
|
||||||
request.messages = messages
|
|
||||||
request = await convert_request_to_raw(request)
|
|
||||||
|
|
||||||
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
|
|
||||||
model_input = formatter.encode_dialog_prompt(
|
|
||||||
request.messages,
|
|
||||||
tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model),
|
|
||||||
)
|
|
||||||
return (
|
|
||||||
formatter.tokenizer.decode(model_input.tokens),
|
|
||||||
len(model_input.tokens),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def chat_completion_request_to_messages(
|
|
||||||
request: ChatCompletionRequest,
|
|
||||||
llama_model: str,
|
|
||||||
) -> list[Message]:
|
|
||||||
"""Reads chat completion request and augments the messages to handle tools.
|
|
||||||
For eg. for llama_3_1, add system message with the appropriate tools or
|
|
||||||
add user messsage for custom tools, etc.
|
|
||||||
"""
|
|
||||||
assert llama_model is not None, "llama_model is required"
|
|
||||||
model = resolve_model(llama_model)
|
|
||||||
if model is None:
|
|
||||||
log.error(f"Could not resolve model {llama_model}")
|
|
||||||
return request.messages
|
|
||||||
|
|
||||||
allowed_models = supported_inference_models()
|
|
||||||
descriptors = [m.descriptor() for m in allowed_models]
|
|
||||||
if model.descriptor() not in descriptors:
|
|
||||||
log.error(f"Unsupported inference model? {model.descriptor()}")
|
|
||||||
return request.messages
|
|
||||||
|
|
||||||
if model.model_family == ModelFamily.llama3_1 or (
|
|
||||||
model.model_family == ModelFamily.llama3_2 and is_multimodal(model.core_model_id)
|
|
||||||
):
|
|
||||||
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
|
|
||||||
messages = augment_messages_for_tools_llama_3_1(request)
|
|
||||||
elif model.model_family in (
|
|
||||||
ModelFamily.llama3_2,
|
|
||||||
ModelFamily.llama3_3,
|
|
||||||
):
|
|
||||||
# llama3.2, llama3.3 follow the same tool prompt format
|
|
||||||
messages = augment_messages_for_tools_llama(request, PythonListCustomToolGenerator)
|
|
||||||
elif model.model_family == ModelFamily.llama4:
|
|
||||||
messages = augment_messages_for_tools_llama(request, PythonListCustomToolGeneratorLlama4)
|
|
||||||
else:
|
|
||||||
messages = request.messages
|
|
||||||
|
|
||||||
if fmt_prompt := response_format_prompt(request.response_format):
|
|
||||||
messages.append(UserMessage(content=fmt_prompt))
|
|
||||||
|
|
||||||
return messages
|
|
||||||
|
|
||||||
|
|
||||||
def response_format_prompt(fmt: ResponseFormat | None):
|
def response_format_prompt(fmt: ResponseFormat | None):
|
||||||
if not fmt:
|
if not fmt:
|
||||||
return None
|
return None
|
||||||
|
|
@ -338,128 +211,6 @@ def response_format_prompt(fmt: ResponseFormat | None):
|
||||||
raise ValueError(f"Unknown response format {fmt.type}")
|
raise ValueError(f"Unknown response format {fmt.type}")
|
||||||
|
|
||||||
|
|
||||||
def augment_messages_for_tools_llama_3_1(
|
|
||||||
request: ChatCompletionRequest,
|
|
||||||
) -> list[Message]:
|
|
||||||
existing_messages = request.messages
|
|
||||||
existing_system_message = None
|
|
||||||
if existing_messages[0].role == Role.system.value:
|
|
||||||
existing_system_message = existing_messages.pop(0)
|
|
||||||
|
|
||||||
assert existing_messages[0].role != Role.system.value, "Should only have 1 system message"
|
|
||||||
|
|
||||||
messages = []
|
|
||||||
|
|
||||||
default_gen = SystemDefaultGenerator()
|
|
||||||
default_template = default_gen.gen()
|
|
||||||
|
|
||||||
sys_content = ""
|
|
||||||
|
|
||||||
tool_template = None
|
|
||||||
if request.tools:
|
|
||||||
tool_gen = BuiltinToolGenerator()
|
|
||||||
tool_template = tool_gen.gen(request.tools)
|
|
||||||
|
|
||||||
sys_content += tool_template.render()
|
|
||||||
sys_content += "\n"
|
|
||||||
|
|
||||||
sys_content += default_template.render()
|
|
||||||
|
|
||||||
if existing_system_message:
|
|
||||||
# TODO: this fn is needed in many places
|
|
||||||
def _process(c):
|
|
||||||
if isinstance(c, str):
|
|
||||||
return c
|
|
||||||
else:
|
|
||||||
return "<media>"
|
|
||||||
|
|
||||||
sys_content += "\n"
|
|
||||||
|
|
||||||
if isinstance(existing_system_message.content, str):
|
|
||||||
sys_content += _process(existing_system_message.content)
|
|
||||||
elif isinstance(existing_system_message.content, list):
|
|
||||||
sys_content += "\n".join([_process(c) for c in existing_system_message.content])
|
|
||||||
|
|
||||||
tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools)
|
|
||||||
if tool_choice_prompt:
|
|
||||||
sys_content += "\n" + tool_choice_prompt
|
|
||||||
|
|
||||||
messages.append(SystemMessage(content=sys_content))
|
|
||||||
|
|
||||||
has_custom_tools = request.tools is not None and any(isinstance(dfn.tool_name, str) for dfn in request.tools)
|
|
||||||
if has_custom_tools:
|
|
||||||
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.json
|
|
||||||
if fmt == ToolPromptFormat.json:
|
|
||||||
tool_gen = JsonCustomToolGenerator()
|
|
||||||
elif fmt == ToolPromptFormat.function_tag:
|
|
||||||
tool_gen = FunctionTagCustomToolGenerator()
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Non supported ToolPromptFormat {fmt}")
|
|
||||||
|
|
||||||
custom_tools = [t for t in request.tools if isinstance(t.tool_name, str)]
|
|
||||||
custom_template = tool_gen.gen(custom_tools)
|
|
||||||
messages.append(UserMessage(content=custom_template.render()))
|
|
||||||
|
|
||||||
# Add back existing messages from the request
|
|
||||||
messages += existing_messages
|
|
||||||
|
|
||||||
return messages
|
|
||||||
|
|
||||||
|
|
||||||
def augment_messages_for_tools_llama(
|
|
||||||
request: ChatCompletionRequest,
|
|
||||||
custom_tool_prompt_generator,
|
|
||||||
) -> list[Message]:
|
|
||||||
existing_messages = request.messages
|
|
||||||
existing_system_message = None
|
|
||||||
if existing_messages[0].role == Role.system.value:
|
|
||||||
existing_system_message = existing_messages.pop(0)
|
|
||||||
|
|
||||||
assert existing_messages[0].role != Role.system.value, "Should only have 1 system message"
|
|
||||||
|
|
||||||
sys_content = ""
|
|
||||||
custom_tools, builtin_tools = [], []
|
|
||||||
for t in request.tools:
|
|
||||||
if isinstance(t.tool_name, str):
|
|
||||||
custom_tools.append(t)
|
|
||||||
else:
|
|
||||||
builtin_tools.append(t)
|
|
||||||
|
|
||||||
if builtin_tools:
|
|
||||||
tool_gen = BuiltinToolGenerator()
|
|
||||||
tool_template = tool_gen.gen(builtin_tools)
|
|
||||||
|
|
||||||
sys_content += tool_template.render()
|
|
||||||
sys_content += "\n"
|
|
||||||
|
|
||||||
custom_tools = [dfn for dfn in request.tools if isinstance(dfn.tool_name, str)]
|
|
||||||
if custom_tools:
|
|
||||||
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.python_list
|
|
||||||
if fmt != ToolPromptFormat.python_list:
|
|
||||||
raise ValueError(f"Non supported ToolPromptFormat {request.tool_config.tool_prompt_format}")
|
|
||||||
|
|
||||||
system_prompt = None
|
|
||||||
if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace:
|
|
||||||
system_prompt = existing_system_message.content
|
|
||||||
|
|
||||||
tool_template = custom_tool_prompt_generator().gen(custom_tools, system_prompt)
|
|
||||||
|
|
||||||
sys_content += tool_template.render()
|
|
||||||
sys_content += "\n"
|
|
||||||
|
|
||||||
if existing_system_message and (
|
|
||||||
request.tool_config.system_message_behavior == SystemMessageBehavior.append or not custom_tools
|
|
||||||
):
|
|
||||||
sys_content += interleaved_content_as_str(existing_system_message.content, sep="\n")
|
|
||||||
|
|
||||||
tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools)
|
|
||||||
if tool_choice_prompt:
|
|
||||||
sys_content += "\n" + tool_choice_prompt
|
|
||||||
|
|
||||||
messages = [SystemMessage(content=sys_content.strip("\n")), *existing_messages]
|
|
||||||
return messages
|
|
||||||
|
|
||||||
|
|
||||||
def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: list[ToolDefinition]) -> str:
|
def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: list[ToolDefinition]) -> str:
|
||||||
if tool_choice == ToolChoice.auto:
|
if tool_choice == ToolChoice.auto:
|
||||||
return ""
|
return ""
|
||||||
|
|
|
||||||
|
|
@ -1,303 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
|
||||||
ChatCompletionRequest,
|
|
||||||
CompletionMessage,
|
|
||||||
StopReason,
|
|
||||||
SystemMessage,
|
|
||||||
SystemMessageBehavior,
|
|
||||||
ToolCall,
|
|
||||||
ToolConfig,
|
|
||||||
UserMessage,
|
|
||||||
)
|
|
||||||
from llama_stack.models.llama.datatypes import (
|
|
||||||
BuiltinTool,
|
|
||||||
ToolDefinition,
|
|
||||||
ToolPromptFormat,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
||||||
chat_completion_request_to_messages,
|
|
||||||
chat_completion_request_to_prompt,
|
|
||||||
interleaved_content_as_str,
|
|
||||||
)
|
|
||||||
|
|
||||||
MODEL = "Llama3.1-8B-Instruct"
|
|
||||||
MODEL3_2 = "Llama3.2-3B-Instruct"
|
|
||||||
|
|
||||||
|
|
||||||
async def test_system_default():
|
|
||||||
content = "Hello !"
|
|
||||||
request = ChatCompletionRequest(
|
|
||||||
model=MODEL,
|
|
||||||
messages=[
|
|
||||||
UserMessage(content=content),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
messages = chat_completion_request_to_messages(request, MODEL)
|
|
||||||
assert len(messages) == 2
|
|
||||||
assert messages[-1].content == content
|
|
||||||
assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_system_builtin_only():
|
|
||||||
content = "Hello !"
|
|
||||||
request = ChatCompletionRequest(
|
|
||||||
model=MODEL,
|
|
||||||
messages=[
|
|
||||||
UserMessage(content=content),
|
|
||||||
],
|
|
||||||
tools=[
|
|
||||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
|
||||||
ToolDefinition(tool_name=BuiltinTool.brave_search),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
messages = chat_completion_request_to_messages(request, MODEL)
|
|
||||||
assert len(messages) == 2
|
|
||||||
assert messages[-1].content == content
|
|
||||||
assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content)
|
|
||||||
assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_system_custom_only():
|
|
||||||
content = "Hello !"
|
|
||||||
request = ChatCompletionRequest(
|
|
||||||
model=MODEL,
|
|
||||||
messages=[
|
|
||||||
UserMessage(content=content),
|
|
||||||
],
|
|
||||||
tools=[
|
|
||||||
ToolDefinition(
|
|
||||||
tool_name="custom1",
|
|
||||||
description="custom1 tool",
|
|
||||||
input_schema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"param1": {
|
|
||||||
"type": "str",
|
|
||||||
"description": "param1 description",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["param1"],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
],
|
|
||||||
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json),
|
|
||||||
)
|
|
||||||
messages = chat_completion_request_to_messages(request, MODEL)
|
|
||||||
assert len(messages) == 3
|
|
||||||
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
|
|
||||||
|
|
||||||
assert "Return function calls in JSON format" in interleaved_content_as_str(messages[1].content)
|
|
||||||
assert messages[-1].content == content
|
|
||||||
|
|
||||||
|
|
||||||
async def test_system_custom_and_builtin():
|
|
||||||
content = "Hello !"
|
|
||||||
request = ChatCompletionRequest(
|
|
||||||
model=MODEL,
|
|
||||||
messages=[
|
|
||||||
UserMessage(content=content),
|
|
||||||
],
|
|
||||||
tools=[
|
|
||||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
|
||||||
ToolDefinition(tool_name=BuiltinTool.brave_search),
|
|
||||||
ToolDefinition(
|
|
||||||
tool_name="custom1",
|
|
||||||
description="custom1 tool",
|
|
||||||
input_schema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"param1": {
|
|
||||||
"type": "str",
|
|
||||||
"description": "param1 description",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["param1"],
|
|
||||||
},
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
messages = chat_completion_request_to_messages(request, MODEL)
|
|
||||||
assert len(messages) == 3
|
|
||||||
|
|
||||||
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
|
|
||||||
assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content)
|
|
||||||
|
|
||||||
assert "Return function calls in JSON format" in interleaved_content_as_str(messages[1].content)
|
|
||||||
assert messages[-1].content == content
|
|
||||||
|
|
||||||
|
|
||||||
async def test_completion_message_encoding():
|
|
||||||
request = ChatCompletionRequest(
|
|
||||||
model=MODEL3_2,
|
|
||||||
messages=[
|
|
||||||
UserMessage(content="hello"),
|
|
||||||
CompletionMessage(
|
|
||||||
content="",
|
|
||||||
stop_reason=StopReason.end_of_turn,
|
|
||||||
tool_calls=[
|
|
||||||
ToolCall(
|
|
||||||
tool_name="custom1",
|
|
||||||
arguments='{"param1": "value1"}', # arguments must be a JSON string
|
|
||||||
call_id="123",
|
|
||||||
)
|
|
||||||
],
|
|
||||||
),
|
|
||||||
],
|
|
||||||
tools=[
|
|
||||||
ToolDefinition(
|
|
||||||
tool_name="custom1",
|
|
||||||
description="custom1 tool",
|
|
||||||
input_schema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"param1": {
|
|
||||||
"type": "str",
|
|
||||||
"description": "param1 description",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["param1"],
|
|
||||||
},
|
|
||||||
),
|
|
||||||
],
|
|
||||||
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list),
|
|
||||||
)
|
|
||||||
prompt = await chat_completion_request_to_prompt(request, request.model)
|
|
||||||
assert '[custom1(param1="value1")]' in prompt
|
|
||||||
|
|
||||||
request.model = MODEL
|
|
||||||
request.tool_config = ToolConfig(tool_prompt_format=ToolPromptFormat.json)
|
|
||||||
prompt = await chat_completion_request_to_prompt(request, request.model)
|
|
||||||
assert '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}' in prompt
|
|
||||||
|
|
||||||
|
|
||||||
async def test_user_provided_system_message():
|
|
||||||
content = "Hello !"
|
|
||||||
system_prompt = "You are a pirate"
|
|
||||||
request = ChatCompletionRequest(
|
|
||||||
model=MODEL,
|
|
||||||
messages=[
|
|
||||||
SystemMessage(content=system_prompt),
|
|
||||||
UserMessage(content=content),
|
|
||||||
],
|
|
||||||
tools=[
|
|
||||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
messages = chat_completion_request_to_messages(request, MODEL)
|
|
||||||
assert len(messages) == 2
|
|
||||||
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
|
|
||||||
|
|
||||||
assert messages[-1].content == content
|
|
||||||
|
|
||||||
|
|
||||||
async def test_replace_system_message_behavior_builtin_tools():
|
|
||||||
content = "Hello !"
|
|
||||||
system_prompt = "You are a pirate"
|
|
||||||
request = ChatCompletionRequest(
|
|
||||||
model=MODEL,
|
|
||||||
messages=[
|
|
||||||
SystemMessage(content=system_prompt),
|
|
||||||
UserMessage(content=content),
|
|
||||||
],
|
|
||||||
tools=[
|
|
||||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
|
||||||
],
|
|
||||||
tool_config=ToolConfig(
|
|
||||||
tool_choice="auto",
|
|
||||||
tool_prompt_format=ToolPromptFormat.python_list,
|
|
||||||
system_message_behavior=SystemMessageBehavior.replace,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
|
||||||
assert len(messages) == 2
|
|
||||||
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
|
|
||||||
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
|
|
||||||
assert messages[-1].content == content
|
|
||||||
|
|
||||||
|
|
||||||
async def test_replace_system_message_behavior_custom_tools():
|
|
||||||
content = "Hello !"
|
|
||||||
system_prompt = "You are a pirate"
|
|
||||||
request = ChatCompletionRequest(
|
|
||||||
model=MODEL,
|
|
||||||
messages=[
|
|
||||||
SystemMessage(content=system_prompt),
|
|
||||||
UserMessage(content=content),
|
|
||||||
],
|
|
||||||
tools=[
|
|
||||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
|
||||||
ToolDefinition(
|
|
||||||
tool_name="custom1",
|
|
||||||
description="custom1 tool",
|
|
||||||
input_schema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"param1": {
|
|
||||||
"type": "str",
|
|
||||||
"description": "param1 description",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["param1"],
|
|
||||||
},
|
|
||||||
),
|
|
||||||
],
|
|
||||||
tool_config=ToolConfig(
|
|
||||||
tool_choice="auto",
|
|
||||||
tool_prompt_format=ToolPromptFormat.python_list,
|
|
||||||
system_message_behavior=SystemMessageBehavior.replace,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
|
||||||
|
|
||||||
assert len(messages) == 2
|
|
||||||
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
|
|
||||||
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
|
|
||||||
assert messages[-1].content == content
|
|
||||||
|
|
||||||
|
|
||||||
async def test_replace_system_message_behavior_custom_tools_with_template():
|
|
||||||
content = "Hello !"
|
|
||||||
system_prompt = "You are a pirate {{ function_description }}"
|
|
||||||
request = ChatCompletionRequest(
|
|
||||||
model=MODEL,
|
|
||||||
messages=[
|
|
||||||
SystemMessage(content=system_prompt),
|
|
||||||
UserMessage(content=content),
|
|
||||||
],
|
|
||||||
tools=[
|
|
||||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
|
||||||
ToolDefinition(
|
|
||||||
tool_name="custom1",
|
|
||||||
description="custom1 tool",
|
|
||||||
input_schema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"param1": {
|
|
||||||
"type": "str",
|
|
||||||
"description": "param1 description",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["param1"],
|
|
||||||
},
|
|
||||||
),
|
|
||||||
],
|
|
||||||
tool_config=ToolConfig(
|
|
||||||
tool_choice="auto",
|
|
||||||
tool_prompt_format=ToolPromptFormat.python_list,
|
|
||||||
system_message_behavior=SystemMessageBehavior.replace,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
|
||||||
|
|
||||||
assert len(messages) == 2
|
|
||||||
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
|
|
||||||
assert "You are a pirate" in interleaved_content_as_str(messages[0].content)
|
|
||||||
# function description is present in the system prompt
|
|
||||||
assert '"name": "custom1"' in interleaved_content_as_str(messages[0].content)
|
|
||||||
assert messages[-1].content == content
|
|
||||||
|
|
@ -10,11 +10,13 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_stack.apis.inference import CompletionMessage, UserMessage
|
from llama_stack.apis.inference import (
|
||||||
|
OpenAIAssistantMessageParam,
|
||||||
|
OpenAIUserMessageParam,
|
||||||
|
)
|
||||||
from llama_stack.apis.resource import ResourceType
|
from llama_stack.apis.resource import ResourceType
|
||||||
from llama_stack.apis.safety import RunShieldResponse, ViolationLevel
|
from llama_stack.apis.safety import RunShieldResponse, ViolationLevel
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.models.llama.datatypes import StopReason
|
|
||||||
from llama_stack.providers.remote.safety.nvidia.config import NVIDIASafetyConfig
|
from llama_stack.providers.remote.safety.nvidia.config import NVIDIASafetyConfig
|
||||||
from llama_stack.providers.remote.safety.nvidia.nvidia import NVIDIASafetyAdapter
|
from llama_stack.providers.remote.safety.nvidia.nvidia import NVIDIASafetyAdapter
|
||||||
|
|
||||||
|
|
@ -136,11 +138,9 @@ async def test_run_shield_allowed(nvidia_adapter, mock_guardrails_post):
|
||||||
|
|
||||||
# Run the shield
|
# Run the shield
|
||||||
messages = [
|
messages = [
|
||||||
UserMessage(role="user", content="Hello, how are you?"),
|
OpenAIUserMessageParam(content="Hello, how are you?"),
|
||||||
CompletionMessage(
|
OpenAIAssistantMessageParam(
|
||||||
role="assistant",
|
|
||||||
content="I'm doing well, thank you for asking!",
|
content="I'm doing well, thank you for asking!",
|
||||||
stop_reason=StopReason.end_of_message,
|
|
||||||
tool_calls=[],
|
tool_calls=[],
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
@ -191,13 +191,10 @@ async def test_run_shield_blocked(nvidia_adapter, mock_guardrails_post):
|
||||||
# Mock Guardrails API response
|
# Mock Guardrails API response
|
||||||
mock_guardrails_post.return_value = {"status": "blocked", "rails_status": {"reason": "harmful_content"}}
|
mock_guardrails_post.return_value = {"status": "blocked", "rails_status": {"reason": "harmful_content"}}
|
||||||
|
|
||||||
# Run the shield
|
|
||||||
messages = [
|
messages = [
|
||||||
UserMessage(role="user", content="Hello, how are you?"),
|
OpenAIUserMessageParam(content="Hello, how are you?"),
|
||||||
CompletionMessage(
|
OpenAIAssistantMessageParam(
|
||||||
role="assistant",
|
|
||||||
content="I'm doing well, thank you for asking!",
|
content="I'm doing well, thank you for asking!",
|
||||||
stop_reason=StopReason.end_of_message,
|
|
||||||
tool_calls=[],
|
tool_calls=[],
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
@ -243,7 +240,7 @@ async def test_run_shield_not_found(nvidia_adapter, mock_guardrails_post):
|
||||||
adapter.shield_store.get_shield.return_value = None
|
adapter.shield_store.get_shield.return_value = None
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
UserMessage(role="user", content="Hello, how are you?"),
|
OpenAIUserMessageParam(content="Hello, how are you?"),
|
||||||
]
|
]
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
|
|
@ -274,11 +271,9 @@ async def test_run_shield_http_error(nvidia_adapter, mock_guardrails_post):
|
||||||
|
|
||||||
# Running the shield should raise an exception
|
# Running the shield should raise an exception
|
||||||
messages = [
|
messages = [
|
||||||
UserMessage(role="user", content="Hello, how are you?"),
|
OpenAIUserMessageParam(content="Hello, how are you?"),
|
||||||
CompletionMessage(
|
OpenAIAssistantMessageParam(
|
||||||
role="assistant",
|
|
||||||
content="I'm doing well, thank you for asking!",
|
content="I'm doing well, thank you for asking!",
|
||||||
stop_reason=StopReason.end_of_message,
|
|
||||||
tool_calls=[],
|
tool_calls=[],
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,220 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from pydantic import ValidationError
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import TextContentItem
|
|
||||||
from llama_stack.apis.inference import (
|
|
||||||
CompletionMessage,
|
|
||||||
OpenAIAssistantMessageParam,
|
|
||||||
OpenAIChatCompletionContentPartImageParam,
|
|
||||||
OpenAIChatCompletionContentPartTextParam,
|
|
||||||
OpenAIDeveloperMessageParam,
|
|
||||||
OpenAIImageURL,
|
|
||||||
OpenAISystemMessageParam,
|
|
||||||
OpenAIToolMessageParam,
|
|
||||||
OpenAIUserMessageParam,
|
|
||||||
SystemMessage,
|
|
||||||
UserMessage,
|
|
||||||
)
|
|
||||||
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
|
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
|
||||||
convert_message_to_openai_dict,
|
|
||||||
convert_message_to_openai_dict_new,
|
|
||||||
openai_messages_to_messages,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_convert_message_to_openai_dict():
|
|
||||||
message = UserMessage(content=[TextContentItem(text="Hello, world!")], role="user")
|
|
||||||
assert await convert_message_to_openai_dict(message) == {
|
|
||||||
"role": "user",
|
|
||||||
"content": [{"type": "text", "text": "Hello, world!"}],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# Test convert_message_to_openai_dict with a tool call
|
|
||||||
async def test_convert_message_to_openai_dict_with_tool_call():
|
|
||||||
message = CompletionMessage(
|
|
||||||
content="",
|
|
||||||
tool_calls=[ToolCall(call_id="123", tool_name="test_tool", arguments='{"foo": "bar"}')],
|
|
||||||
stop_reason=StopReason.end_of_turn,
|
|
||||||
)
|
|
||||||
|
|
||||||
openai_dict = await convert_message_to_openai_dict(message)
|
|
||||||
|
|
||||||
assert openai_dict == {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": [{"type": "text", "text": ""}],
|
|
||||||
"tool_calls": [
|
|
||||||
{"id": "123", "type": "function", "function": {"name": "test_tool", "arguments": '{"foo": "bar"}'}}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def test_convert_message_to_openai_dict_with_builtin_tool_call():
|
|
||||||
message = CompletionMessage(
|
|
||||||
content="",
|
|
||||||
tool_calls=[
|
|
||||||
ToolCall(
|
|
||||||
call_id="123",
|
|
||||||
tool_name=BuiltinTool.brave_search,
|
|
||||||
arguments='{"foo": "bar"}',
|
|
||||||
)
|
|
||||||
],
|
|
||||||
stop_reason=StopReason.end_of_turn,
|
|
||||||
)
|
|
||||||
|
|
||||||
openai_dict = await convert_message_to_openai_dict(message)
|
|
||||||
|
|
||||||
assert openai_dict == {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": [{"type": "text", "text": ""}],
|
|
||||||
"tool_calls": [
|
|
||||||
{"id": "123", "type": "function", "function": {"name": "brave_search", "arguments": '{"foo": "bar"}'}}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def test_openai_messages_to_messages_with_content_str():
|
|
||||||
openai_messages = [
|
|
||||||
OpenAISystemMessageParam(content="system message"),
|
|
||||||
OpenAIUserMessageParam(content="user message"),
|
|
||||||
OpenAIAssistantMessageParam(content="assistant message"),
|
|
||||||
]
|
|
||||||
|
|
||||||
llama_messages = openai_messages_to_messages(openai_messages)
|
|
||||||
assert len(llama_messages) == 3
|
|
||||||
assert isinstance(llama_messages[0], SystemMessage)
|
|
||||||
assert isinstance(llama_messages[1], UserMessage)
|
|
||||||
assert isinstance(llama_messages[2], CompletionMessage)
|
|
||||||
assert llama_messages[0].content == "system message"
|
|
||||||
assert llama_messages[1].content == "user message"
|
|
||||||
assert llama_messages[2].content == "assistant message"
|
|
||||||
|
|
||||||
|
|
||||||
async def test_openai_messages_to_messages_with_content_list():
|
|
||||||
openai_messages = [
|
|
||||||
OpenAISystemMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="system message")]),
|
|
||||||
OpenAIUserMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="user message")]),
|
|
||||||
OpenAIAssistantMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="assistant message")]),
|
|
||||||
]
|
|
||||||
|
|
||||||
llama_messages = openai_messages_to_messages(openai_messages)
|
|
||||||
assert len(llama_messages) == 3
|
|
||||||
assert isinstance(llama_messages[0], SystemMessage)
|
|
||||||
assert isinstance(llama_messages[1], UserMessage)
|
|
||||||
assert isinstance(llama_messages[2], CompletionMessage)
|
|
||||||
assert llama_messages[0].content[0].text == "system message"
|
|
||||||
assert llama_messages[1].content[0].text == "user message"
|
|
||||||
assert llama_messages[2].content[0].text == "assistant message"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"message_class,kwargs",
|
|
||||||
[
|
|
||||||
(OpenAISystemMessageParam, {}),
|
|
||||||
(OpenAIAssistantMessageParam, {}),
|
|
||||||
(OpenAIDeveloperMessageParam, {}),
|
|
||||||
(OpenAIUserMessageParam, {}),
|
|
||||||
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_message_accepts_text_string(message_class, kwargs):
|
|
||||||
"""Test that messages accept string text content."""
|
|
||||||
msg = message_class(content="Test message", **kwargs)
|
|
||||||
assert msg.content == "Test message"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"message_class,kwargs",
|
|
||||||
[
|
|
||||||
(OpenAISystemMessageParam, {}),
|
|
||||||
(OpenAIAssistantMessageParam, {}),
|
|
||||||
(OpenAIDeveloperMessageParam, {}),
|
|
||||||
(OpenAIUserMessageParam, {}),
|
|
||||||
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_message_accepts_text_list(message_class, kwargs):
|
|
||||||
"""Test that messages accept list of text content parts."""
|
|
||||||
content_list = [OpenAIChatCompletionContentPartTextParam(text="Test message")]
|
|
||||||
msg = message_class(content=content_list, **kwargs)
|
|
||||||
assert len(msg.content) == 1
|
|
||||||
assert msg.content[0].text == "Test message"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"message_class,kwargs",
|
|
||||||
[
|
|
||||||
(OpenAISystemMessageParam, {}),
|
|
||||||
(OpenAIAssistantMessageParam, {}),
|
|
||||||
(OpenAIDeveloperMessageParam, {}),
|
|
||||||
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_message_rejects_images(message_class, kwargs):
|
|
||||||
"""Test that system, assistant, developer, and tool messages reject image content."""
|
|
||||||
with pytest.raises(ValidationError):
|
|
||||||
message_class(
|
|
||||||
content=[
|
|
||||||
OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg"))
|
|
||||||
],
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_user_message_accepts_images():
|
|
||||||
"""Test that user messages accept image content (unlike other message types)."""
|
|
||||||
# List with images should work
|
|
||||||
msg = OpenAIUserMessageParam(
|
|
||||||
content=[
|
|
||||||
OpenAIChatCompletionContentPartTextParam(text="Describe this image:"),
|
|
||||||
OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg")),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
assert len(msg.content) == 2
|
|
||||||
assert msg.content[0].text == "Describe this image:"
|
|
||||||
assert msg.content[1].image_url.url == "http://example.com/image.jpg"
|
|
||||||
|
|
||||||
|
|
||||||
async def test_convert_message_to_openai_dict_new_user_message():
|
|
||||||
"""Test convert_message_to_openai_dict_new with UserMessage."""
|
|
||||||
message = UserMessage(content="Hello, world!", role="user")
|
|
||||||
result = await convert_message_to_openai_dict_new(message)
|
|
||||||
|
|
||||||
assert result["role"] == "user"
|
|
||||||
assert result["content"] == "Hello, world!"
|
|
||||||
|
|
||||||
|
|
||||||
async def test_convert_message_to_openai_dict_new_completion_message_with_tool_calls():
|
|
||||||
"""Test convert_message_to_openai_dict_new with CompletionMessage containing tool calls."""
|
|
||||||
message = CompletionMessage(
|
|
||||||
content="I'll help you find the weather.",
|
|
||||||
tool_calls=[
|
|
||||||
ToolCall(
|
|
||||||
call_id="call_123",
|
|
||||||
tool_name="get_weather",
|
|
||||||
arguments='{"city": "Sligo"}',
|
|
||||||
)
|
|
||||||
],
|
|
||||||
stop_reason=StopReason.end_of_turn,
|
|
||||||
)
|
|
||||||
result = await convert_message_to_openai_dict_new(message)
|
|
||||||
|
|
||||||
# This would have failed with "Cannot instantiate typing.Union" before the fix
|
|
||||||
assert result["role"] == "assistant"
|
|
||||||
assert result["content"] == "I'll help you find the weather."
|
|
||||||
assert "tool_calls" in result
|
|
||||||
assert result["tool_calls"] is not None
|
|
||||||
assert len(result["tool_calls"]) == 1
|
|
||||||
|
|
||||||
tool_call = result["tool_calls"][0]
|
|
||||||
assert tool_call.id == "call_123"
|
|
||||||
assert tool_call.type == "function"
|
|
||||||
assert tool_call.function.name == "get_weather"
|
|
||||||
assert tool_call.function.arguments == '{"city": "Sligo"}'
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue