forked from phoenix-oss/llama-stack-mirror
# What does this PR do? We are setting a default value of json for tool prompt format, which conflicts with llama 3.2/3.3 models since they use python list. This PR changes the defaults to None and in the code, we infer default based on the model. Addresses: #695 Tests: ❯ LLAMA_STACK_BASE_URL=http://localhost:5000 pytest -v tests/client-sdk/inference/test_inference.py -k "test_text_chat_completion" pytest llama_stack/providers/tests/inference/test_prompt_adapter.py
341 lines
8.6 KiB
Python
341 lines
8.6 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from enum import Enum
|
|
from typing import (
|
|
Any,
|
|
AsyncIterator,
|
|
Dict,
|
|
List,
|
|
Literal,
|
|
Optional,
|
|
Protocol,
|
|
runtime_checkable,
|
|
Union,
|
|
)
|
|
|
|
from llama_models.llama3.api.datatypes import (
|
|
BuiltinTool,
|
|
SamplingParams,
|
|
StopReason,
|
|
ToolCall,
|
|
ToolDefinition,
|
|
ToolPromptFormat,
|
|
)
|
|
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
|
from pydantic import BaseModel, Field, field_validator
|
|
from typing_extensions import Annotated
|
|
|
|
from llama_stack.apis.common.content_types import InterleavedContent
|
|
from llama_stack.apis.models import Model
|
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
|
|
|
|
|
class LogProbConfig(BaseModel):
|
|
top_k: Optional[int] = 0
|
|
|
|
|
|
@json_schema_type
|
|
class QuantizationType(Enum):
|
|
bf16 = "bf16"
|
|
fp8 = "fp8"
|
|
int4 = "int4"
|
|
|
|
|
|
@json_schema_type
|
|
class Fp8QuantizationConfig(BaseModel):
|
|
type: Literal["fp8"] = "fp8"
|
|
|
|
|
|
@json_schema_type
|
|
class Bf16QuantizationConfig(BaseModel):
|
|
type: Literal["bf16"] = "bf16"
|
|
|
|
|
|
@json_schema_type
|
|
class Int4QuantizationConfig(BaseModel):
|
|
type: Literal["int4"] = "int4"
|
|
scheme: Optional[str] = "int4_weight_int8_dynamic_activation"
|
|
|
|
|
|
QuantizationConfig = Annotated[
|
|
Union[Bf16QuantizationConfig, Fp8QuantizationConfig, Int4QuantizationConfig],
|
|
Field(discriminator="type"),
|
|
]
|
|
|
|
|
|
@json_schema_type
|
|
class UserMessage(BaseModel):
|
|
role: Literal["user"] = "user"
|
|
content: InterleavedContent
|
|
context: Optional[InterleavedContent] = None
|
|
|
|
|
|
@json_schema_type
|
|
class SystemMessage(BaseModel):
|
|
role: Literal["system"] = "system"
|
|
content: InterleavedContent
|
|
|
|
|
|
@json_schema_type
|
|
class ToolResponseMessage(BaseModel):
|
|
role: Literal["ipython"] = "ipython"
|
|
# it was nice to re-use the ToolResponse type, but having all messages
|
|
# have a `content` type makes things nicer too
|
|
call_id: str
|
|
tool_name: Union[BuiltinTool, str]
|
|
content: InterleavedContent
|
|
|
|
|
|
@json_schema_type
|
|
class CompletionMessage(BaseModel):
|
|
role: Literal["assistant"] = "assistant"
|
|
content: InterleavedContent
|
|
stop_reason: StopReason
|
|
tool_calls: List[ToolCall] = Field(default_factory=list)
|
|
|
|
|
|
Message = register_schema(
|
|
Annotated[
|
|
Union[
|
|
UserMessage,
|
|
SystemMessage,
|
|
ToolResponseMessage,
|
|
CompletionMessage,
|
|
],
|
|
Field(discriminator="role"),
|
|
],
|
|
name="Message",
|
|
)
|
|
|
|
|
|
@json_schema_type
|
|
class ToolResponse(BaseModel):
|
|
call_id: str
|
|
tool_name: Union[BuiltinTool, str]
|
|
content: InterleavedContent
|
|
|
|
@field_validator("tool_name", mode="before")
|
|
@classmethod
|
|
def validate_field(cls, v):
|
|
if isinstance(v, str):
|
|
try:
|
|
return BuiltinTool(v)
|
|
except ValueError:
|
|
return v
|
|
return v
|
|
|
|
|
|
@json_schema_type
|
|
class ToolChoice(Enum):
|
|
auto = "auto"
|
|
required = "required"
|
|
|
|
|
|
@json_schema_type
|
|
class TokenLogProbs(BaseModel):
|
|
logprobs_by_token: Dict[str, float]
|
|
|
|
|
|
@json_schema_type
|
|
class ChatCompletionResponseEventType(Enum):
|
|
start = "start"
|
|
complete = "complete"
|
|
progress = "progress"
|
|
|
|
|
|
@json_schema_type
|
|
class ToolCallParseStatus(Enum):
|
|
started = "started"
|
|
in_progress = "in_progress"
|
|
failure = "failure"
|
|
success = "success"
|
|
|
|
|
|
@json_schema_type
|
|
class ToolCallDelta(BaseModel):
|
|
content: Union[str, ToolCall]
|
|
parse_status: ToolCallParseStatus
|
|
|
|
|
|
@json_schema_type
|
|
class ChatCompletionResponseEvent(BaseModel):
|
|
"""Chat completion response event."""
|
|
|
|
event_type: ChatCompletionResponseEventType
|
|
delta: Union[str, ToolCallDelta]
|
|
logprobs: Optional[List[TokenLogProbs]] = None
|
|
stop_reason: Optional[StopReason] = None
|
|
|
|
|
|
class ResponseFormatType(Enum):
|
|
json_schema = "json_schema"
|
|
grammar = "grammar"
|
|
|
|
|
|
class JsonSchemaResponseFormat(BaseModel):
|
|
type: Literal[ResponseFormatType.json_schema.value] = (
|
|
ResponseFormatType.json_schema.value
|
|
)
|
|
json_schema: Dict[str, Any]
|
|
|
|
|
|
class GrammarResponseFormat(BaseModel):
|
|
type: Literal[ResponseFormatType.grammar.value] = ResponseFormatType.grammar.value
|
|
bnf: Dict[str, Any]
|
|
|
|
|
|
ResponseFormat = register_schema(
|
|
Annotated[
|
|
Union[JsonSchemaResponseFormat, GrammarResponseFormat],
|
|
Field(discriminator="type"),
|
|
],
|
|
name="ResponseFormat",
|
|
)
|
|
|
|
|
|
@json_schema_type
|
|
class CompletionRequest(BaseModel):
|
|
model: str
|
|
content: InterleavedContent
|
|
sampling_params: Optional[SamplingParams] = SamplingParams()
|
|
response_format: Optional[ResponseFormat] = None
|
|
|
|
stream: Optional[bool] = False
|
|
logprobs: Optional[LogProbConfig] = None
|
|
|
|
|
|
@json_schema_type
|
|
class CompletionResponse(BaseModel):
|
|
"""Completion response."""
|
|
|
|
content: str
|
|
stop_reason: StopReason
|
|
logprobs: Optional[List[TokenLogProbs]] = None
|
|
|
|
|
|
@json_schema_type
|
|
class CompletionResponseStreamChunk(BaseModel):
|
|
"""streamed completion response."""
|
|
|
|
delta: str
|
|
stop_reason: Optional[StopReason] = None
|
|
logprobs: Optional[List[TokenLogProbs]] = None
|
|
|
|
|
|
@json_schema_type
|
|
class BatchCompletionRequest(BaseModel):
|
|
model: str
|
|
content_batch: List[InterleavedContent]
|
|
sampling_params: Optional[SamplingParams] = SamplingParams()
|
|
response_format: Optional[ResponseFormat] = None
|
|
logprobs: Optional[LogProbConfig] = None
|
|
|
|
|
|
@json_schema_type
|
|
class BatchCompletionResponse(BaseModel):
|
|
"""Batch completion response."""
|
|
|
|
batch: List[CompletionResponse]
|
|
|
|
|
|
@json_schema_type
|
|
class ChatCompletionRequest(BaseModel):
|
|
model: str
|
|
messages: List[Message]
|
|
sampling_params: Optional[SamplingParams] = SamplingParams()
|
|
|
|
# zero-shot tool definitions as input to the model
|
|
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
|
|
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
|
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
|
|
response_format: Optional[ResponseFormat] = None
|
|
|
|
stream: Optional[bool] = False
|
|
logprobs: Optional[LogProbConfig] = None
|
|
|
|
|
|
@json_schema_type
|
|
class ChatCompletionResponseStreamChunk(BaseModel):
|
|
"""SSE-stream of these events."""
|
|
|
|
event: ChatCompletionResponseEvent
|
|
|
|
|
|
@json_schema_type
|
|
class ChatCompletionResponse(BaseModel):
|
|
"""Chat completion response."""
|
|
|
|
completion_message: CompletionMessage
|
|
logprobs: Optional[List[TokenLogProbs]] = None
|
|
|
|
|
|
@json_schema_type
|
|
class BatchChatCompletionRequest(BaseModel):
|
|
model: str
|
|
messages_batch: List[List[Message]]
|
|
sampling_params: Optional[SamplingParams] = SamplingParams()
|
|
|
|
# zero-shot tool definitions as input to the model
|
|
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
|
|
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
|
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
|
|
logprobs: Optional[LogProbConfig] = None
|
|
|
|
|
|
@json_schema_type
|
|
class BatchChatCompletionResponse(BaseModel):
|
|
batch: List[ChatCompletionResponse]
|
|
|
|
|
|
@json_schema_type
|
|
class EmbeddingsResponse(BaseModel):
|
|
embeddings: List[List[float]]
|
|
|
|
|
|
class ModelStore(Protocol):
|
|
def get_model(self, identifier: str) -> Model: ...
|
|
|
|
|
|
@runtime_checkable
|
|
@trace_protocol
|
|
class Inference(Protocol):
|
|
model_store: ModelStore
|
|
|
|
@webmethod(route="/inference/completion")
|
|
async def completion(
|
|
self,
|
|
model_id: str,
|
|
content: InterleavedContent,
|
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
|
response_format: Optional[ResponseFormat] = None,
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: ...
|
|
|
|
@webmethod(route="/inference/chat-completion")
|
|
async def chat_completion(
|
|
self,
|
|
model_id: str,
|
|
messages: List[Message],
|
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
|
# zero-shot tool definitions as input to the model
|
|
tools: Optional[List[ToolDefinition]] = None,
|
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
|
response_format: Optional[ResponseFormat] = None,
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
) -> Union[
|
|
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
|
]: ...
|
|
|
|
@webmethod(route="/inference/embeddings")
|
|
async def embeddings(
|
|
self,
|
|
model_id: str,
|
|
contents: List[InterleavedContent],
|
|
) -> EmbeddingsResponse: ...
|