llama-stack-mirror/llama_stack/apis/inference/inference.py
Ashwin Bharambe e5936a8df8
Update discriminator to have the correct mapping (#881)
See
https://swagger.io/docs/specification/v3_0/data-models/inheritance-and-polymorphism/#discriminator

When specifying discriminators, mapping must be specified unless the
value of the discriminator is the subtype itself (which in our case is
not.)

The changes in the YAML are self-explanatory.
2025-01-27 09:18:13 -08:00

330 lines
8.4 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import (
Any,
AsyncIterator,
Dict,
List,
Literal,
Optional,
Protocol,
runtime_checkable,
Union,
)
from llama_models.llama3.api.datatypes import (
BuiltinTool,
SamplingParams,
StopReason,
ToolCall,
ToolDefinition,
ToolPromptFormat,
)
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, Field, field_validator
from typing_extensions import Annotated
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
from llama_stack.apis.models import Model
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
class LogProbConfig(BaseModel):
top_k: Optional[int] = 0
@json_schema_type
class QuantizationType(Enum):
bf16 = "bf16"
fp8 = "fp8"
int4 = "int4"
@json_schema_type
class Fp8QuantizationConfig(BaseModel):
type: Literal["fp8"] = "fp8"
@json_schema_type
class Bf16QuantizationConfig(BaseModel):
type: Literal["bf16"] = "bf16"
@json_schema_type
class Int4QuantizationConfig(BaseModel):
type: Literal["int4"] = "int4"
scheme: Optional[str] = "int4_weight_int8_dynamic_activation"
QuantizationConfig = Annotated[
Union[Bf16QuantizationConfig, Fp8QuantizationConfig, Int4QuantizationConfig],
Field(discriminator="type"),
]
@json_schema_type
class UserMessage(BaseModel):
role: Literal["user"] = "user"
content: InterleavedContent
context: Optional[InterleavedContent] = None
@json_schema_type
class SystemMessage(BaseModel):
role: Literal["system"] = "system"
content: InterleavedContent
@json_schema_type
class ToolResponseMessage(BaseModel):
role: Literal["tool"] = "tool"
# it was nice to re-use the ToolResponse type, but having all messages
# have a `content` type makes things nicer too
call_id: str
tool_name: Union[BuiltinTool, str]
content: InterleavedContent
@json_schema_type
class CompletionMessage(BaseModel):
role: Literal["assistant"] = "assistant"
content: InterleavedContent
stop_reason: StopReason
tool_calls: List[ToolCall] = Field(default_factory=list)
Message = register_schema(
Annotated[
Union[
UserMessage,
SystemMessage,
ToolResponseMessage,
CompletionMessage,
],
Field(discriminator="role"),
],
name="Message",
)
@json_schema_type
class ToolResponse(BaseModel):
call_id: str
tool_name: Union[BuiltinTool, str]
content: InterleavedContent
@field_validator("tool_name", mode="before")
@classmethod
def validate_field(cls, v):
if isinstance(v, str):
try:
return BuiltinTool(v)
except ValueError:
return v
return v
@json_schema_type
class ToolChoice(Enum):
auto = "auto"
required = "required"
@json_schema_type
class TokenLogProbs(BaseModel):
logprobs_by_token: Dict[str, float]
@json_schema_type
class ChatCompletionResponseEventType(Enum):
start = "start"
complete = "complete"
progress = "progress"
@json_schema_type
class ChatCompletionResponseEvent(BaseModel):
"""Chat completion response event."""
event_type: ChatCompletionResponseEventType
delta: ContentDelta
logprobs: Optional[List[TokenLogProbs]] = None
stop_reason: Optional[StopReason] = None
@json_schema_type
class ResponseFormatType(Enum):
json_schema = "json_schema"
grammar = "grammar"
@json_schema_type
class JsonSchemaResponseFormat(BaseModel):
type: Literal[ResponseFormatType.json_schema.value] = (
ResponseFormatType.json_schema.value
)
json_schema: Dict[str, Any]
@json_schema_type
class GrammarResponseFormat(BaseModel):
type: Literal[ResponseFormatType.grammar.value] = ResponseFormatType.grammar.value
bnf: Dict[str, Any]
ResponseFormat = register_schema(
Annotated[
Union[JsonSchemaResponseFormat, GrammarResponseFormat],
Field(discriminator="type"),
],
name="ResponseFormat",
)
@json_schema_type
class CompletionRequest(BaseModel):
model: str
content: InterleavedContent
sampling_params: Optional[SamplingParams] = SamplingParams()
response_format: Optional[ResponseFormat] = None
stream: Optional[bool] = False
logprobs: Optional[LogProbConfig] = None
@json_schema_type
class CompletionResponse(BaseModel):
"""Completion response."""
content: str
stop_reason: StopReason
logprobs: Optional[List[TokenLogProbs]] = None
@json_schema_type
class CompletionResponseStreamChunk(BaseModel):
"""streamed completion response."""
delta: str
stop_reason: Optional[StopReason] = None
logprobs: Optional[List[TokenLogProbs]] = None
@json_schema_type
class BatchCompletionRequest(BaseModel):
model: str
content_batch: List[InterleavedContent]
sampling_params: Optional[SamplingParams] = SamplingParams()
response_format: Optional[ResponseFormat] = None
logprobs: Optional[LogProbConfig] = None
@json_schema_type
class BatchCompletionResponse(BaseModel):
"""Batch completion response."""
batch: List[CompletionResponse]
@json_schema_type
class ChatCompletionRequest(BaseModel):
model: str
messages: List[Message]
sampling_params: Optional[SamplingParams] = SamplingParams()
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
response_format: Optional[ResponseFormat] = None
stream: Optional[bool] = False
logprobs: Optional[LogProbConfig] = None
@json_schema_type
class ChatCompletionResponseStreamChunk(BaseModel):
"""SSE-stream of these events."""
event: ChatCompletionResponseEvent
@json_schema_type
class ChatCompletionResponse(BaseModel):
"""Chat completion response."""
completion_message: CompletionMessage
logprobs: Optional[List[TokenLogProbs]] = None
@json_schema_type
class BatchChatCompletionRequest(BaseModel):
model: str
messages_batch: List[List[Message]]
sampling_params: Optional[SamplingParams] = SamplingParams()
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
logprobs: Optional[LogProbConfig] = None
@json_schema_type
class BatchChatCompletionResponse(BaseModel):
batch: List[ChatCompletionResponse]
@json_schema_type
class EmbeddingsResponse(BaseModel):
embeddings: List[List[float]]
class ModelStore(Protocol):
def get_model(self, identifier: str) -> Model: ...
@runtime_checkable
@trace_protocol
class Inference(Protocol):
model_store: ModelStore
@webmethod(route="/inference/completion", method="POST")
async def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: ...
@webmethod(route="/inference/chat-completion", method="POST")
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]: ...
@webmethod(route="/inference/embeddings", method="POST")
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
) -> EmbeddingsResponse: ...