mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
Merge branch 'main' into add-mcp-authentication-param
This commit is contained in:
commit
30a544fb8c
27 changed files with 599 additions and 2148 deletions
|
|
@ -963,7 +963,7 @@ paths:
|
||||||
Optional filter to control which routes are returned. Can be an API level
|
Optional filter to control which routes are returned. Can be an API level
|
||||||
('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level,
|
('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level,
|
||||||
or 'deprecated' to show deprecated routes across all levels. If not specified,
|
or 'deprecated' to show deprecated routes across all levels. If not specified,
|
||||||
returns only non-deprecated v1 routes.
|
returns all non-deprecated routes.
|
||||||
required: false
|
required: false
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
|
|
|
||||||
2
docs/static/llama-stack-spec.yaml
vendored
2
docs/static/llama-stack-spec.yaml
vendored
|
|
@ -960,7 +960,7 @@ paths:
|
||||||
Optional filter to control which routes are returned. Can be an API level
|
Optional filter to control which routes are returned. Can be an API level
|
||||||
('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level,
|
('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level,
|
||||||
or 'deprecated' to show deprecated routes across all levels. If not specified,
|
or 'deprecated' to show deprecated routes across all levels. If not specified,
|
||||||
returns only non-deprecated v1 routes.
|
returns all non-deprecated routes.
|
||||||
required: false
|
required: false
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
|
|
|
||||||
2
docs/static/stainless-llama-stack-spec.yaml
vendored
2
docs/static/stainless-llama-stack-spec.yaml
vendored
|
|
@ -963,7 +963,7 @@ paths:
|
||||||
Optional filter to control which routes are returned. Can be an API level
|
Optional filter to control which routes are returned. Can be an API level
|
||||||
('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level,
|
('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level,
|
||||||
or 'deprecated' to show deprecated routes across all levels. If not specified,
|
or 'deprecated' to show deprecated routes across all levels. If not specified,
|
||||||
returns only non-deprecated v1 routes.
|
returns all non-deprecated routes.
|
||||||
required: false
|
required: false
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
|
|
|
||||||
|
|
@ -1,43 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
|
||||||
ChatCompletionResponseEventType,
|
|
||||||
ChatCompletionResponseStreamChunk,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LogEvent:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
content: str = "",
|
|
||||||
end: str = "\n",
|
|
||||||
color="white",
|
|
||||||
):
|
|
||||||
self.content = content
|
|
||||||
self.color = color
|
|
||||||
self.end = "\n" if end is None else end
|
|
||||||
|
|
||||||
def print(self, flush=True):
|
|
||||||
cprint(f"{self.content}", color=self.color, end=self.end, flush=flush)
|
|
||||||
|
|
||||||
|
|
||||||
class EventLogger:
|
|
||||||
async def log(self, event_generator):
|
|
||||||
async for chunk in event_generator:
|
|
||||||
if isinstance(chunk, ChatCompletionResponseStreamChunk):
|
|
||||||
event = chunk.event
|
|
||||||
if event.event_type == ChatCompletionResponseEventType.start:
|
|
||||||
yield LogEvent("Assistant> ", color="cyan", end="")
|
|
||||||
elif event.event_type == ChatCompletionResponseEventType.progress:
|
|
||||||
yield LogEvent(event.delta, color="yellow", end="")
|
|
||||||
elif event.event_type == ChatCompletionResponseEventType.complete:
|
|
||||||
yield LogEvent("")
|
|
||||||
else:
|
|
||||||
yield LogEvent("Assistant> ", color="cyan", end="")
|
|
||||||
yield LogEvent(chunk.completion_message.content, color="yellow")
|
|
||||||
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from enum import Enum
|
from enum import Enum, StrEnum
|
||||||
from typing import (
|
from typing import (
|
||||||
Annotated,
|
Annotated,
|
||||||
Any,
|
Any,
|
||||||
|
|
@ -15,28 +15,18 @@ from typing import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from fastapi import Body
|
from fastapi import Body
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
from llama_stack.apis.common.responses import MetricResponseMixin, Order
|
from llama_stack.apis.common.responses import (
|
||||||
|
Order,
|
||||||
|
)
|
||||||
from llama_stack.apis.common.tracing import telemetry_traceable
|
from llama_stack.apis.common.tracing import telemetry_traceable
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
|
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
|
||||||
from llama_stack.models.llama.datatypes import (
|
|
||||||
BuiltinTool,
|
|
||||||
StopReason,
|
|
||||||
ToolCall,
|
|
||||||
ToolDefinition,
|
|
||||||
ToolPromptFormat,
|
|
||||||
)
|
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
register_schema(ToolCall)
|
|
||||||
register_schema(ToolDefinition)
|
|
||||||
|
|
||||||
from enum import StrEnum
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class GreedySamplingStrategy(BaseModel):
|
class GreedySamplingStrategy(BaseModel):
|
||||||
|
|
@ -201,58 +191,6 @@ class ToolResponseMessage(BaseModel):
|
||||||
content: InterleavedContent
|
content: InterleavedContent
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class CompletionMessage(BaseModel):
|
|
||||||
"""A message containing the model's (assistant) response in a chat conversation.
|
|
||||||
|
|
||||||
:param role: Must be "assistant" to identify this as the model's response
|
|
||||||
:param content: The content of the model's response
|
|
||||||
:param stop_reason: Reason why the model stopped generating. Options are:
|
|
||||||
- `StopReason.end_of_turn`: The model finished generating the entire response.
|
|
||||||
- `StopReason.end_of_message`: The model finished generating but generated a partial response -- usually, a tool call. The user may call the tool and continue the conversation with the tool's response.
|
|
||||||
- `StopReason.out_of_tokens`: The model ran out of token budget.
|
|
||||||
:param tool_calls: List of tool calls. Each tool call is a ToolCall object.
|
|
||||||
"""
|
|
||||||
|
|
||||||
role: Literal["assistant"] = "assistant"
|
|
||||||
content: InterleavedContent
|
|
||||||
stop_reason: StopReason
|
|
||||||
tool_calls: list[ToolCall] | None = Field(default_factory=lambda: [])
|
|
||||||
|
|
||||||
|
|
||||||
Message = Annotated[
|
|
||||||
UserMessage | SystemMessage | ToolResponseMessage | CompletionMessage,
|
|
||||||
Field(discriminator="role"),
|
|
||||||
]
|
|
||||||
register_schema(Message, name="Message")
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ToolResponse(BaseModel):
|
|
||||||
"""Response from a tool invocation.
|
|
||||||
|
|
||||||
:param call_id: Unique identifier for the tool call this response is for
|
|
||||||
:param tool_name: Name of the tool that was invoked
|
|
||||||
:param content: The response content from the tool
|
|
||||||
:param metadata: (Optional) Additional metadata about the tool response
|
|
||||||
"""
|
|
||||||
|
|
||||||
call_id: str
|
|
||||||
tool_name: BuiltinTool | str
|
|
||||||
content: InterleavedContent
|
|
||||||
metadata: dict[str, Any] | None = None
|
|
||||||
|
|
||||||
@field_validator("tool_name", mode="before")
|
|
||||||
@classmethod
|
|
||||||
def validate_field(cls, v):
|
|
||||||
if isinstance(v, str):
|
|
||||||
try:
|
|
||||||
return BuiltinTool(v)
|
|
||||||
except ValueError:
|
|
||||||
return v
|
|
||||||
return v
|
|
||||||
|
|
||||||
|
|
||||||
class ToolChoice(Enum):
|
class ToolChoice(Enum):
|
||||||
"""Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model.
|
"""Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model.
|
||||||
|
|
||||||
|
|
@ -289,22 +227,6 @@ class ChatCompletionResponseEventType(Enum):
|
||||||
progress = "progress"
|
progress = "progress"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ChatCompletionResponseEvent(BaseModel):
|
|
||||||
"""An event during chat completion generation.
|
|
||||||
|
|
||||||
:param event_type: Type of the event
|
|
||||||
:param delta: Content generated since last event. This can be one or more tokens, or a tool call.
|
|
||||||
:param logprobs: Optional log probabilities for generated tokens
|
|
||||||
:param stop_reason: Optional reason why generation stopped, if complete
|
|
||||||
"""
|
|
||||||
|
|
||||||
event_type: ChatCompletionResponseEventType
|
|
||||||
delta: ContentDelta
|
|
||||||
logprobs: list[TokenLogProbs] | None = None
|
|
||||||
stop_reason: StopReason | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class ResponseFormatType(StrEnum):
|
class ResponseFormatType(StrEnum):
|
||||||
"""Types of formats for structured (guided) decoding.
|
"""Types of formats for structured (guided) decoding.
|
||||||
|
|
||||||
|
|
@ -357,34 +279,6 @@ class CompletionRequest(BaseModel):
|
||||||
logprobs: LogProbConfig | None = None
|
logprobs: LogProbConfig | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class CompletionResponse(MetricResponseMixin):
|
|
||||||
"""Response from a completion request.
|
|
||||||
|
|
||||||
:param content: The generated completion text
|
|
||||||
:param stop_reason: Reason why generation stopped
|
|
||||||
:param logprobs: Optional log probabilities for generated tokens
|
|
||||||
"""
|
|
||||||
|
|
||||||
content: str
|
|
||||||
stop_reason: StopReason
|
|
||||||
logprobs: list[TokenLogProbs] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class CompletionResponseStreamChunk(MetricResponseMixin):
|
|
||||||
"""A chunk of a streamed completion response.
|
|
||||||
|
|
||||||
:param delta: New content generated since last chunk. This can be one or more tokens.
|
|
||||||
:param stop_reason: Optional reason why generation stopped, if complete
|
|
||||||
:param logprobs: Optional log probabilities for generated tokens
|
|
||||||
"""
|
|
||||||
|
|
||||||
delta: str
|
|
||||||
stop_reason: StopReason | None = None
|
|
||||||
logprobs: list[TokenLogProbs] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class SystemMessageBehavior(Enum):
|
class SystemMessageBehavior(Enum):
|
||||||
"""Config for how to override the default system prompt.
|
"""Config for how to override the default system prompt.
|
||||||
|
|
||||||
|
|
@ -398,70 +292,6 @@ class SystemMessageBehavior(Enum):
|
||||||
replace = "replace"
|
replace = "replace"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ToolConfig(BaseModel):
|
|
||||||
"""Configuration for tool use.
|
|
||||||
|
|
||||||
:param tool_choice: (Optional) Whether tool use is automatic, required, or none. Can also specify a tool name to use a specific tool. Defaults to ToolChoice.auto.
|
|
||||||
:param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model.
|
|
||||||
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
|
|
||||||
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag.
|
|
||||||
- `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls.
|
|
||||||
:param system_message_behavior: (Optional) Config for how to override the default system prompt.
|
|
||||||
- `SystemMessageBehavior.append`: Appends the provided system message to the default system prompt.
|
|
||||||
- `SystemMessageBehavior.replace`: Replaces the default system prompt with the provided system message. The system message can include the string
|
|
||||||
'{{function_definitions}}' to indicate where the function definitions should be inserted.
|
|
||||||
"""
|
|
||||||
|
|
||||||
tool_choice: ToolChoice | str | None = Field(default=ToolChoice.auto)
|
|
||||||
tool_prompt_format: ToolPromptFormat | None = Field(default=None)
|
|
||||||
system_message_behavior: SystemMessageBehavior | None = Field(default=SystemMessageBehavior.append)
|
|
||||||
|
|
||||||
def model_post_init(self, __context: Any) -> None:
|
|
||||||
if isinstance(self.tool_choice, str):
|
|
||||||
try:
|
|
||||||
self.tool_choice = ToolChoice[self.tool_choice]
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# This is an internally used class
|
|
||||||
@json_schema_type
|
|
||||||
class ChatCompletionRequest(BaseModel):
|
|
||||||
model: str
|
|
||||||
messages: list[Message]
|
|
||||||
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
|
|
||||||
|
|
||||||
tools: list[ToolDefinition] | None = Field(default_factory=lambda: [])
|
|
||||||
tool_config: ToolConfig | None = Field(default_factory=ToolConfig)
|
|
||||||
|
|
||||||
response_format: ResponseFormat | None = None
|
|
||||||
stream: bool | None = False
|
|
||||||
logprobs: LogProbConfig | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ChatCompletionResponseStreamChunk(MetricResponseMixin):
|
|
||||||
"""A chunk of a streamed chat completion response.
|
|
||||||
|
|
||||||
:param event: The event containing the new content
|
|
||||||
"""
|
|
||||||
|
|
||||||
event: ChatCompletionResponseEvent
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ChatCompletionResponse(MetricResponseMixin):
|
|
||||||
"""Response from a chat completion request.
|
|
||||||
|
|
||||||
:param completion_message: The complete response message
|
|
||||||
:param logprobs: Optional log probabilities for generated tokens
|
|
||||||
"""
|
|
||||||
|
|
||||||
completion_message: CompletionMessage
|
|
||||||
logprobs: list[TokenLogProbs] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class EmbeddingsResponse(BaseModel):
|
class EmbeddingsResponse(BaseModel):
|
||||||
"""Response containing generated embeddings.
|
"""Response containing generated embeddings.
|
||||||
|
|
|
||||||
|
|
@ -76,7 +76,7 @@ class Inspect(Protocol):
|
||||||
|
|
||||||
List all available API routes with their methods and implementing providers.
|
List all available API routes with their methods and implementing providers.
|
||||||
|
|
||||||
:param api_filter: Optional filter to control which routes are returned. Can be an API level ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, or 'deprecated' to show deprecated routes across all levels. If not specified, returns only non-deprecated v1 routes.
|
:param api_filter: Optional filter to control which routes are returned. Can be an API level ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, or 'deprecated' to show deprecated routes across all levels. If not specified, returns all non-deprecated routes.
|
||||||
:returns: Response containing information about all available routes.
|
:returns: Response containing information about all available routes.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,6 @@ from llama_stack.apis.inspect import (
|
||||||
RouteInfo,
|
RouteInfo,
|
||||||
VersionInfo,
|
VersionInfo,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
|
||||||
from llama_stack.core.datatypes import StackRunConfig
|
from llama_stack.core.datatypes import StackRunConfig
|
||||||
from llama_stack.core.external import load_external_apis
|
from llama_stack.core.external import load_external_apis
|
||||||
from llama_stack.core.server.routes import get_all_api_routes
|
from llama_stack.core.server.routes import get_all_api_routes
|
||||||
|
|
@ -46,8 +45,8 @@ class DistributionInspectImpl(Inspect):
|
||||||
# Helper function to determine if a route should be included based on api_filter
|
# Helper function to determine if a route should be included based on api_filter
|
||||||
def should_include_route(webmethod) -> bool:
|
def should_include_route(webmethod) -> bool:
|
||||||
if api_filter is None:
|
if api_filter is None:
|
||||||
# Default: only non-deprecated v1 APIs
|
# Default: only non-deprecated APIs
|
||||||
return not webmethod.deprecated and webmethod.level == LLAMA_STACK_API_V1
|
return not webmethod.deprecated
|
||||||
elif api_filter == "deprecated":
|
elif api_filter == "deprecated":
|
||||||
# Special filter: show deprecated routes regardless of their actual level
|
# Special filter: show deprecated routes regardless of their actual level
|
||||||
return bool(webmethod.deprecated)
|
return bool(webmethod.deprecated)
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.inference import Message
|
from llama_stack.apis.inference import OpenAIMessageParam
|
||||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||||
from llama_stack.apis.safety.safety import ModerationObject
|
from llama_stack.apis.safety.safety import ModerationObject
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
|
|
@ -52,7 +52,7 @@ class SafetyRouter(Safety):
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self,
|
self,
|
||||||
shield_id: str,
|
shield_id: str,
|
||||||
messages: list[Message],
|
messages: list[OpenAIMessageParam],
|
||||||
params: dict[str, Any] = None,
|
params: dict[str, Any] = None,
|
||||||
) -> RunShieldResponse:
|
) -> RunShieldResponse:
|
||||||
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
||||||
|
|
|
||||||
|
|
@ -26,8 +26,10 @@ from fairscale.nn.model_parallel.initialize import (
|
||||||
)
|
)
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
|
from llama_stack.models.llama.datatypes import ToolPromptFormat
|
||||||
|
|
||||||
from ..checkpoint import maybe_reshard_state_dict
|
from ..checkpoint import maybe_reshard_state_dict
|
||||||
from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage, ToolPromptFormat
|
from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage
|
||||||
from .args import ModelArgs
|
from .args import ModelArgs
|
||||||
from .chat_format import ChatFormat, LLMInput
|
from .chat_format import ChatFormat, LLMInput
|
||||||
from .model import Transformer
|
from .model import Transformer
|
||||||
|
|
|
||||||
|
|
@ -15,13 +15,10 @@ from pathlib import Path
|
||||||
|
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
|
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall, ToolDefinition, ToolPromptFormat
|
||||||
|
|
||||||
from ..datatypes import (
|
from ..datatypes import (
|
||||||
BuiltinTool,
|
|
||||||
RawMessage,
|
RawMessage,
|
||||||
StopReason,
|
|
||||||
ToolCall,
|
|
||||||
ToolDefinition,
|
|
||||||
ToolPromptFormat,
|
|
||||||
)
|
)
|
||||||
from . import template_data
|
from . import template_data
|
||||||
from .chat_format import ChatFormat
|
from .chat_format import ChatFormat
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ import textwrap
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -8,8 +8,9 @@ import json
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.models.llama.datatypes import BuiltinTool, ToolCall, ToolPromptFormat
|
||||||
|
|
||||||
from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
|
from ..datatypes import RecursiveType
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="models::llama")
|
logger = get_logger(name=__name__, category="models::llama")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@
|
||||||
|
|
||||||
import textwrap
|
import textwrap
|
||||||
|
|
||||||
from llama_stack.apis.inference import ToolDefinition
|
from llama_stack.models.llama.datatypes import ToolDefinition
|
||||||
from llama_stack.models.llama.llama3.prompt_templates.base import (
|
from llama_stack.models.llama.llama3.prompt_templates.base import (
|
||||||
PromptTemplate,
|
PromptTemplate,
|
||||||
PromptTemplateGeneratorBase,
|
PromptTemplateGeneratorBase,
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from collections.abc import Generator
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -14,21 +13,19 @@ from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerToken
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
GreedySamplingStrategy,
|
GreedySamplingStrategy,
|
||||||
JsonSchemaResponseFormat,
|
JsonSchemaResponseFormat,
|
||||||
|
OpenAIChatCompletionRequestWithExtraBody,
|
||||||
|
OpenAIResponseFormatJSONSchema,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
|
ResponseFormatType,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
TopPSamplingStrategy,
|
TopPSamplingStrategy,
|
||||||
)
|
)
|
||||||
from llama_stack.models.llama.datatypes import QuantizationMode
|
from llama_stack.models.llama.datatypes import QuantizationMode, ToolPromptFormat
|
||||||
from llama_stack.models.llama.llama3.generation import Llama3
|
from llama_stack.models.llama.llama3.generation import Llama3
|
||||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
|
||||||
from llama_stack.models.llama.llama4.generation import Llama4
|
from llama_stack.models.llama.llama4.generation import Llama4
|
||||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
|
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
|
||||||
from llama_stack.models.llama.sku_types import Model, ModelFamily
|
from llama_stack.models.llama.sku_types import Model, ModelFamily
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
||||||
ChatCompletionRequestWithRawContent,
|
|
||||||
CompletionRequestWithRawContent,
|
|
||||||
get_default_tool_prompt_format,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .common import model_checkpoint_dir
|
from .common import model_checkpoint_dir
|
||||||
from .config import MetaReferenceInferenceConfig
|
from .config import MetaReferenceInferenceConfig
|
||||||
|
|
@ -106,14 +103,6 @@ def _infer_sampling_params(sampling_params: SamplingParams):
|
||||||
return temperature, top_p
|
return temperature, top_p
|
||||||
|
|
||||||
|
|
||||||
def _infer_tool_prompt_format(request: ChatCompletionRequestWithRawContent):
|
|
||||||
tool_config = request.tool_config
|
|
||||||
if tool_config is not None and tool_config.tool_prompt_format is not None:
|
|
||||||
return tool_config.tool_prompt_format
|
|
||||||
else:
|
|
||||||
return get_default_tool_prompt_format(request.model)
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaGenerator:
|
class LlamaGenerator:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -157,55 +146,56 @@ class LlamaGenerator:
|
||||||
self.args = self.inner_generator.args
|
self.args = self.inner_generator.args
|
||||||
self.formatter = self.inner_generator.formatter
|
self.formatter = self.inner_generator.formatter
|
||||||
|
|
||||||
def completion(
|
|
||||||
self,
|
|
||||||
request_batch: list[CompletionRequestWithRawContent],
|
|
||||||
) -> Generator:
|
|
||||||
first_request = request_batch[0]
|
|
||||||
sampling_params = first_request.sampling_params or SamplingParams()
|
|
||||||
max_gen_len = sampling_params.max_tokens
|
|
||||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
|
||||||
max_gen_len = self.args.max_seq_len - 1
|
|
||||||
|
|
||||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
|
||||||
yield from self.inner_generator.generate(
|
|
||||||
llm_inputs=[self.formatter.encode_content(request.content) for request in request_batch],
|
|
||||||
max_gen_len=max_gen_len,
|
|
||||||
temperature=temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
logprobs=bool(first_request.logprobs),
|
|
||||||
echo=False,
|
|
||||||
logits_processor=get_logits_processor(
|
|
||||||
self.tokenizer,
|
|
||||||
self.args.vocab_size,
|
|
||||||
first_request.response_format,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def chat_completion(
|
def chat_completion(
|
||||||
self,
|
self,
|
||||||
request_batch: list[ChatCompletionRequestWithRawContent],
|
request: OpenAIChatCompletionRequestWithExtraBody,
|
||||||
) -> Generator:
|
raw_messages: list,
|
||||||
first_request = request_batch[0]
|
):
|
||||||
sampling_params = first_request.sampling_params or SamplingParams()
|
"""Generate chat completion using OpenAI request format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: OpenAI chat completion request
|
||||||
|
raw_messages: Pre-converted list of RawMessage objects
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Determine tool prompt format
|
||||||
|
tool_prompt_format = ToolPromptFormat.json if request.tools else ToolPromptFormat.json
|
||||||
|
|
||||||
|
# Prepare sampling params
|
||||||
|
sampling_params = SamplingParams()
|
||||||
|
if request.temperature is not None or request.top_p is not None:
|
||||||
|
sampling_params.strategy = TopPSamplingStrategy(
|
||||||
|
temperature=request.temperature if request.temperature is not None else 1.0,
|
||||||
|
top_p=request.top_p if request.top_p is not None else 1.0,
|
||||||
|
)
|
||||||
|
if request.max_tokens:
|
||||||
|
sampling_params.max_tokens = request.max_tokens
|
||||||
|
|
||||||
max_gen_len = sampling_params.max_tokens
|
max_gen_len = sampling_params.max_tokens
|
||||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
||||||
max_gen_len = self.args.max_seq_len - 1
|
max_gen_len = self.args.max_seq_len - 1
|
||||||
|
|
||||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||||
|
|
||||||
|
# Get logits processor for response format
|
||||||
|
logits_processor = None
|
||||||
|
if request.response_format:
|
||||||
|
if isinstance(request.response_format, OpenAIResponseFormatJSONSchema):
|
||||||
|
# Extract the actual schema from OpenAIJSONSchema TypedDict
|
||||||
|
schema_dict = request.response_format.json_schema.get("schema") or {}
|
||||||
|
json_schema_format = JsonSchemaResponseFormat(
|
||||||
|
type=ResponseFormatType.json_schema,
|
||||||
|
json_schema=schema_dict,
|
||||||
|
)
|
||||||
|
logits_processor = get_logits_processor(self.tokenizer, self.args.vocab_size, json_schema_format)
|
||||||
|
|
||||||
|
# Generate
|
||||||
yield from self.inner_generator.generate(
|
yield from self.inner_generator.generate(
|
||||||
llm_inputs=[
|
llm_inputs=[self.formatter.encode_dialog_prompt(raw_messages, tool_prompt_format)],
|
||||||
self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))
|
|
||||||
for request in request_batch
|
|
||||||
],
|
|
||||||
max_gen_len=max_gen_len,
|
max_gen_len=max_gen_len,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
logprobs=bool(first_request.logprobs),
|
logprobs=False,
|
||||||
echo=False,
|
echo=False,
|
||||||
logits_processor=get_logits_processor(
|
logits_processor=logits_processor,
|
||||||
self.tokenizer,
|
|
||||||
self.args.vocab_size,
|
|
||||||
first_request.response_format,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -5,12 +5,19 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
InferenceProvider,
|
InferenceProvider,
|
||||||
|
OpenAIAssistantMessageParam,
|
||||||
OpenAIChatCompletionRequestWithExtraBody,
|
OpenAIChatCompletionRequestWithExtraBody,
|
||||||
|
OpenAIChatCompletionUsage,
|
||||||
|
OpenAIChoice,
|
||||||
OpenAICompletionRequestWithExtraBody,
|
OpenAICompletionRequestWithExtraBody,
|
||||||
|
OpenAIUserMessageParam,
|
||||||
|
ToolChoice,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference.inference import (
|
from llama_stack.apis.inference.inference import (
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
|
|
@ -19,12 +26,20 @@ from llama_stack.apis.inference.inference import (
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model, ModelType
|
from llama_stack.apis.models import Model, ModelType
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.models.llama.datatypes import RawMessage, RawTextItem, ToolDefinition
|
||||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
|
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
|
||||||
|
from llama_stack.models.llama.llama3.prompt_templates import (
|
||||||
|
JsonCustomToolGenerator,
|
||||||
|
SystemDefaultGenerator,
|
||||||
|
)
|
||||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
|
||||||
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
|
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
|
||||||
|
from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
|
||||||
|
PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4,
|
||||||
|
)
|
||||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
|
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
|
||||||
from llama_stack.models.llama.sku_list import resolve_model
|
from llama_stack.models.llama.sku_list import resolve_model
|
||||||
from llama_stack.models.llama.sku_types import ModelFamily
|
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
|
||||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||||
SentenceTransformerEmbeddingMixin,
|
SentenceTransformerEmbeddingMixin,
|
||||||
|
|
@ -44,6 +59,170 @@ log = get_logger(__name__, category="inference")
|
||||||
SEMAPHORE = asyncio.Semaphore(1)
|
SEMAPHORE = asyncio.Semaphore(1)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_openai_tool_to_tool_definition(tool) -> ToolDefinition:
|
||||||
|
"""Convert OpenAI tool format to ToolDefinition format."""
|
||||||
|
# OpenAI tools have function.name and function.parameters
|
||||||
|
return ToolDefinition(
|
||||||
|
tool_name=tool.function.name,
|
||||||
|
description=tool.function.description or "",
|
||||||
|
parameters=tool.function.parameters or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_tool_choice_prompt(tool_choice, tools) -> str:
|
||||||
|
"""Generate prompt text for tool_choice behavior."""
|
||||||
|
if not tool_choice or tool_choice == ToolChoice.auto or tool_choice == "auto":
|
||||||
|
return ""
|
||||||
|
elif tool_choice == ToolChoice.required or tool_choice == "required":
|
||||||
|
return "You MUST use one of the provided functions/tools to answer the user query."
|
||||||
|
elif tool_choice == ToolChoice.none or tool_choice == "none":
|
||||||
|
return ""
|
||||||
|
else:
|
||||||
|
# Specific tool specified
|
||||||
|
return f"You MUST use the tool `{tool_choice}` to answer the user query."
|
||||||
|
|
||||||
|
|
||||||
|
def _raw_content_as_str(content) -> str:
|
||||||
|
"""Convert RawContent to string for system messages."""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
elif isinstance(content, RawTextItem):
|
||||||
|
return content.text
|
||||||
|
elif isinstance(content, list):
|
||||||
|
return "\n".join(_raw_content_as_str(c) for c in content)
|
||||||
|
else:
|
||||||
|
return "<media>"
|
||||||
|
|
||||||
|
|
||||||
|
def _augment_raw_messages_for_tools_llama_3_1(
|
||||||
|
raw_messages: list[RawMessage],
|
||||||
|
tools: list,
|
||||||
|
tool_choice,
|
||||||
|
) -> list[RawMessage]:
|
||||||
|
"""Augment raw messages with tool definitions for Llama 3.1 style models."""
|
||||||
|
messages = raw_messages.copy()
|
||||||
|
existing_system_message = None
|
||||||
|
if messages and messages[0].role == "system":
|
||||||
|
existing_system_message = messages.pop(0)
|
||||||
|
|
||||||
|
sys_content = ""
|
||||||
|
|
||||||
|
# Add tool definitions first (if present)
|
||||||
|
if tools:
|
||||||
|
# Convert OpenAI tools to ToolDefinitions
|
||||||
|
tool_definitions = [_convert_openai_tool_to_tool_definition(t) for t in tools]
|
||||||
|
|
||||||
|
# For OpenAI format, all tools are custom (have string names)
|
||||||
|
tool_gen = JsonCustomToolGenerator()
|
||||||
|
tool_template = tool_gen.gen(tool_definitions)
|
||||||
|
sys_content += tool_template.render()
|
||||||
|
sys_content += "\n"
|
||||||
|
|
||||||
|
# Add default system prompt
|
||||||
|
default_gen = SystemDefaultGenerator()
|
||||||
|
default_template = default_gen.gen()
|
||||||
|
sys_content += default_template.render()
|
||||||
|
|
||||||
|
# Add existing system message if present
|
||||||
|
if existing_system_message:
|
||||||
|
sys_content += "\n" + _raw_content_as_str(existing_system_message.content)
|
||||||
|
|
||||||
|
# Add tool choice prompt if needed
|
||||||
|
if tool_choice_prompt := _get_tool_choice_prompt(tool_choice, tools):
|
||||||
|
sys_content += "\n" + tool_choice_prompt
|
||||||
|
|
||||||
|
# Create new system message
|
||||||
|
new_system_message = RawMessage(
|
||||||
|
role="system",
|
||||||
|
content=[RawTextItem(text=sys_content.strip())],
|
||||||
|
)
|
||||||
|
|
||||||
|
return [new_system_message] + messages
|
||||||
|
|
||||||
|
|
||||||
|
def _augment_raw_messages_for_tools_llama_4(
|
||||||
|
raw_messages: list[RawMessage],
|
||||||
|
tools: list,
|
||||||
|
tool_choice,
|
||||||
|
) -> list[RawMessage]:
|
||||||
|
"""Augment raw messages with tool definitions for Llama 4/3.2/3.3 style models."""
|
||||||
|
messages = raw_messages.copy()
|
||||||
|
existing_system_message = None
|
||||||
|
if messages and messages[0].role == "system":
|
||||||
|
existing_system_message = messages.pop(0)
|
||||||
|
|
||||||
|
sys_content = ""
|
||||||
|
|
||||||
|
# Add tool definitions if present
|
||||||
|
if tools:
|
||||||
|
# Convert OpenAI tools to ToolDefinitions
|
||||||
|
tool_definitions = [_convert_openai_tool_to_tool_definition(t) for t in tools]
|
||||||
|
|
||||||
|
# Use python_list format for Llama 4
|
||||||
|
tool_gen = PythonListCustomToolGeneratorLlama4()
|
||||||
|
system_prompt = None
|
||||||
|
if existing_system_message:
|
||||||
|
system_prompt = _raw_content_as_str(existing_system_message.content)
|
||||||
|
|
||||||
|
tool_template = tool_gen.gen(tool_definitions, system_prompt)
|
||||||
|
sys_content = tool_template.render()
|
||||||
|
elif existing_system_message:
|
||||||
|
# No tools, just use existing system message
|
||||||
|
sys_content = _raw_content_as_str(existing_system_message.content)
|
||||||
|
|
||||||
|
# Add tool choice prompt if needed
|
||||||
|
if tool_choice_prompt := _get_tool_choice_prompt(tool_choice, tools):
|
||||||
|
sys_content += "\n" + tool_choice_prompt
|
||||||
|
|
||||||
|
if sys_content:
|
||||||
|
new_system_message = RawMessage(
|
||||||
|
role="system",
|
||||||
|
content=[RawTextItem(text=sys_content.strip())],
|
||||||
|
)
|
||||||
|
return [new_system_message] + messages
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def augment_raw_messages_for_tools(
|
||||||
|
raw_messages: list[RawMessage],
|
||||||
|
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||||
|
llama_model,
|
||||||
|
) -> list[RawMessage]:
|
||||||
|
"""Augment raw messages with tool definitions based on model family."""
|
||||||
|
if not params.tools:
|
||||||
|
return raw_messages
|
||||||
|
|
||||||
|
# Determine augmentation strategy based on model family
|
||||||
|
if llama_model.model_family == ModelFamily.llama3_1 or (
|
||||||
|
llama_model.model_family == ModelFamily.llama3_2 and is_multimodal(llama_model.core_model_id)
|
||||||
|
):
|
||||||
|
# Llama 3.1 and Llama 3.2 multimodal use JSON format
|
||||||
|
return _augment_raw_messages_for_tools_llama_3_1(
|
||||||
|
raw_messages,
|
||||||
|
params.tools,
|
||||||
|
params.tool_choice,
|
||||||
|
)
|
||||||
|
elif llama_model.model_family in (
|
||||||
|
ModelFamily.llama3_2,
|
||||||
|
ModelFamily.llama3_3,
|
||||||
|
ModelFamily.llama4,
|
||||||
|
):
|
||||||
|
# Llama 3.2/3.3/4 use python_list format
|
||||||
|
return _augment_raw_messages_for_tools_llama_4(
|
||||||
|
raw_messages,
|
||||||
|
params.tools,
|
||||||
|
params.tool_choice,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Default to Llama 3.1 style
|
||||||
|
return _augment_raw_messages_for_tools_llama_3_1(
|
||||||
|
raw_messages,
|
||||||
|
params.tools,
|
||||||
|
params.tool_choice,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator:
|
def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator:
|
||||||
return LlamaGenerator(config, model_id, llama_model)
|
return LlamaGenerator(config, model_id, llama_model)
|
||||||
|
|
||||||
|
|
@ -136,10 +315,13 @@ class MetaReferenceInferenceImpl(
|
||||||
self.llama_model = llama_model
|
self.llama_model = llama_model
|
||||||
|
|
||||||
log.info("Warming up...")
|
log.info("Warming up...")
|
||||||
|
|
||||||
await self.openai_chat_completion(
|
await self.openai_chat_completion(
|
||||||
model=model_id,
|
params=OpenAIChatCompletionRequestWithExtraBody(
|
||||||
messages=[{"role": "user", "content": "Hi how are you?"}],
|
model=model_id,
|
||||||
max_tokens=20,
|
messages=[OpenAIUserMessageParam(role="user", content="Hi how are you?")],
|
||||||
|
max_tokens=20,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
log.info("Warmed up!")
|
log.info("Warmed up!")
|
||||||
|
|
||||||
|
|
@ -155,4 +337,207 @@ class MetaReferenceInferenceImpl(
|
||||||
self,
|
self,
|
||||||
params: OpenAIChatCompletionRequestWithExtraBody,
|
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||||
raise NotImplementedError("OpenAI chat completion not supported by meta-reference inference provider")
|
self.check_model(params)
|
||||||
|
|
||||||
|
# Convert OpenAI messages to RawMessages
|
||||||
|
from llama_stack.models.llama.datatypes import StopReason
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
convert_openai_message_to_raw_message,
|
||||||
|
decode_assistant_message,
|
||||||
|
)
|
||||||
|
|
||||||
|
raw_messages = [await convert_openai_message_to_raw_message(msg) for msg in params.messages]
|
||||||
|
|
||||||
|
# Augment messages with tool definitions if tools are present
|
||||||
|
raw_messages = augment_raw_messages_for_tools(raw_messages, params, self.llama_model)
|
||||||
|
|
||||||
|
# Call generator's chat_completion method (works for both single-GPU and model-parallel)
|
||||||
|
if isinstance(self.generator, LlamaGenerator):
|
||||||
|
generator = self.generator.chat_completion(params, raw_messages)
|
||||||
|
else:
|
||||||
|
# Model parallel: submit task to process group
|
||||||
|
generator = self.generator.group.run_inference(("chat_completion", [params, raw_messages]))
|
||||||
|
|
||||||
|
# Check if streaming is requested
|
||||||
|
if params.stream:
|
||||||
|
return self._stream_chat_completion(generator, params)
|
||||||
|
|
||||||
|
# Non-streaming: collect all generated text
|
||||||
|
generated_text = ""
|
||||||
|
for result_batch in generator:
|
||||||
|
for result in result_batch:
|
||||||
|
if not result.ignore_token and result.source == "output":
|
||||||
|
generated_text += result.text
|
||||||
|
|
||||||
|
# Decode assistant message to extract tool calls and determine stop_reason
|
||||||
|
# Default to end_of_turn if generation completed normally
|
||||||
|
decoded_message = decode_assistant_message(generated_text, StopReason.end_of_turn)
|
||||||
|
|
||||||
|
# Convert tool calls to OpenAI format
|
||||||
|
openai_tool_calls = None
|
||||||
|
if decoded_message.tool_calls:
|
||||||
|
from llama_stack.apis.inference import (
|
||||||
|
OpenAIChatCompletionToolCall,
|
||||||
|
OpenAIChatCompletionToolCallFunction,
|
||||||
|
)
|
||||||
|
|
||||||
|
openai_tool_calls = [
|
||||||
|
OpenAIChatCompletionToolCall(
|
||||||
|
# generate a uuid for the call id. This is the only inline provider that does this, so need to get creative.
|
||||||
|
id=f"call_{uuid.uuid4().hex[:24]}",
|
||||||
|
type="function",
|
||||||
|
function=OpenAIChatCompletionToolCallFunction(
|
||||||
|
name=str(tc.tool_name),
|
||||||
|
arguments=tc.arguments,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for tc in decoded_message.tool_calls
|
||||||
|
]
|
||||||
|
|
||||||
|
# Determine finish_reason based on whether tool calls are present
|
||||||
|
finish_reason = "tool_calls" if openai_tool_calls else "stop"
|
||||||
|
|
||||||
|
# Extract content from decoded message
|
||||||
|
content = ""
|
||||||
|
if isinstance(decoded_message.content, str):
|
||||||
|
content = decoded_message.content
|
||||||
|
elif isinstance(decoded_message.content, list):
|
||||||
|
for item in decoded_message.content:
|
||||||
|
if isinstance(item, RawTextItem):
|
||||||
|
content += item.text
|
||||||
|
|
||||||
|
# Create OpenAI response
|
||||||
|
# generate a uuid for the call id. This is the only inline provider that does this, so need to get creative.
|
||||||
|
response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
|
||||||
|
created = int(time.time())
|
||||||
|
|
||||||
|
return OpenAIChatCompletion(
|
||||||
|
id=response_id,
|
||||||
|
object="chat.completion",
|
||||||
|
created=created,
|
||||||
|
model=params.model,
|
||||||
|
choices=[
|
||||||
|
OpenAIChoice(
|
||||||
|
index=0,
|
||||||
|
message=OpenAIAssistantMessageParam(
|
||||||
|
role="assistant",
|
||||||
|
content=content,
|
||||||
|
tool_calls=openai_tool_calls,
|
||||||
|
),
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
logprobs=None,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
usage=OpenAIChatCompletionUsage(
|
||||||
|
prompt_tokens=0, # TODO: calculate properly
|
||||||
|
completion_tokens=0, # TODO: calculate properly
|
||||||
|
total_tokens=0, # TODO: calculate properly
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _stream_chat_completion(
|
||||||
|
self,
|
||||||
|
generator,
|
||||||
|
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||||
|
) -> AsyncIterator[OpenAIChatCompletionChunk]:
|
||||||
|
"""Stream chat completion chunks as they're generated."""
|
||||||
|
from llama_stack.apis.inference import (
|
||||||
|
OpenAIChatCompletionChunk,
|
||||||
|
OpenAIChatCompletionToolCall,
|
||||||
|
OpenAIChatCompletionToolCallFunction,
|
||||||
|
OpenAIChoiceDelta,
|
||||||
|
OpenAIChunkChoice,
|
||||||
|
)
|
||||||
|
from llama_stack.models.llama.datatypes import StopReason
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import decode_assistant_message
|
||||||
|
|
||||||
|
response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
|
||||||
|
created = int(time.time())
|
||||||
|
generated_text = ""
|
||||||
|
|
||||||
|
# Yield chunks as tokens are generated
|
||||||
|
for result_batch in generator:
|
||||||
|
for result in result_batch:
|
||||||
|
if result.ignore_token or result.source != "output":
|
||||||
|
continue
|
||||||
|
|
||||||
|
generated_text += result.text
|
||||||
|
|
||||||
|
# Yield delta chunk with the new text
|
||||||
|
chunk = OpenAIChatCompletionChunk(
|
||||||
|
id=response_id,
|
||||||
|
object="chat.completion.chunk",
|
||||||
|
created=created,
|
||||||
|
model=params.model,
|
||||||
|
choices=[
|
||||||
|
OpenAIChunkChoice(
|
||||||
|
index=0,
|
||||||
|
delta=OpenAIChoiceDelta(
|
||||||
|
role="assistant",
|
||||||
|
content=result.text,
|
||||||
|
),
|
||||||
|
finish_reason="",
|
||||||
|
logprobs=None,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
# After generation completes, decode the full message to extract tool calls
|
||||||
|
decoded_message = decode_assistant_message(generated_text, StopReason.end_of_turn)
|
||||||
|
|
||||||
|
# If tool calls are present, yield a final chunk with tool_calls
|
||||||
|
if decoded_message.tool_calls:
|
||||||
|
openai_tool_calls = [
|
||||||
|
OpenAIChatCompletionToolCall(
|
||||||
|
# generate a uuid for the call id. This is the only inline provider that does this, so need to get creative.
|
||||||
|
id=f"call_{uuid.uuid4().hex[:24]}",
|
||||||
|
type="function",
|
||||||
|
function=OpenAIChatCompletionToolCallFunction(
|
||||||
|
name=str(tc.tool_name),
|
||||||
|
arguments=tc.arguments,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for tc in decoded_message.tool_calls
|
||||||
|
]
|
||||||
|
|
||||||
|
# Yield chunk with tool_calls
|
||||||
|
chunk = OpenAIChatCompletionChunk(
|
||||||
|
id=response_id,
|
||||||
|
object="chat.completion.chunk",
|
||||||
|
created=created,
|
||||||
|
model=params.model,
|
||||||
|
choices=[
|
||||||
|
OpenAIChunkChoice(
|
||||||
|
index=0,
|
||||||
|
delta=OpenAIChoiceDelta(
|
||||||
|
role="assistant",
|
||||||
|
tool_calls=openai_tool_calls,
|
||||||
|
),
|
||||||
|
finish_reason="",
|
||||||
|
logprobs=None,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
finish_reason = "tool_calls"
|
||||||
|
else:
|
||||||
|
finish_reason = "stop"
|
||||||
|
|
||||||
|
# Yield final chunk with finish_reason
|
||||||
|
final_chunk = OpenAIChatCompletionChunk(
|
||||||
|
id=response_id,
|
||||||
|
object="chat.completion.chunk",
|
||||||
|
created=created,
|
||||||
|
model=params.model,
|
||||||
|
choices=[
|
||||||
|
OpenAIChunkChoice(
|
||||||
|
index=0,
|
||||||
|
delta=OpenAIChoiceDelta(),
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
logprobs=None,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
yield final_chunk
|
||||||
|
|
|
||||||
|
|
@ -4,17 +4,12 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from collections.abc import Callable, Generator
|
from collections.abc import Callable
|
||||||
from copy import deepcopy
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
|
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
|
||||||
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
|
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
||||||
ChatCompletionRequestWithRawContent,
|
|
||||||
CompletionRequestWithRawContent,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .parallel_utils import ModelParallelProcessGroup
|
from .parallel_utils import ModelParallelProcessGroup
|
||||||
|
|
||||||
|
|
@ -23,12 +18,14 @@ class ModelRunner:
|
||||||
def __init__(self, llama):
|
def __init__(self, llama):
|
||||||
self.llama = llama
|
self.llama = llama
|
||||||
|
|
||||||
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
|
|
||||||
def __call__(self, task: Any):
|
def __call__(self, task: Any):
|
||||||
if task[0] == "chat_completion":
|
task_type = task[0]
|
||||||
return self.llama.chat_completion(task[1])
|
if task_type == "chat_completion":
|
||||||
|
# task[1] is [params, raw_messages]
|
||||||
|
params, raw_messages = task[1]
|
||||||
|
return self.llama.chat_completion(params, raw_messages)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unexpected task type {task[0]}")
|
raise ValueError(f"Unexpected task type {task_type}")
|
||||||
|
|
||||||
|
|
||||||
def init_model_cb(
|
def init_model_cb(
|
||||||
|
|
@ -78,19 +75,3 @@ class LlamaModelParallelGenerator:
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, exc_traceback):
|
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||||
self.group.stop()
|
self.group.stop()
|
||||||
|
|
||||||
def completion(
|
|
||||||
self,
|
|
||||||
request_batch: list[CompletionRequestWithRawContent],
|
|
||||||
) -> Generator:
|
|
||||||
req_obj = deepcopy(request_batch)
|
|
||||||
gen = self.group.run_inference(("completion", req_obj))
|
|
||||||
yield from gen
|
|
||||||
|
|
||||||
def chat_completion(
|
|
||||||
self,
|
|
||||||
request_batch: list[ChatCompletionRequestWithRawContent],
|
|
||||||
) -> Generator:
|
|
||||||
req_obj = deepcopy(request_batch)
|
|
||||||
gen = self.group.run_inference(("chat_completion", req_obj))
|
|
||||||
yield from gen
|
|
||||||
|
|
|
||||||
|
|
@ -33,10 +33,6 @@ from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
||||||
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.datatypes import GenerationResult
|
from llama_stack.models.llama.datatypes import GenerationResult
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
||||||
ChatCompletionRequestWithRawContent,
|
|
||||||
CompletionRequestWithRawContent,
|
|
||||||
)
|
|
||||||
|
|
||||||
log = get_logger(name=__name__, category="inference")
|
log = get_logger(name=__name__, category="inference")
|
||||||
|
|
||||||
|
|
@ -69,10 +65,7 @@ class CancelSentinel(BaseModel):
|
||||||
|
|
||||||
class TaskRequest(BaseModel):
|
class TaskRequest(BaseModel):
|
||||||
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
|
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
|
||||||
task: tuple[
|
task: tuple[str, list]
|
||||||
str,
|
|
||||||
list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent],
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class TaskResponse(BaseModel):
|
class TaskResponse(BaseModel):
|
||||||
|
|
@ -328,10 +321,7 @@ class ModelParallelProcessGroup:
|
||||||
|
|
||||||
def run_inference(
|
def run_inference(
|
||||||
self,
|
self,
|
||||||
req: tuple[
|
req: tuple[str, list],
|
||||||
str,
|
|
||||||
list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent],
|
|
||||||
],
|
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
assert not self.running, "inference already running"
|
assert not self.running, "inference already running"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -22,9 +22,6 @@ from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||||
SentenceTransformerEmbeddingMixin,
|
SentenceTransformerEmbeddingMixin,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
|
||||||
OpenAIChatCompletionToLlamaStackMixin,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .config import SentenceTransformersInferenceConfig
|
from .config import SentenceTransformersInferenceConfig
|
||||||
|
|
||||||
|
|
@ -32,7 +29,6 @@ log = get_logger(name=__name__, category="inference")
|
||||||
|
|
||||||
|
|
||||||
class SentenceTransformersInferenceImpl(
|
class SentenceTransformersInferenceImpl(
|
||||||
OpenAIChatCompletionToLlamaStackMixin,
|
|
||||||
SentenceTransformerEmbeddingMixin,
|
SentenceTransformerEmbeddingMixin,
|
||||||
InferenceProvider,
|
InferenceProvider,
|
||||||
ModelsProtocolPrivate,
|
ModelsProtocolPrivate,
|
||||||
|
|
|
||||||
|
|
@ -11,9 +11,7 @@ from collections.abc import AsyncIterator
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
|
||||||
InferenceProvider,
|
InferenceProvider,
|
||||||
JsonSchemaResponseFormat,
|
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
OpenAIChatCompletionChunk,
|
OpenAIChatCompletionChunk,
|
||||||
OpenAIChatCompletionRequestWithExtraBody,
|
OpenAIChatCompletionRequestWithExtraBody,
|
||||||
|
|
@ -23,15 +21,11 @@ from llama_stack.apis.inference import (
|
||||||
OpenAIEmbeddingsRequestWithExtraBody,
|
OpenAIEmbeddingsRequestWithExtraBody,
|
||||||
OpenAIEmbeddingsResponse,
|
OpenAIEmbeddingsResponse,
|
||||||
OpenAIEmbeddingUsage,
|
OpenAIEmbeddingUsage,
|
||||||
ToolChoice,
|
|
||||||
)
|
)
|
||||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
convert_message_to_openai_dict_new,
|
|
||||||
convert_tooldef_to_openai_tool,
|
|
||||||
get_sampling_options,
|
|
||||||
prepare_openai_completion_params,
|
prepare_openai_completion_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -127,51 +121,6 @@ class LiteLLMOpenAIMixin(
|
||||||
|
|
||||||
return schema
|
return schema
|
||||||
|
|
||||||
async def _get_params(self, request: ChatCompletionRequest) -> dict:
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
input_dict: dict[str, Any] = {}
|
|
||||||
|
|
||||||
input_dict["messages"] = [
|
|
||||||
await convert_message_to_openai_dict_new(m, download_images=self.download_images) for m in request.messages
|
|
||||||
]
|
|
||||||
if fmt := request.response_format:
|
|
||||||
if not isinstance(fmt, JsonSchemaResponseFormat):
|
|
||||||
raise ValueError(
|
|
||||||
f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert to dict for manipulation
|
|
||||||
fmt_dict = dict(fmt.json_schema)
|
|
||||||
name = fmt_dict["title"]
|
|
||||||
del fmt_dict["title"]
|
|
||||||
fmt_dict["additionalProperties"] = False
|
|
||||||
|
|
||||||
# Apply additionalProperties: False recursively to all objects
|
|
||||||
fmt_dict = self._add_additional_properties_recursive(fmt_dict)
|
|
||||||
|
|
||||||
input_dict["response_format"] = {
|
|
||||||
"type": "json_schema",
|
|
||||||
"json_schema": {
|
|
||||||
"name": name,
|
|
||||||
"schema": fmt_dict,
|
|
||||||
"strict": self.json_schema_strict,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
if request.tools:
|
|
||||||
input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools]
|
|
||||||
if request.tool_config and (tool_choice := request.tool_config.tool_choice):
|
|
||||||
input_dict["tool_choice"] = tool_choice.value if isinstance(tool_choice, ToolChoice) else tool_choice
|
|
||||||
|
|
||||||
return {
|
|
||||||
"model": request.model,
|
|
||||||
"api_key": self.get_api_key(),
|
|
||||||
"api_base": self.api_base,
|
|
||||||
**input_dict,
|
|
||||||
"stream": request.stream,
|
|
||||||
**get_sampling_options(request.sampling_params),
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_api_key(self) -> str:
|
def get_api_key(self) -> str:
|
||||||
provider_data = self.get_request_provider_data()
|
provider_data = self.get_request_provider_data()
|
||||||
key_field = self.provider_data_api_key_field
|
key_field = self.provider_data_api_key_field
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -21,19 +21,18 @@ from llama_stack.apis.common.content_types import (
|
||||||
TextContentItem,
|
TextContentItem,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
Message,
|
OpenAIAssistantMessageParam,
|
||||||
OpenAIChatCompletionContentPartImageParam,
|
OpenAIChatCompletionContentPartImageParam,
|
||||||
OpenAIChatCompletionContentPartTextParam,
|
OpenAIChatCompletionContentPartTextParam,
|
||||||
OpenAIFile,
|
OpenAIFile,
|
||||||
|
OpenAIMessageParam,
|
||||||
|
OpenAISystemMessageParam,
|
||||||
|
OpenAIToolMessageParam,
|
||||||
|
OpenAIUserMessageParam,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
ResponseFormatType,
|
ResponseFormatType,
|
||||||
SystemMessage,
|
|
||||||
SystemMessageBehavior,
|
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolDefinition,
|
|
||||||
UserMessage,
|
|
||||||
)
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
|
|
@ -42,33 +41,19 @@ from llama_stack.models.llama.datatypes import (
|
||||||
RawMediaItem,
|
RawMediaItem,
|
||||||
RawMessage,
|
RawMessage,
|
||||||
RawTextItem,
|
RawTextItem,
|
||||||
Role,
|
|
||||||
StopReason,
|
StopReason,
|
||||||
|
ToolCall,
|
||||||
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||||
from llama_stack.models.llama.llama3.prompt_templates import (
|
|
||||||
BuiltinToolGenerator,
|
|
||||||
FunctionTagCustomToolGenerator,
|
|
||||||
JsonCustomToolGenerator,
|
|
||||||
PythonListCustomToolGenerator,
|
|
||||||
SystemDefaultGenerator,
|
|
||||||
)
|
|
||||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||||
from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
|
|
||||||
PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4,
|
|
||||||
)
|
|
||||||
from llama_stack.models.llama.sku_list import resolve_model
|
from llama_stack.models.llama.sku_list import resolve_model
|
||||||
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
|
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
|
||||||
from llama_stack.providers.utils.inference import supported_inference_models
|
|
||||||
|
|
||||||
log = get_logger(name=__name__, category="providers::utils")
|
log = get_logger(name=__name__, category="providers::utils")
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionRequestWithRawContent(ChatCompletionRequest):
|
|
||||||
messages: list[RawMessage]
|
|
||||||
|
|
||||||
|
|
||||||
class CompletionRequestWithRawContent(CompletionRequest):
|
class CompletionRequestWithRawContent(CompletionRequest):
|
||||||
content: RawContent
|
content: RawContent
|
||||||
|
|
||||||
|
|
@ -103,28 +88,6 @@ def interleaved_content_as_str(
|
||||||
return _process(content)
|
return _process(content)
|
||||||
|
|
||||||
|
|
||||||
async def convert_request_to_raw(
|
|
||||||
request: ChatCompletionRequest | CompletionRequest,
|
|
||||||
) -> ChatCompletionRequestWithRawContent | CompletionRequestWithRawContent:
|
|
||||||
if isinstance(request, ChatCompletionRequest):
|
|
||||||
messages = []
|
|
||||||
for m in request.messages:
|
|
||||||
content = await interleaved_content_convert_to_raw(m.content)
|
|
||||||
d = m.model_dump()
|
|
||||||
d["content"] = content
|
|
||||||
messages.append(RawMessage(**d))
|
|
||||||
|
|
||||||
d = request.model_dump()
|
|
||||||
d["messages"] = messages
|
|
||||||
request = ChatCompletionRequestWithRawContent(**d)
|
|
||||||
else:
|
|
||||||
d = request.model_dump()
|
|
||||||
d["content"] = await interleaved_content_convert_to_raw(request.content)
|
|
||||||
request = CompletionRequestWithRawContent(**d)
|
|
||||||
|
|
||||||
return request
|
|
||||||
|
|
||||||
|
|
||||||
async def interleaved_content_convert_to_raw(
|
async def interleaved_content_convert_to_raw(
|
||||||
content: InterleavedContent,
|
content: InterleavedContent,
|
||||||
) -> RawContent:
|
) -> RawContent:
|
||||||
|
|
@ -171,6 +134,36 @@ async def interleaved_content_convert_to_raw(
|
||||||
return await _localize_single(content)
|
return await _localize_single(content)
|
||||||
|
|
||||||
|
|
||||||
|
async def convert_openai_message_to_raw_message(message: OpenAIMessageParam) -> RawMessage:
|
||||||
|
"""Convert OpenAI message format to RawMessage format used by Llama formatters."""
|
||||||
|
if isinstance(message, OpenAIUserMessageParam):
|
||||||
|
content = await interleaved_content_convert_to_raw(message.content) # type: ignore[arg-type]
|
||||||
|
return RawMessage(role="user", content=content)
|
||||||
|
elif isinstance(message, OpenAISystemMessageParam):
|
||||||
|
content = await interleaved_content_convert_to_raw(message.content) # type: ignore[arg-type]
|
||||||
|
return RawMessage(role="system", content=content)
|
||||||
|
elif isinstance(message, OpenAIAssistantMessageParam):
|
||||||
|
content = await interleaved_content_convert_to_raw(message.content or "") # type: ignore[arg-type]
|
||||||
|
tool_calls = []
|
||||||
|
if message.tool_calls:
|
||||||
|
for tc in message.tool_calls:
|
||||||
|
if tc.function:
|
||||||
|
tool_calls.append(
|
||||||
|
ToolCall(
|
||||||
|
call_id=tc.id or "",
|
||||||
|
tool_name=tc.function.name or "",
|
||||||
|
arguments=tc.function.arguments or "{}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return RawMessage(role="assistant", content=content, tool_calls=tool_calls)
|
||||||
|
elif isinstance(message, OpenAIToolMessageParam):
|
||||||
|
content = await interleaved_content_convert_to_raw(message.content) # type: ignore[arg-type]
|
||||||
|
return RawMessage(role="tool", content=content)
|
||||||
|
else:
|
||||||
|
# Handle OpenAIDeveloperMessageParam if needed
|
||||||
|
raise ValueError(f"Unsupported message type: {type(message)}")
|
||||||
|
|
||||||
|
|
||||||
def content_has_media(content: InterleavedContent):
|
def content_has_media(content: InterleavedContent):
|
||||||
def _has_media_content(c):
|
def _has_media_content(c):
|
||||||
return isinstance(c, ImageContentItem)
|
return isinstance(c, ImageContentItem)
|
||||||
|
|
@ -181,17 +174,6 @@ def content_has_media(content: InterleavedContent):
|
||||||
return _has_media_content(content)
|
return _has_media_content(content)
|
||||||
|
|
||||||
|
|
||||||
def messages_have_media(messages: list[Message]):
|
|
||||||
return any(content_has_media(m.content) for m in messages)
|
|
||||||
|
|
||||||
|
|
||||||
def request_has_media(request: ChatCompletionRequest | CompletionRequest):
|
|
||||||
if isinstance(request, ChatCompletionRequest):
|
|
||||||
return messages_have_media(request.messages)
|
|
||||||
else:
|
|
||||||
return content_has_media(request.content)
|
|
||||||
|
|
||||||
|
|
||||||
async def localize_image_content(uri: str) -> tuple[bytes, str] | None:
|
async def localize_image_content(uri: str) -> tuple[bytes, str] | None:
|
||||||
if uri.startswith("http"):
|
if uri.startswith("http"):
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
|
|
@ -253,79 +235,6 @@ def augment_content_with_response_format_prompt(response_format, content):
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
||||||
async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llama_model: str) -> str:
|
|
||||||
messages = chat_completion_request_to_messages(request, llama_model)
|
|
||||||
request.messages = messages
|
|
||||||
request = await convert_request_to_raw(request)
|
|
||||||
|
|
||||||
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
|
|
||||||
model_input = formatter.encode_dialog_prompt(
|
|
||||||
request.messages,
|
|
||||||
tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model),
|
|
||||||
)
|
|
||||||
return formatter.tokenizer.decode(model_input.tokens)
|
|
||||||
|
|
||||||
|
|
||||||
async def chat_completion_request_to_model_input_info(
|
|
||||||
request: ChatCompletionRequest, llama_model: str
|
|
||||||
) -> tuple[str, int]:
|
|
||||||
messages = chat_completion_request_to_messages(request, llama_model)
|
|
||||||
request.messages = messages
|
|
||||||
request = await convert_request_to_raw(request)
|
|
||||||
|
|
||||||
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
|
|
||||||
model_input = formatter.encode_dialog_prompt(
|
|
||||||
request.messages,
|
|
||||||
tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model),
|
|
||||||
)
|
|
||||||
return (
|
|
||||||
formatter.tokenizer.decode(model_input.tokens),
|
|
||||||
len(model_input.tokens),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def chat_completion_request_to_messages(
|
|
||||||
request: ChatCompletionRequest,
|
|
||||||
llama_model: str,
|
|
||||||
) -> list[Message]:
|
|
||||||
"""Reads chat completion request and augments the messages to handle tools.
|
|
||||||
For eg. for llama_3_1, add system message with the appropriate tools or
|
|
||||||
add user messsage for custom tools, etc.
|
|
||||||
"""
|
|
||||||
assert llama_model is not None, "llama_model is required"
|
|
||||||
model = resolve_model(llama_model)
|
|
||||||
if model is None:
|
|
||||||
log.error(f"Could not resolve model {llama_model}")
|
|
||||||
return request.messages
|
|
||||||
|
|
||||||
allowed_models = supported_inference_models()
|
|
||||||
descriptors = [m.descriptor() for m in allowed_models]
|
|
||||||
if model.descriptor() not in descriptors:
|
|
||||||
log.error(f"Unsupported inference model? {model.descriptor()}")
|
|
||||||
return request.messages
|
|
||||||
|
|
||||||
if model.model_family == ModelFamily.llama3_1 or (
|
|
||||||
model.model_family == ModelFamily.llama3_2 and is_multimodal(model.core_model_id)
|
|
||||||
):
|
|
||||||
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
|
|
||||||
messages = augment_messages_for_tools_llama_3_1(request)
|
|
||||||
elif model.model_family in (
|
|
||||||
ModelFamily.llama3_2,
|
|
||||||
ModelFamily.llama3_3,
|
|
||||||
):
|
|
||||||
# llama3.2, llama3.3 follow the same tool prompt format
|
|
||||||
messages = augment_messages_for_tools_llama(request, PythonListCustomToolGenerator)
|
|
||||||
elif model.model_family == ModelFamily.llama4:
|
|
||||||
messages = augment_messages_for_tools_llama(request, PythonListCustomToolGeneratorLlama4)
|
|
||||||
else:
|
|
||||||
messages = request.messages
|
|
||||||
|
|
||||||
if fmt_prompt := response_format_prompt(request.response_format):
|
|
||||||
messages.append(UserMessage(content=fmt_prompt))
|
|
||||||
|
|
||||||
return messages
|
|
||||||
|
|
||||||
|
|
||||||
def response_format_prompt(fmt: ResponseFormat | None):
|
def response_format_prompt(fmt: ResponseFormat | None):
|
||||||
if not fmt:
|
if not fmt:
|
||||||
return None
|
return None
|
||||||
|
|
@ -338,128 +247,6 @@ def response_format_prompt(fmt: ResponseFormat | None):
|
||||||
raise ValueError(f"Unknown response format {fmt.type}")
|
raise ValueError(f"Unknown response format {fmt.type}")
|
||||||
|
|
||||||
|
|
||||||
def augment_messages_for_tools_llama_3_1(
|
|
||||||
request: ChatCompletionRequest,
|
|
||||||
) -> list[Message]:
|
|
||||||
existing_messages = request.messages
|
|
||||||
existing_system_message = None
|
|
||||||
if existing_messages[0].role == Role.system.value:
|
|
||||||
existing_system_message = existing_messages.pop(0)
|
|
||||||
|
|
||||||
assert existing_messages[0].role != Role.system.value, "Should only have 1 system message"
|
|
||||||
|
|
||||||
messages = []
|
|
||||||
|
|
||||||
default_gen = SystemDefaultGenerator()
|
|
||||||
default_template = default_gen.gen()
|
|
||||||
|
|
||||||
sys_content = ""
|
|
||||||
|
|
||||||
tool_template = None
|
|
||||||
if request.tools:
|
|
||||||
tool_gen = BuiltinToolGenerator()
|
|
||||||
tool_template = tool_gen.gen(request.tools)
|
|
||||||
|
|
||||||
sys_content += tool_template.render()
|
|
||||||
sys_content += "\n"
|
|
||||||
|
|
||||||
sys_content += default_template.render()
|
|
||||||
|
|
||||||
if existing_system_message:
|
|
||||||
# TODO: this fn is needed in many places
|
|
||||||
def _process(c):
|
|
||||||
if isinstance(c, str):
|
|
||||||
return c
|
|
||||||
else:
|
|
||||||
return "<media>"
|
|
||||||
|
|
||||||
sys_content += "\n"
|
|
||||||
|
|
||||||
if isinstance(existing_system_message.content, str):
|
|
||||||
sys_content += _process(existing_system_message.content)
|
|
||||||
elif isinstance(existing_system_message.content, list):
|
|
||||||
sys_content += "\n".join([_process(c) for c in existing_system_message.content])
|
|
||||||
|
|
||||||
tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools)
|
|
||||||
if tool_choice_prompt:
|
|
||||||
sys_content += "\n" + tool_choice_prompt
|
|
||||||
|
|
||||||
messages.append(SystemMessage(content=sys_content))
|
|
||||||
|
|
||||||
has_custom_tools = request.tools is not None and any(isinstance(dfn.tool_name, str) for dfn in request.tools)
|
|
||||||
if has_custom_tools:
|
|
||||||
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.json
|
|
||||||
if fmt == ToolPromptFormat.json:
|
|
||||||
tool_gen = JsonCustomToolGenerator()
|
|
||||||
elif fmt == ToolPromptFormat.function_tag:
|
|
||||||
tool_gen = FunctionTagCustomToolGenerator()
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Non supported ToolPromptFormat {fmt}")
|
|
||||||
|
|
||||||
custom_tools = [t for t in request.tools if isinstance(t.tool_name, str)]
|
|
||||||
custom_template = tool_gen.gen(custom_tools)
|
|
||||||
messages.append(UserMessage(content=custom_template.render()))
|
|
||||||
|
|
||||||
# Add back existing messages from the request
|
|
||||||
messages += existing_messages
|
|
||||||
|
|
||||||
return messages
|
|
||||||
|
|
||||||
|
|
||||||
def augment_messages_for_tools_llama(
|
|
||||||
request: ChatCompletionRequest,
|
|
||||||
custom_tool_prompt_generator,
|
|
||||||
) -> list[Message]:
|
|
||||||
existing_messages = request.messages
|
|
||||||
existing_system_message = None
|
|
||||||
if existing_messages[0].role == Role.system.value:
|
|
||||||
existing_system_message = existing_messages.pop(0)
|
|
||||||
|
|
||||||
assert existing_messages[0].role != Role.system.value, "Should only have 1 system message"
|
|
||||||
|
|
||||||
sys_content = ""
|
|
||||||
custom_tools, builtin_tools = [], []
|
|
||||||
for t in request.tools:
|
|
||||||
if isinstance(t.tool_name, str):
|
|
||||||
custom_tools.append(t)
|
|
||||||
else:
|
|
||||||
builtin_tools.append(t)
|
|
||||||
|
|
||||||
if builtin_tools:
|
|
||||||
tool_gen = BuiltinToolGenerator()
|
|
||||||
tool_template = tool_gen.gen(builtin_tools)
|
|
||||||
|
|
||||||
sys_content += tool_template.render()
|
|
||||||
sys_content += "\n"
|
|
||||||
|
|
||||||
custom_tools = [dfn for dfn in request.tools if isinstance(dfn.tool_name, str)]
|
|
||||||
if custom_tools:
|
|
||||||
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.python_list
|
|
||||||
if fmt != ToolPromptFormat.python_list:
|
|
||||||
raise ValueError(f"Non supported ToolPromptFormat {request.tool_config.tool_prompt_format}")
|
|
||||||
|
|
||||||
system_prompt = None
|
|
||||||
if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace:
|
|
||||||
system_prompt = existing_system_message.content
|
|
||||||
|
|
||||||
tool_template = custom_tool_prompt_generator().gen(custom_tools, system_prompt)
|
|
||||||
|
|
||||||
sys_content += tool_template.render()
|
|
||||||
sys_content += "\n"
|
|
||||||
|
|
||||||
if existing_system_message and (
|
|
||||||
request.tool_config.system_message_behavior == SystemMessageBehavior.append or not custom_tools
|
|
||||||
):
|
|
||||||
sys_content += interleaved_content_as_str(existing_system_message.content, sep="\n")
|
|
||||||
|
|
||||||
tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools)
|
|
||||||
if tool_choice_prompt:
|
|
||||||
sys_content += "\n" + tool_choice_prompt
|
|
||||||
|
|
||||||
messages = [SystemMessage(content=sys_content.strip("\n")), *existing_messages]
|
|
||||||
return messages
|
|
||||||
|
|
||||||
|
|
||||||
def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: list[ToolDefinition]) -> str:
|
def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: list[ToolDefinition]) -> str:
|
||||||
if tool_choice == ToolChoice.auto:
|
if tool_choice == ToolChoice.auto:
|
||||||
return ""
|
return ""
|
||||||
|
|
|
||||||
|
|
@ -1,303 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
|
||||||
ChatCompletionRequest,
|
|
||||||
CompletionMessage,
|
|
||||||
StopReason,
|
|
||||||
SystemMessage,
|
|
||||||
SystemMessageBehavior,
|
|
||||||
ToolCall,
|
|
||||||
ToolConfig,
|
|
||||||
UserMessage,
|
|
||||||
)
|
|
||||||
from llama_stack.models.llama.datatypes import (
|
|
||||||
BuiltinTool,
|
|
||||||
ToolDefinition,
|
|
||||||
ToolPromptFormat,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
||||||
chat_completion_request_to_messages,
|
|
||||||
chat_completion_request_to_prompt,
|
|
||||||
interleaved_content_as_str,
|
|
||||||
)
|
|
||||||
|
|
||||||
MODEL = "Llama3.1-8B-Instruct"
|
|
||||||
MODEL3_2 = "Llama3.2-3B-Instruct"
|
|
||||||
|
|
||||||
|
|
||||||
async def test_system_default():
|
|
||||||
content = "Hello !"
|
|
||||||
request = ChatCompletionRequest(
|
|
||||||
model=MODEL,
|
|
||||||
messages=[
|
|
||||||
UserMessage(content=content),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
messages = chat_completion_request_to_messages(request, MODEL)
|
|
||||||
assert len(messages) == 2
|
|
||||||
assert messages[-1].content == content
|
|
||||||
assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_system_builtin_only():
|
|
||||||
content = "Hello !"
|
|
||||||
request = ChatCompletionRequest(
|
|
||||||
model=MODEL,
|
|
||||||
messages=[
|
|
||||||
UserMessage(content=content),
|
|
||||||
],
|
|
||||||
tools=[
|
|
||||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
|
||||||
ToolDefinition(tool_name=BuiltinTool.brave_search),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
messages = chat_completion_request_to_messages(request, MODEL)
|
|
||||||
assert len(messages) == 2
|
|
||||||
assert messages[-1].content == content
|
|
||||||
assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content)
|
|
||||||
assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_system_custom_only():
|
|
||||||
content = "Hello !"
|
|
||||||
request = ChatCompletionRequest(
|
|
||||||
model=MODEL,
|
|
||||||
messages=[
|
|
||||||
UserMessage(content=content),
|
|
||||||
],
|
|
||||||
tools=[
|
|
||||||
ToolDefinition(
|
|
||||||
tool_name="custom1",
|
|
||||||
description="custom1 tool",
|
|
||||||
input_schema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"param1": {
|
|
||||||
"type": "str",
|
|
||||||
"description": "param1 description",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["param1"],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
],
|
|
||||||
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json),
|
|
||||||
)
|
|
||||||
messages = chat_completion_request_to_messages(request, MODEL)
|
|
||||||
assert len(messages) == 3
|
|
||||||
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
|
|
||||||
|
|
||||||
assert "Return function calls in JSON format" in interleaved_content_as_str(messages[1].content)
|
|
||||||
assert messages[-1].content == content
|
|
||||||
|
|
||||||
|
|
||||||
async def test_system_custom_and_builtin():
|
|
||||||
content = "Hello !"
|
|
||||||
request = ChatCompletionRequest(
|
|
||||||
model=MODEL,
|
|
||||||
messages=[
|
|
||||||
UserMessage(content=content),
|
|
||||||
],
|
|
||||||
tools=[
|
|
||||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
|
||||||
ToolDefinition(tool_name=BuiltinTool.brave_search),
|
|
||||||
ToolDefinition(
|
|
||||||
tool_name="custom1",
|
|
||||||
description="custom1 tool",
|
|
||||||
input_schema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"param1": {
|
|
||||||
"type": "str",
|
|
||||||
"description": "param1 description",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["param1"],
|
|
||||||
},
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
messages = chat_completion_request_to_messages(request, MODEL)
|
|
||||||
assert len(messages) == 3
|
|
||||||
|
|
||||||
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
|
|
||||||
assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content)
|
|
||||||
|
|
||||||
assert "Return function calls in JSON format" in interleaved_content_as_str(messages[1].content)
|
|
||||||
assert messages[-1].content == content
|
|
||||||
|
|
||||||
|
|
||||||
async def test_completion_message_encoding():
|
|
||||||
request = ChatCompletionRequest(
|
|
||||||
model=MODEL3_2,
|
|
||||||
messages=[
|
|
||||||
UserMessage(content="hello"),
|
|
||||||
CompletionMessage(
|
|
||||||
content="",
|
|
||||||
stop_reason=StopReason.end_of_turn,
|
|
||||||
tool_calls=[
|
|
||||||
ToolCall(
|
|
||||||
tool_name="custom1",
|
|
||||||
arguments='{"param1": "value1"}', # arguments must be a JSON string
|
|
||||||
call_id="123",
|
|
||||||
)
|
|
||||||
],
|
|
||||||
),
|
|
||||||
],
|
|
||||||
tools=[
|
|
||||||
ToolDefinition(
|
|
||||||
tool_name="custom1",
|
|
||||||
description="custom1 tool",
|
|
||||||
input_schema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"param1": {
|
|
||||||
"type": "str",
|
|
||||||
"description": "param1 description",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["param1"],
|
|
||||||
},
|
|
||||||
),
|
|
||||||
],
|
|
||||||
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list),
|
|
||||||
)
|
|
||||||
prompt = await chat_completion_request_to_prompt(request, request.model)
|
|
||||||
assert '[custom1(param1="value1")]' in prompt
|
|
||||||
|
|
||||||
request.model = MODEL
|
|
||||||
request.tool_config = ToolConfig(tool_prompt_format=ToolPromptFormat.json)
|
|
||||||
prompt = await chat_completion_request_to_prompt(request, request.model)
|
|
||||||
assert '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}' in prompt
|
|
||||||
|
|
||||||
|
|
||||||
async def test_user_provided_system_message():
|
|
||||||
content = "Hello !"
|
|
||||||
system_prompt = "You are a pirate"
|
|
||||||
request = ChatCompletionRequest(
|
|
||||||
model=MODEL,
|
|
||||||
messages=[
|
|
||||||
SystemMessage(content=system_prompt),
|
|
||||||
UserMessage(content=content),
|
|
||||||
],
|
|
||||||
tools=[
|
|
||||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
messages = chat_completion_request_to_messages(request, MODEL)
|
|
||||||
assert len(messages) == 2
|
|
||||||
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
|
|
||||||
|
|
||||||
assert messages[-1].content == content
|
|
||||||
|
|
||||||
|
|
||||||
async def test_replace_system_message_behavior_builtin_tools():
|
|
||||||
content = "Hello !"
|
|
||||||
system_prompt = "You are a pirate"
|
|
||||||
request = ChatCompletionRequest(
|
|
||||||
model=MODEL,
|
|
||||||
messages=[
|
|
||||||
SystemMessage(content=system_prompt),
|
|
||||||
UserMessage(content=content),
|
|
||||||
],
|
|
||||||
tools=[
|
|
||||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
|
||||||
],
|
|
||||||
tool_config=ToolConfig(
|
|
||||||
tool_choice="auto",
|
|
||||||
tool_prompt_format=ToolPromptFormat.python_list,
|
|
||||||
system_message_behavior=SystemMessageBehavior.replace,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
|
||||||
assert len(messages) == 2
|
|
||||||
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
|
|
||||||
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
|
|
||||||
assert messages[-1].content == content
|
|
||||||
|
|
||||||
|
|
||||||
async def test_replace_system_message_behavior_custom_tools():
|
|
||||||
content = "Hello !"
|
|
||||||
system_prompt = "You are a pirate"
|
|
||||||
request = ChatCompletionRequest(
|
|
||||||
model=MODEL,
|
|
||||||
messages=[
|
|
||||||
SystemMessage(content=system_prompt),
|
|
||||||
UserMessage(content=content),
|
|
||||||
],
|
|
||||||
tools=[
|
|
||||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
|
||||||
ToolDefinition(
|
|
||||||
tool_name="custom1",
|
|
||||||
description="custom1 tool",
|
|
||||||
input_schema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"param1": {
|
|
||||||
"type": "str",
|
|
||||||
"description": "param1 description",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["param1"],
|
|
||||||
},
|
|
||||||
),
|
|
||||||
],
|
|
||||||
tool_config=ToolConfig(
|
|
||||||
tool_choice="auto",
|
|
||||||
tool_prompt_format=ToolPromptFormat.python_list,
|
|
||||||
system_message_behavior=SystemMessageBehavior.replace,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
|
||||||
|
|
||||||
assert len(messages) == 2
|
|
||||||
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
|
|
||||||
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
|
|
||||||
assert messages[-1].content == content
|
|
||||||
|
|
||||||
|
|
||||||
async def test_replace_system_message_behavior_custom_tools_with_template():
|
|
||||||
content = "Hello !"
|
|
||||||
system_prompt = "You are a pirate {{ function_description }}"
|
|
||||||
request = ChatCompletionRequest(
|
|
||||||
model=MODEL,
|
|
||||||
messages=[
|
|
||||||
SystemMessage(content=system_prompt),
|
|
||||||
UserMessage(content=content),
|
|
||||||
],
|
|
||||||
tools=[
|
|
||||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
|
||||||
ToolDefinition(
|
|
||||||
tool_name="custom1",
|
|
||||||
description="custom1 tool",
|
|
||||||
input_schema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"param1": {
|
|
||||||
"type": "str",
|
|
||||||
"description": "param1 description",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["param1"],
|
|
||||||
},
|
|
||||||
),
|
|
||||||
],
|
|
||||||
tool_config=ToolConfig(
|
|
||||||
tool_choice="auto",
|
|
||||||
tool_prompt_format=ToolPromptFormat.python_list,
|
|
||||||
system_message_behavior=SystemMessageBehavior.replace,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
|
||||||
|
|
||||||
assert len(messages) == 2
|
|
||||||
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
|
|
||||||
assert "You are a pirate" in interleaved_content_as_str(messages[0].content)
|
|
||||||
# function description is present in the system prompt
|
|
||||||
assert '"name": "custom1"' in interleaved_content_as_str(messages[0].content)
|
|
||||||
assert messages[-1].content == content
|
|
||||||
5
tests/unit/providers/inline/inference/__init__.py
Normal file
5
tests/unit/providers/inline/inference/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
44
tests/unit/providers/inline/inference/test_meta_reference.py
Normal file
44
tests/unit/providers/inline/inference/test_meta_reference.py
Normal file
|
|
@ -0,0 +1,44 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.providers.inline.inference.meta_reference.model_parallel import (
|
||||||
|
ModelRunner,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestModelRunner:
|
||||||
|
"""Test ModelRunner task dispatching for model-parallel inference."""
|
||||||
|
|
||||||
|
def test_chat_completion_task_dispatch(self):
|
||||||
|
"""Verify ModelRunner correctly dispatches chat_completion tasks."""
|
||||||
|
# Create a mock generator
|
||||||
|
mock_generator = Mock()
|
||||||
|
mock_generator.chat_completion = Mock(return_value=iter([]))
|
||||||
|
|
||||||
|
runner = ModelRunner(mock_generator)
|
||||||
|
|
||||||
|
# Create a chat_completion task
|
||||||
|
fake_params = {"model": "test"}
|
||||||
|
fake_messages = [{"role": "user", "content": "test"}]
|
||||||
|
task = ("chat_completion", [fake_params, fake_messages])
|
||||||
|
|
||||||
|
# Execute task
|
||||||
|
runner(task)
|
||||||
|
|
||||||
|
# Verify chat_completion was called with correct arguments
|
||||||
|
mock_generator.chat_completion.assert_called_once_with(fake_params, fake_messages)
|
||||||
|
|
||||||
|
def test_invalid_task_type_raises_error(self):
|
||||||
|
"""Verify ModelRunner rejects invalid task types."""
|
||||||
|
mock_generator = Mock()
|
||||||
|
runner = ModelRunner(mock_generator)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Unexpected task type"):
|
||||||
|
runner(("invalid_task", []))
|
||||||
|
|
@ -10,11 +10,13 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_stack.apis.inference import CompletionMessage, UserMessage
|
from llama_stack.apis.inference import (
|
||||||
|
OpenAIAssistantMessageParam,
|
||||||
|
OpenAIUserMessageParam,
|
||||||
|
)
|
||||||
from llama_stack.apis.resource import ResourceType
|
from llama_stack.apis.resource import ResourceType
|
||||||
from llama_stack.apis.safety import RunShieldResponse, ViolationLevel
|
from llama_stack.apis.safety import RunShieldResponse, ViolationLevel
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.models.llama.datatypes import StopReason
|
|
||||||
from llama_stack.providers.remote.safety.nvidia.config import NVIDIASafetyConfig
|
from llama_stack.providers.remote.safety.nvidia.config import NVIDIASafetyConfig
|
||||||
from llama_stack.providers.remote.safety.nvidia.nvidia import NVIDIASafetyAdapter
|
from llama_stack.providers.remote.safety.nvidia.nvidia import NVIDIASafetyAdapter
|
||||||
|
|
||||||
|
|
@ -136,11 +138,9 @@ async def test_run_shield_allowed(nvidia_adapter, mock_guardrails_post):
|
||||||
|
|
||||||
# Run the shield
|
# Run the shield
|
||||||
messages = [
|
messages = [
|
||||||
UserMessage(role="user", content="Hello, how are you?"),
|
OpenAIUserMessageParam(content="Hello, how are you?"),
|
||||||
CompletionMessage(
|
OpenAIAssistantMessageParam(
|
||||||
role="assistant",
|
|
||||||
content="I'm doing well, thank you for asking!",
|
content="I'm doing well, thank you for asking!",
|
||||||
stop_reason=StopReason.end_of_message,
|
|
||||||
tool_calls=[],
|
tool_calls=[],
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
@ -191,13 +191,10 @@ async def test_run_shield_blocked(nvidia_adapter, mock_guardrails_post):
|
||||||
# Mock Guardrails API response
|
# Mock Guardrails API response
|
||||||
mock_guardrails_post.return_value = {"status": "blocked", "rails_status": {"reason": "harmful_content"}}
|
mock_guardrails_post.return_value = {"status": "blocked", "rails_status": {"reason": "harmful_content"}}
|
||||||
|
|
||||||
# Run the shield
|
|
||||||
messages = [
|
messages = [
|
||||||
UserMessage(role="user", content="Hello, how are you?"),
|
OpenAIUserMessageParam(content="Hello, how are you?"),
|
||||||
CompletionMessage(
|
OpenAIAssistantMessageParam(
|
||||||
role="assistant",
|
|
||||||
content="I'm doing well, thank you for asking!",
|
content="I'm doing well, thank you for asking!",
|
||||||
stop_reason=StopReason.end_of_message,
|
|
||||||
tool_calls=[],
|
tool_calls=[],
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
@ -243,7 +240,7 @@ async def test_run_shield_not_found(nvidia_adapter, mock_guardrails_post):
|
||||||
adapter.shield_store.get_shield.return_value = None
|
adapter.shield_store.get_shield.return_value = None
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
UserMessage(role="user", content="Hello, how are you?"),
|
OpenAIUserMessageParam(content="Hello, how are you?"),
|
||||||
]
|
]
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
|
|
@ -274,11 +271,9 @@ async def test_run_shield_http_error(nvidia_adapter, mock_guardrails_post):
|
||||||
|
|
||||||
# Running the shield should raise an exception
|
# Running the shield should raise an exception
|
||||||
messages = [
|
messages = [
|
||||||
UserMessage(role="user", content="Hello, how are you?"),
|
OpenAIUserMessageParam(content="Hello, how are you?"),
|
||||||
CompletionMessage(
|
OpenAIAssistantMessageParam(
|
||||||
role="assistant",
|
|
||||||
content="I'm doing well, thank you for asking!",
|
content="I'm doing well, thank you for asking!",
|
||||||
stop_reason=StopReason.end_of_message,
|
|
||||||
tool_calls=[],
|
tool_calls=[],
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,220 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from pydantic import ValidationError
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import TextContentItem
|
|
||||||
from llama_stack.apis.inference import (
|
|
||||||
CompletionMessage,
|
|
||||||
OpenAIAssistantMessageParam,
|
|
||||||
OpenAIChatCompletionContentPartImageParam,
|
|
||||||
OpenAIChatCompletionContentPartTextParam,
|
|
||||||
OpenAIDeveloperMessageParam,
|
|
||||||
OpenAIImageURL,
|
|
||||||
OpenAISystemMessageParam,
|
|
||||||
OpenAIToolMessageParam,
|
|
||||||
OpenAIUserMessageParam,
|
|
||||||
SystemMessage,
|
|
||||||
UserMessage,
|
|
||||||
)
|
|
||||||
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
|
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
|
||||||
convert_message_to_openai_dict,
|
|
||||||
convert_message_to_openai_dict_new,
|
|
||||||
openai_messages_to_messages,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_convert_message_to_openai_dict():
|
|
||||||
message = UserMessage(content=[TextContentItem(text="Hello, world!")], role="user")
|
|
||||||
assert await convert_message_to_openai_dict(message) == {
|
|
||||||
"role": "user",
|
|
||||||
"content": [{"type": "text", "text": "Hello, world!"}],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# Test convert_message_to_openai_dict with a tool call
|
|
||||||
async def test_convert_message_to_openai_dict_with_tool_call():
|
|
||||||
message = CompletionMessage(
|
|
||||||
content="",
|
|
||||||
tool_calls=[ToolCall(call_id="123", tool_name="test_tool", arguments='{"foo": "bar"}')],
|
|
||||||
stop_reason=StopReason.end_of_turn,
|
|
||||||
)
|
|
||||||
|
|
||||||
openai_dict = await convert_message_to_openai_dict(message)
|
|
||||||
|
|
||||||
assert openai_dict == {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": [{"type": "text", "text": ""}],
|
|
||||||
"tool_calls": [
|
|
||||||
{"id": "123", "type": "function", "function": {"name": "test_tool", "arguments": '{"foo": "bar"}'}}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def test_convert_message_to_openai_dict_with_builtin_tool_call():
|
|
||||||
message = CompletionMessage(
|
|
||||||
content="",
|
|
||||||
tool_calls=[
|
|
||||||
ToolCall(
|
|
||||||
call_id="123",
|
|
||||||
tool_name=BuiltinTool.brave_search,
|
|
||||||
arguments='{"foo": "bar"}',
|
|
||||||
)
|
|
||||||
],
|
|
||||||
stop_reason=StopReason.end_of_turn,
|
|
||||||
)
|
|
||||||
|
|
||||||
openai_dict = await convert_message_to_openai_dict(message)
|
|
||||||
|
|
||||||
assert openai_dict == {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": [{"type": "text", "text": ""}],
|
|
||||||
"tool_calls": [
|
|
||||||
{"id": "123", "type": "function", "function": {"name": "brave_search", "arguments": '{"foo": "bar"}'}}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def test_openai_messages_to_messages_with_content_str():
|
|
||||||
openai_messages = [
|
|
||||||
OpenAISystemMessageParam(content="system message"),
|
|
||||||
OpenAIUserMessageParam(content="user message"),
|
|
||||||
OpenAIAssistantMessageParam(content="assistant message"),
|
|
||||||
]
|
|
||||||
|
|
||||||
llama_messages = openai_messages_to_messages(openai_messages)
|
|
||||||
assert len(llama_messages) == 3
|
|
||||||
assert isinstance(llama_messages[0], SystemMessage)
|
|
||||||
assert isinstance(llama_messages[1], UserMessage)
|
|
||||||
assert isinstance(llama_messages[2], CompletionMessage)
|
|
||||||
assert llama_messages[0].content == "system message"
|
|
||||||
assert llama_messages[1].content == "user message"
|
|
||||||
assert llama_messages[2].content == "assistant message"
|
|
||||||
|
|
||||||
|
|
||||||
async def test_openai_messages_to_messages_with_content_list():
|
|
||||||
openai_messages = [
|
|
||||||
OpenAISystemMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="system message")]),
|
|
||||||
OpenAIUserMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="user message")]),
|
|
||||||
OpenAIAssistantMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="assistant message")]),
|
|
||||||
]
|
|
||||||
|
|
||||||
llama_messages = openai_messages_to_messages(openai_messages)
|
|
||||||
assert len(llama_messages) == 3
|
|
||||||
assert isinstance(llama_messages[0], SystemMessage)
|
|
||||||
assert isinstance(llama_messages[1], UserMessage)
|
|
||||||
assert isinstance(llama_messages[2], CompletionMessage)
|
|
||||||
assert llama_messages[0].content[0].text == "system message"
|
|
||||||
assert llama_messages[1].content[0].text == "user message"
|
|
||||||
assert llama_messages[2].content[0].text == "assistant message"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"message_class,kwargs",
|
|
||||||
[
|
|
||||||
(OpenAISystemMessageParam, {}),
|
|
||||||
(OpenAIAssistantMessageParam, {}),
|
|
||||||
(OpenAIDeveloperMessageParam, {}),
|
|
||||||
(OpenAIUserMessageParam, {}),
|
|
||||||
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_message_accepts_text_string(message_class, kwargs):
|
|
||||||
"""Test that messages accept string text content."""
|
|
||||||
msg = message_class(content="Test message", **kwargs)
|
|
||||||
assert msg.content == "Test message"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"message_class,kwargs",
|
|
||||||
[
|
|
||||||
(OpenAISystemMessageParam, {}),
|
|
||||||
(OpenAIAssistantMessageParam, {}),
|
|
||||||
(OpenAIDeveloperMessageParam, {}),
|
|
||||||
(OpenAIUserMessageParam, {}),
|
|
||||||
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_message_accepts_text_list(message_class, kwargs):
|
|
||||||
"""Test that messages accept list of text content parts."""
|
|
||||||
content_list = [OpenAIChatCompletionContentPartTextParam(text="Test message")]
|
|
||||||
msg = message_class(content=content_list, **kwargs)
|
|
||||||
assert len(msg.content) == 1
|
|
||||||
assert msg.content[0].text == "Test message"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"message_class,kwargs",
|
|
||||||
[
|
|
||||||
(OpenAISystemMessageParam, {}),
|
|
||||||
(OpenAIAssistantMessageParam, {}),
|
|
||||||
(OpenAIDeveloperMessageParam, {}),
|
|
||||||
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_message_rejects_images(message_class, kwargs):
|
|
||||||
"""Test that system, assistant, developer, and tool messages reject image content."""
|
|
||||||
with pytest.raises(ValidationError):
|
|
||||||
message_class(
|
|
||||||
content=[
|
|
||||||
OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg"))
|
|
||||||
],
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_user_message_accepts_images():
|
|
||||||
"""Test that user messages accept image content (unlike other message types)."""
|
|
||||||
# List with images should work
|
|
||||||
msg = OpenAIUserMessageParam(
|
|
||||||
content=[
|
|
||||||
OpenAIChatCompletionContentPartTextParam(text="Describe this image:"),
|
|
||||||
OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg")),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
assert len(msg.content) == 2
|
|
||||||
assert msg.content[0].text == "Describe this image:"
|
|
||||||
assert msg.content[1].image_url.url == "http://example.com/image.jpg"
|
|
||||||
|
|
||||||
|
|
||||||
async def test_convert_message_to_openai_dict_new_user_message():
|
|
||||||
"""Test convert_message_to_openai_dict_new with UserMessage."""
|
|
||||||
message = UserMessage(content="Hello, world!", role="user")
|
|
||||||
result = await convert_message_to_openai_dict_new(message)
|
|
||||||
|
|
||||||
assert result["role"] == "user"
|
|
||||||
assert result["content"] == "Hello, world!"
|
|
||||||
|
|
||||||
|
|
||||||
async def test_convert_message_to_openai_dict_new_completion_message_with_tool_calls():
|
|
||||||
"""Test convert_message_to_openai_dict_new with CompletionMessage containing tool calls."""
|
|
||||||
message = CompletionMessage(
|
|
||||||
content="I'll help you find the weather.",
|
|
||||||
tool_calls=[
|
|
||||||
ToolCall(
|
|
||||||
call_id="call_123",
|
|
||||||
tool_name="get_weather",
|
|
||||||
arguments='{"city": "Sligo"}',
|
|
||||||
)
|
|
||||||
],
|
|
||||||
stop_reason=StopReason.end_of_turn,
|
|
||||||
)
|
|
||||||
result = await convert_message_to_openai_dict_new(message)
|
|
||||||
|
|
||||||
# This would have failed with "Cannot instantiate typing.Union" before the fix
|
|
||||||
assert result["role"] == "assistant"
|
|
||||||
assert result["content"] == "I'll help you find the weather."
|
|
||||||
assert "tool_calls" in result
|
|
||||||
assert result["tool_calls"] is not None
|
|
||||||
assert len(result["tool_calls"]) == 1
|
|
||||||
|
|
||||||
tool_call = result["tool_calls"][0]
|
|
||||||
assert tool_call.id == "call_123"
|
|
||||||
assert tool_call.type == "function"
|
|
||||||
assert tool_call.function.name == "get_weather"
|
|
||||||
assert tool_call.function.arguments == '{"city": "Sligo"}'
|
|
||||||
35
tests/unit/providers/utils/inference/test_prompt_adapter.py
Normal file
35
tests/unit/providers/utils/inference/test_prompt_adapter.py
Normal file
|
|
@ -0,0 +1,35 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import (
|
||||||
|
OpenAIAssistantMessageParam,
|
||||||
|
OpenAIUserMessageParam,
|
||||||
|
)
|
||||||
|
from llama_stack.models.llama.datatypes import RawTextItem
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
convert_openai_message_to_raw_message,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestConvertOpenAIMessageToRawMessage:
|
||||||
|
"""Test conversion of OpenAI message types to RawMessage format."""
|
||||||
|
|
||||||
|
async def test_user_message_conversion(self):
|
||||||
|
msg = OpenAIUserMessageParam(role="user", content="Hello world")
|
||||||
|
raw_msg = await convert_openai_message_to_raw_message(msg)
|
||||||
|
|
||||||
|
assert raw_msg.role == "user"
|
||||||
|
assert isinstance(raw_msg.content, RawTextItem)
|
||||||
|
assert raw_msg.content.text == "Hello world"
|
||||||
|
|
||||||
|
async def test_assistant_message_conversion(self):
|
||||||
|
msg = OpenAIAssistantMessageParam(role="assistant", content="Hi there!")
|
||||||
|
raw_msg = await convert_openai_message_to_raw_message(msg)
|
||||||
|
|
||||||
|
assert raw_msg.role == "assistant"
|
||||||
|
assert isinstance(raw_msg.content, RawTextItem)
|
||||||
|
assert raw_msg.content.text == "Hi there!"
|
||||||
|
assert raw_msg.tool_calls == []
|
||||||
Loading…
Add table
Add a link
Reference in a new issue