forked from phoenix-oss/llama-stack-mirror
chore: enable pyupgrade fixes (#1806)
# What does this PR do? The goal of this PR is code base modernization. Schema reflection code needed a minor adjustment to handle UnionTypes and collections.abc.AsyncIterator. (Both are preferred for latest Python releases.) Note to reviewers: almost all changes here are automatically generated by pyupgrade. Some additional unused imports were cleaned up. The only change worth of note can be found under `docs/openapi_generator` and `llama_stack/strong_typing/schema.py` where reflection code was updated to deal with "newer" types. Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
ffe3d0b2cd
commit
9e6561a1ec
319 changed files with 2843 additions and 3033 deletions
|
@ -4,21 +4,18 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Protocol,
|
||||
Union,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing_extensions import Annotated, TypedDict
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
|
||||
from llama_stack.apis.models import Model
|
||||
|
@ -47,8 +44,8 @@ class GreedySamplingStrategy(BaseModel):
|
|||
@json_schema_type
|
||||
class TopPSamplingStrategy(BaseModel):
|
||||
type: Literal["top_p"] = "top_p"
|
||||
temperature: Optional[float] = Field(..., gt=0.0)
|
||||
top_p: Optional[float] = 0.95
|
||||
temperature: float | None = Field(..., gt=0.0)
|
||||
top_p: float | None = 0.95
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -58,7 +55,7 @@ class TopKSamplingStrategy(BaseModel):
|
|||
|
||||
|
||||
SamplingStrategy = Annotated[
|
||||
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
|
||||
GreedySamplingStrategy | TopPSamplingStrategy | TopKSamplingStrategy,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(SamplingStrategy, name="SamplingStrategy")
|
||||
|
@ -79,9 +76,9 @@ class SamplingParams(BaseModel):
|
|||
|
||||
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
|
||||
|
||||
max_tokens: Optional[int] = 0
|
||||
repetition_penalty: Optional[float] = 1.0
|
||||
stop: Optional[List[str]] = None
|
||||
max_tokens: int | None = 0
|
||||
repetition_penalty: float | None = 1.0
|
||||
stop: list[str] | None = None
|
||||
|
||||
|
||||
class LogProbConfig(BaseModel):
|
||||
|
@ -90,7 +87,7 @@ class LogProbConfig(BaseModel):
|
|||
:param top_k: How many tokens (for each position) to return log probabilities for.
|
||||
"""
|
||||
|
||||
top_k: Optional[int] = 0
|
||||
top_k: int | None = 0
|
||||
|
||||
|
||||
class QuantizationType(Enum):
|
||||
|
@ -125,11 +122,11 @@ class Int4QuantizationConfig(BaseModel):
|
|||
"""
|
||||
|
||||
type: Literal["int4_mixed"] = "int4_mixed"
|
||||
scheme: Optional[str] = "int4_weight_int8_dynamic_activation"
|
||||
scheme: str | None = "int4_weight_int8_dynamic_activation"
|
||||
|
||||
|
||||
QuantizationConfig = Annotated[
|
||||
Union[Bf16QuantizationConfig, Fp8QuantizationConfig, Int4QuantizationConfig],
|
||||
Bf16QuantizationConfig | Fp8QuantizationConfig | Int4QuantizationConfig,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
@ -145,7 +142,7 @@ class UserMessage(BaseModel):
|
|||
|
||||
role: Literal["user"] = "user"
|
||||
content: InterleavedContent
|
||||
context: Optional[InterleavedContent] = None
|
||||
context: InterleavedContent | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -190,16 +187,11 @@ class CompletionMessage(BaseModel):
|
|||
role: Literal["assistant"] = "assistant"
|
||||
content: InterleavedContent
|
||||
stop_reason: StopReason
|
||||
tool_calls: Optional[List[ToolCall]] = Field(default_factory=list)
|
||||
tool_calls: list[ToolCall] | None = Field(default_factory=list)
|
||||
|
||||
|
||||
Message = Annotated[
|
||||
Union[
|
||||
UserMessage,
|
||||
SystemMessage,
|
||||
ToolResponseMessage,
|
||||
CompletionMessage,
|
||||
],
|
||||
UserMessage | SystemMessage | ToolResponseMessage | CompletionMessage,
|
||||
Field(discriminator="role"),
|
||||
]
|
||||
register_schema(Message, name="Message")
|
||||
|
@ -208,9 +200,9 @@ register_schema(Message, name="Message")
|
|||
@json_schema_type
|
||||
class ToolResponse(BaseModel):
|
||||
call_id: str
|
||||
tool_name: Union[BuiltinTool, str]
|
||||
tool_name: BuiltinTool | str
|
||||
content: InterleavedContent
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
@field_validator("tool_name", mode="before")
|
||||
@classmethod
|
||||
|
@ -243,7 +235,7 @@ class TokenLogProbs(BaseModel):
|
|||
:param logprobs_by_token: Dictionary mapping tokens to their log probabilities
|
||||
"""
|
||||
|
||||
logprobs_by_token: Dict[str, float]
|
||||
logprobs_by_token: dict[str, float]
|
||||
|
||||
|
||||
class ChatCompletionResponseEventType(Enum):
|
||||
|
@ -271,8 +263,8 @@ class ChatCompletionResponseEvent(BaseModel):
|
|||
|
||||
event_type: ChatCompletionResponseEventType
|
||||
delta: ContentDelta
|
||||
logprobs: Optional[List[TokenLogProbs]] = None
|
||||
stop_reason: Optional[StopReason] = None
|
||||
logprobs: list[TokenLogProbs] | None = None
|
||||
stop_reason: StopReason | None = None
|
||||
|
||||
|
||||
class ResponseFormatType(Enum):
|
||||
|
@ -295,7 +287,7 @@ class JsonSchemaResponseFormat(BaseModel):
|
|||
"""
|
||||
|
||||
type: Literal[ResponseFormatType.json_schema.value] = ResponseFormatType.json_schema.value
|
||||
json_schema: Dict[str, Any]
|
||||
json_schema: dict[str, Any]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -307,11 +299,11 @@ class GrammarResponseFormat(BaseModel):
|
|||
"""
|
||||
|
||||
type: Literal[ResponseFormatType.grammar.value] = ResponseFormatType.grammar.value
|
||||
bnf: Dict[str, Any]
|
||||
bnf: dict[str, Any]
|
||||
|
||||
|
||||
ResponseFormat = Annotated[
|
||||
Union[JsonSchemaResponseFormat, GrammarResponseFormat],
|
||||
JsonSchemaResponseFormat | GrammarResponseFormat,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(ResponseFormat, name="ResponseFormat")
|
||||
|
@ -321,10 +313,10 @@ register_schema(ResponseFormat, name="ResponseFormat")
|
|||
class CompletionRequest(BaseModel):
|
||||
model: str
|
||||
content: InterleavedContent
|
||||
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
|
||||
response_format: Optional[ResponseFormat] = None
|
||||
stream: Optional[bool] = False
|
||||
logprobs: Optional[LogProbConfig] = None
|
||||
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
|
||||
response_format: ResponseFormat | None = None
|
||||
stream: bool | None = False
|
||||
logprobs: LogProbConfig | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -338,7 +330,7 @@ class CompletionResponse(MetricResponseMixin):
|
|||
|
||||
content: str
|
||||
stop_reason: StopReason
|
||||
logprobs: Optional[List[TokenLogProbs]] = None
|
||||
logprobs: list[TokenLogProbs] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -351,8 +343,8 @@ class CompletionResponseStreamChunk(MetricResponseMixin):
|
|||
"""
|
||||
|
||||
delta: str
|
||||
stop_reason: Optional[StopReason] = None
|
||||
logprobs: Optional[List[TokenLogProbs]] = None
|
||||
stop_reason: StopReason | None = None
|
||||
logprobs: list[TokenLogProbs] | None = None
|
||||
|
||||
|
||||
class SystemMessageBehavior(Enum):
|
||||
|
@ -383,9 +375,9 @@ class ToolConfig(BaseModel):
|
|||
'{{function_definitions}}' to indicate where the function definitions should be inserted.
|
||||
"""
|
||||
|
||||
tool_choice: Optional[ToolChoice | str] = Field(default=ToolChoice.auto)
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
|
||||
system_message_behavior: Optional[SystemMessageBehavior] = Field(default=SystemMessageBehavior.append)
|
||||
tool_choice: ToolChoice | str | None = Field(default=ToolChoice.auto)
|
||||
tool_prompt_format: ToolPromptFormat | None = Field(default=None)
|
||||
system_message_behavior: SystemMessageBehavior | None = Field(default=SystemMessageBehavior.append)
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
if isinstance(self.tool_choice, str):
|
||||
|
@ -399,15 +391,15 @@ class ToolConfig(BaseModel):
|
|||
@json_schema_type
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[Message]
|
||||
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
|
||||
messages: list[Message]
|
||||
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
|
||||
|
||||
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
|
||||
tool_config: Optional[ToolConfig] = Field(default_factory=ToolConfig)
|
||||
tools: list[ToolDefinition] | None = Field(default_factory=list)
|
||||
tool_config: ToolConfig | None = Field(default_factory=ToolConfig)
|
||||
|
||||
response_format: Optional[ResponseFormat] = None
|
||||
stream: Optional[bool] = False
|
||||
logprobs: Optional[LogProbConfig] = None
|
||||
response_format: ResponseFormat | None = None
|
||||
stream: bool | None = False
|
||||
logprobs: LogProbConfig | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -429,7 +421,7 @@ class ChatCompletionResponse(MetricResponseMixin):
|
|||
"""
|
||||
|
||||
completion_message: CompletionMessage
|
||||
logprobs: Optional[List[TokenLogProbs]] = None
|
||||
logprobs: list[TokenLogProbs] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -439,7 +431,7 @@ class EmbeddingsResponse(BaseModel):
|
|||
:param embeddings: List of embedding vectors, one per input content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
|
||||
"""
|
||||
|
||||
embeddings: List[List[float]]
|
||||
embeddings: list[list[float]]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -451,7 +443,7 @@ class OpenAIChatCompletionContentPartTextParam(BaseModel):
|
|||
@json_schema_type
|
||||
class OpenAIImageURL(BaseModel):
|
||||
url: str
|
||||
detail: Optional[str] = None
|
||||
detail: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -461,16 +453,13 @@ class OpenAIChatCompletionContentPartImageParam(BaseModel):
|
|||
|
||||
|
||||
OpenAIChatCompletionContentPartParam = Annotated[
|
||||
Union[
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
],
|
||||
OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam")
|
||||
|
||||
|
||||
OpenAIChatCompletionMessageContent = Union[str, List[OpenAIChatCompletionContentPartParam]]
|
||||
OpenAIChatCompletionMessageContent = str | list[OpenAIChatCompletionContentPartParam]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -484,7 +473,7 @@ class OpenAIUserMessageParam(BaseModel):
|
|||
|
||||
role: Literal["user"] = "user"
|
||||
content: OpenAIChatCompletionMessageContent
|
||||
name: Optional[str] = None
|
||||
name: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -498,21 +487,21 @@ class OpenAISystemMessageParam(BaseModel):
|
|||
|
||||
role: Literal["system"] = "system"
|
||||
content: OpenAIChatCompletionMessageContent
|
||||
name: Optional[str] = None
|
||||
name: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChatCompletionToolCallFunction(BaseModel):
|
||||
name: Optional[str] = None
|
||||
arguments: Optional[str] = None
|
||||
name: str | None = None
|
||||
arguments: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChatCompletionToolCall(BaseModel):
|
||||
index: Optional[int] = None
|
||||
id: Optional[str] = None
|
||||
index: int | None = None
|
||||
id: str | None = None
|
||||
type: Literal["function"] = "function"
|
||||
function: Optional[OpenAIChatCompletionToolCallFunction] = None
|
||||
function: OpenAIChatCompletionToolCallFunction | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -526,9 +515,9 @@ class OpenAIAssistantMessageParam(BaseModel):
|
|||
"""
|
||||
|
||||
role: Literal["assistant"] = "assistant"
|
||||
content: Optional[OpenAIChatCompletionMessageContent] = None
|
||||
name: Optional[str] = None
|
||||
tool_calls: Optional[List[OpenAIChatCompletionToolCall]] = None
|
||||
content: OpenAIChatCompletionMessageContent | None = None
|
||||
name: str | None = None
|
||||
tool_calls: list[OpenAIChatCompletionToolCall] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -556,17 +545,15 @@ class OpenAIDeveloperMessageParam(BaseModel):
|
|||
|
||||
role: Literal["developer"] = "developer"
|
||||
content: OpenAIChatCompletionMessageContent
|
||||
name: Optional[str] = None
|
||||
name: str | None = None
|
||||
|
||||
|
||||
OpenAIMessageParam = Annotated[
|
||||
Union[
|
||||
OpenAIUserMessageParam,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
OpenAIDeveloperMessageParam,
|
||||
],
|
||||
OpenAIUserMessageParam
|
||||
| OpenAISystemMessageParam
|
||||
| OpenAIAssistantMessageParam
|
||||
| OpenAIToolMessageParam
|
||||
| OpenAIDeveloperMessageParam,
|
||||
Field(discriminator="role"),
|
||||
]
|
||||
register_schema(OpenAIMessageParam, name="OpenAIMessageParam")
|
||||
|
@ -580,14 +567,14 @@ class OpenAIResponseFormatText(BaseModel):
|
|||
@json_schema_type
|
||||
class OpenAIJSONSchema(TypedDict, total=False):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
strict: Optional[bool] = None
|
||||
description: str | None = None
|
||||
strict: bool | None = None
|
||||
|
||||
# Pydantic BaseModel cannot be used with a schema param, since it already
|
||||
# has one. And, we don't want to alias here because then have to handle
|
||||
# that alias when converting to OpenAI params. So, to support schema,
|
||||
# we use a TypedDict.
|
||||
schema: Optional[Dict[str, Any]] = None
|
||||
schema: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -602,11 +589,7 @@ class OpenAIResponseFormatJSONObject(BaseModel):
|
|||
|
||||
|
||||
OpenAIResponseFormatParam = Annotated[
|
||||
Union[
|
||||
OpenAIResponseFormatText,
|
||||
OpenAIResponseFormatJSONSchema,
|
||||
OpenAIResponseFormatJSONObject,
|
||||
],
|
||||
OpenAIResponseFormatText | OpenAIResponseFormatJSONSchema | OpenAIResponseFormatJSONObject,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseFormatParam, name="OpenAIResponseFormatParam")
|
||||
|
@ -622,7 +605,7 @@ class OpenAITopLogProb(BaseModel):
|
|||
"""
|
||||
|
||||
token: str
|
||||
bytes: Optional[List[int]] = None
|
||||
bytes: list[int] | None = None
|
||||
logprob: float
|
||||
|
||||
|
||||
|
@ -637,9 +620,9 @@ class OpenAITokenLogProb(BaseModel):
|
|||
"""
|
||||
|
||||
token: str
|
||||
bytes: Optional[List[int]] = None
|
||||
bytes: list[int] | None = None
|
||||
logprob: float
|
||||
top_logprobs: List[OpenAITopLogProb]
|
||||
top_logprobs: list[OpenAITopLogProb]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -650,8 +633,8 @@ class OpenAIChoiceLogprobs(BaseModel):
|
|||
:param refusal: (Optional) The log probabilities for the tokens in the message
|
||||
"""
|
||||
|
||||
content: Optional[List[OpenAITokenLogProb]] = None
|
||||
refusal: Optional[List[OpenAITokenLogProb]] = None
|
||||
content: list[OpenAITokenLogProb] | None = None
|
||||
refusal: list[OpenAITokenLogProb] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -664,10 +647,10 @@ class OpenAIChoiceDelta(BaseModel):
|
|||
:param tool_calls: (Optional) The tool calls of the delta
|
||||
"""
|
||||
|
||||
content: Optional[str] = None
|
||||
refusal: Optional[str] = None
|
||||
role: Optional[str] = None
|
||||
tool_calls: Optional[List[OpenAIChatCompletionToolCall]] = None
|
||||
content: str | None = None
|
||||
refusal: str | None = None
|
||||
role: str | None = None
|
||||
tool_calls: list[OpenAIChatCompletionToolCall] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -683,7 +666,7 @@ class OpenAIChunkChoice(BaseModel):
|
|||
delta: OpenAIChoiceDelta
|
||||
finish_reason: str
|
||||
index: int
|
||||
logprobs: Optional[OpenAIChoiceLogprobs] = None
|
||||
logprobs: OpenAIChoiceLogprobs | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -699,7 +682,7 @@ class OpenAIChoice(BaseModel):
|
|||
message: OpenAIMessageParam
|
||||
finish_reason: str
|
||||
index: int
|
||||
logprobs: Optional[OpenAIChoiceLogprobs] = None
|
||||
logprobs: OpenAIChoiceLogprobs | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -714,7 +697,7 @@ class OpenAIChatCompletion(BaseModel):
|
|||
"""
|
||||
|
||||
id: str
|
||||
choices: List[OpenAIChoice]
|
||||
choices: list[OpenAIChoice]
|
||||
object: Literal["chat.completion"] = "chat.completion"
|
||||
created: int
|
||||
model: str
|
||||
|
@ -732,7 +715,7 @@ class OpenAIChatCompletionChunk(BaseModel):
|
|||
"""
|
||||
|
||||
id: str
|
||||
choices: List[OpenAIChunkChoice]
|
||||
choices: list[OpenAIChunkChoice]
|
||||
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||
created: int
|
||||
model: str
|
||||
|
@ -748,10 +731,10 @@ class OpenAICompletionLogprobs(BaseModel):
|
|||
:top_logprobs: (Optional) The top log probabilities for the tokens
|
||||
"""
|
||||
|
||||
text_offset: Optional[List[int]] = None
|
||||
token_logprobs: Optional[List[float]] = None
|
||||
tokens: Optional[List[str]] = None
|
||||
top_logprobs: Optional[List[Dict[str, float]]] = None
|
||||
text_offset: list[int] | None = None
|
||||
token_logprobs: list[float] | None = None
|
||||
tokens: list[str] | None = None
|
||||
top_logprobs: list[dict[str, float]] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -767,7 +750,7 @@ class OpenAICompletionChoice(BaseModel):
|
|||
finish_reason: str
|
||||
text: str
|
||||
index: int
|
||||
logprobs: Optional[OpenAIChoiceLogprobs] = None
|
||||
logprobs: OpenAIChoiceLogprobs | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -782,7 +765,7 @@ class OpenAICompletion(BaseModel):
|
|||
"""
|
||||
|
||||
id: str
|
||||
choices: List[OpenAICompletionChoice]
|
||||
choices: list[OpenAICompletionChoice]
|
||||
created: int
|
||||
model: str
|
||||
object: Literal["text_completion"] = "text_completion"
|
||||
|
@ -818,12 +801,12 @@ class EmbeddingTaskType(Enum):
|
|||
|
||||
@json_schema_type
|
||||
class BatchCompletionResponse(BaseModel):
|
||||
batch: List[CompletionResponse]
|
||||
batch: list[CompletionResponse]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BatchChatCompletionResponse(BaseModel):
|
||||
batch: List[ChatCompletionResponse]
|
||||
batch: list[ChatCompletionResponse]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
@ -843,11 +826,11 @@ class Inference(Protocol):
|
|||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> CompletionResponse | AsyncIterator[CompletionResponseStreamChunk]:
|
||||
"""Generate a completion for the given content using the specified model.
|
||||
|
||||
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
|
@ -865,10 +848,10 @@ class Inference(Protocol):
|
|||
async def batch_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content_batch: List[InterleavedContent],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
content_batch: list[InterleavedContent],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> BatchCompletionResponse:
|
||||
raise NotImplementedError("Batch completion is not implemented")
|
||||
|
||||
|
@ -876,16 +859,16 @@ class Inference(Protocol):
|
|||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||
messages: list[Message],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||
"""Generate a chat completion for the given messages using the specified model.
|
||||
|
||||
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
|
@ -916,12 +899,12 @@ class Inference(Protocol):
|
|||
async def batch_chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages_batch: List[List[Message]],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
messages_batch: list[list[Message]],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> BatchChatCompletionResponse:
|
||||
raise NotImplementedError("Batch chat completion is not implemented")
|
||||
|
||||
|
@ -929,10 +912,10 @@ class Inference(Protocol):
|
|||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
"""Generate embeddings for content pieces using the specified model.
|
||||
|
||||
|
@ -950,25 +933,25 @@ class Inference(Protocol):
|
|||
self,
|
||||
# Standard OpenAI completion parameters
|
||||
model: str,
|
||||
prompt: Union[str, List[str], List[int], List[List[int]]],
|
||||
best_of: Optional[int] = None,
|
||||
echo: Optional[bool] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
logit_bias: Optional[Dict[str, float]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
stream_options: Optional[Dict[str, Any]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
# vLLM-specific parameters
|
||||
guided_choice: Optional[List[str]] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
) -> OpenAICompletion:
|
||||
"""Generate an OpenAI-compatible completion for the given prompt using the specified model.
|
||||
|
||||
|
@ -996,29 +979,29 @@ class Inference(Protocol):
|
|||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[OpenAIMessageParam],
|
||||
frequency_penalty: Optional[float] = None,
|
||||
function_call: Optional[Union[str, Dict[str, Any]]] = None,
|
||||
functions: Optional[List[Dict[str, Any]]] = None,
|
||||
logit_bias: Optional[Dict[str, float]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
max_completion_tokens: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
stream_options: Optional[Dict[str, Any]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
"""Generate an OpenAI-compatible chat completion for the given messages using the specified model.
|
||||
|
||||
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue