# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. from enum import Enum from typing import ( Any, AsyncIterator, Dict, List, Literal, Optional, Protocol, runtime_checkable, Union, ) from llama_models.llama3.api.datatypes import ( BuiltinTool, SamplingParams, StopReason, ToolCall, ToolDefinition, ToolPromptFormat, ) from llama_models.schema_utils import json_schema_type, register_schema, webmethod from pydantic import BaseModel, Field, field_validator from typing_extensions import Annotated from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.models import Model from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol class LogProbConfig(BaseModel): top_k: Optional[int] = 0 @json_schema_type class QuantizationType(Enum): bf16 = "bf16" fp8 = "fp8" int4 = "int4" @json_schema_type class Fp8QuantizationConfig(BaseModel): type: Literal["fp8"] = "fp8" @json_schema_type class Bf16QuantizationConfig(BaseModel): type: Literal["bf16"] = "bf16" @json_schema_type class Int4QuantizationConfig(BaseModel): type: Literal["int4"] = "int4" scheme: Optional[str] = "int4_weight_int8_dynamic_activation" QuantizationConfig = Annotated[ Union[Bf16QuantizationConfig, Fp8QuantizationConfig, Int4QuantizationConfig], Field(discriminator="type"), ] @json_schema_type class UserMessage(BaseModel): role: Literal["user"] = "user" content: InterleavedContent context: Optional[InterleavedContent] = None @json_schema_type class SystemMessage(BaseModel): role: Literal["system"] = "system" content: InterleavedContent @json_schema_type class ToolResponseMessage(BaseModel): role: Literal["ipython"] = "ipython" # it was nice to re-use the ToolResponse type, but having all messages # have a `content` type makes things nicer too call_id: str tool_name: Union[BuiltinTool, str] content: InterleavedContent @json_schema_type class CompletionMessage(BaseModel): role: Literal["assistant"] = "assistant" content: InterleavedContent stop_reason: StopReason tool_calls: List[ToolCall] = Field(default_factory=list) Message = register_schema( Annotated[ Union[ UserMessage, SystemMessage, ToolResponseMessage, CompletionMessage, ], Field(discriminator="role"), ], name="Message", ) @json_schema_type class ToolResponse(BaseModel): call_id: str tool_name: Union[BuiltinTool, str] content: InterleavedContent @field_validator("tool_name", mode="before") @classmethod def validate_field(cls, v): if isinstance(v, str): try: return BuiltinTool(v) except ValueError: return v return v @json_schema_type class ToolChoice(Enum): auto = "auto" required = "required" @json_schema_type class TokenLogProbs(BaseModel): logprobs_by_token: Dict[str, float] @json_schema_type class ChatCompletionResponseEventType(Enum): start = "start" complete = "complete" progress = "progress" @json_schema_type class ToolCallParseStatus(Enum): started = "started" in_progress = "in_progress" failure = "failure" success = "success" @json_schema_type class ToolCallDelta(BaseModel): content: Union[str, ToolCall] parse_status: ToolCallParseStatus @json_schema_type class ChatCompletionResponseEvent(BaseModel): """Chat completion response event.""" event_type: ChatCompletionResponseEventType delta: Union[str, ToolCallDelta] logprobs: Optional[List[TokenLogProbs]] = None stop_reason: Optional[StopReason] = None class ResponseFormatType(Enum): json_schema = "json_schema" grammar = "grammar" class JsonSchemaResponseFormat(BaseModel): type: Literal[ResponseFormatType.json_schema.value] = ( ResponseFormatType.json_schema.value ) json_schema: Dict[str, Any] class GrammarResponseFormat(BaseModel): type: Literal[ResponseFormatType.grammar.value] = ResponseFormatType.grammar.value bnf: Dict[str, Any] ResponseFormat = register_schema( Annotated[ Union[JsonSchemaResponseFormat, GrammarResponseFormat], Field(discriminator="type"), ], name="ResponseFormat", ) @json_schema_type class CompletionRequest(BaseModel): model: str content: InterleavedContent sampling_params: Optional[SamplingParams] = SamplingParams() response_format: Optional[ResponseFormat] = None stream: Optional[bool] = False logprobs: Optional[LogProbConfig] = None @json_schema_type class CompletionResponse(BaseModel): """Completion response.""" content: str stop_reason: StopReason logprobs: Optional[List[TokenLogProbs]] = None @json_schema_type class CompletionResponseStreamChunk(BaseModel): """streamed completion response.""" delta: str stop_reason: Optional[StopReason] = None logprobs: Optional[List[TokenLogProbs]] = None @json_schema_type class BatchCompletionRequest(BaseModel): model: str content_batch: List[InterleavedContent] sampling_params: Optional[SamplingParams] = SamplingParams() response_format: Optional[ResponseFormat] = None logprobs: Optional[LogProbConfig] = None @json_schema_type class BatchCompletionResponse(BaseModel): """Batch completion response.""" batch: List[CompletionResponse] @json_schema_type class ChatCompletionRequest(BaseModel): model: str messages: List[Message] sampling_params: Optional[SamplingParams] = SamplingParams() # zero-shot tool definitions as input to the model tools: Optional[List[ToolDefinition]] = Field(default_factory=list) tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) tool_prompt_format: Optional[ToolPromptFormat] = Field( default=ToolPromptFormat.json ) response_format: Optional[ResponseFormat] = None stream: Optional[bool] = False logprobs: Optional[LogProbConfig] = None @json_schema_type class ChatCompletionResponseStreamChunk(BaseModel): """SSE-stream of these events.""" event: ChatCompletionResponseEvent @json_schema_type class ChatCompletionResponse(BaseModel): """Chat completion response.""" completion_message: CompletionMessage logprobs: Optional[List[TokenLogProbs]] = None @json_schema_type class BatchChatCompletionRequest(BaseModel): model: str messages_batch: List[List[Message]] sampling_params: Optional[SamplingParams] = SamplingParams() # zero-shot tool definitions as input to the model tools: Optional[List[ToolDefinition]] = Field(default_factory=list) tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) tool_prompt_format: Optional[ToolPromptFormat] = Field( default=ToolPromptFormat.json ) logprobs: Optional[LogProbConfig] = None @json_schema_type class BatchChatCompletionResponse(BaseModel): batch: List[ChatCompletionResponse] @json_schema_type class EmbeddingsResponse(BaseModel): embeddings: List[List[float]] class ModelStore(Protocol): def get_model(self, identifier: str) -> Model: ... @runtime_checkable @trace_protocol class Inference(Protocol): model_store: ModelStore @webmethod(route="/inference/completion") async def completion( self, model_id: str, content: InterleavedContent, sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: ... @webmethod(route="/inference/chat-completion") async def chat_completion( self, model_id: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), # zero-shot tool definitions as input to the model tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> Union[ ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk] ]: ... @webmethod(route="/inference/embeddings") async def embeddings( self, model_id: str, contents: List[InterleavedContent], ) -> EmbeddingsResponse: ...