mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
chore: enable pyupgrade fixes (#1806)
# What does this PR do? The goal of this PR is code base modernization. Schema reflection code needed a minor adjustment to handle UnionTypes and collections.abc.AsyncIterator. (Both are preferred for latest Python releases.) Note to reviewers: almost all changes here are automatically generated by pyupgrade. Some additional unused imports were cleaned up. The only change worth of note can be found under `docs/openapi_generator` and `llama_stack/strong_typing/schema.py` where reflection code was updated to deal with "newer" types. Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
ffe3d0b2cd
commit
9e6561a1ec
319 changed files with 2843 additions and 3033 deletions
|
@ -4,20 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Protocol,
|
||||
Union,
|
||||
runtime_checkable,
|
||||
)
|
||||
from typing import Annotated, Any, Literal, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
@ -79,8 +69,8 @@ class StepCommon(BaseModel):
|
|||
|
||||
turn_id: str
|
||||
step_id: str
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
started_at: datetime | None = None
|
||||
completed_at: datetime | None = None
|
||||
|
||||
|
||||
class StepType(Enum):
|
||||
|
@ -120,8 +110,8 @@ class ToolExecutionStep(StepCommon):
|
|||
"""
|
||||
|
||||
step_type: Literal[StepType.tool_execution.value] = StepType.tool_execution.value
|
||||
tool_calls: List[ToolCall]
|
||||
tool_responses: List[ToolResponse]
|
||||
tool_calls: list[ToolCall]
|
||||
tool_responses: list[ToolResponse]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -132,7 +122,7 @@ class ShieldCallStep(StepCommon):
|
|||
"""
|
||||
|
||||
step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value
|
||||
violation: Optional[SafetyViolation]
|
||||
violation: SafetyViolation | None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -150,12 +140,7 @@ class MemoryRetrievalStep(StepCommon):
|
|||
|
||||
|
||||
Step = Annotated[
|
||||
Union[
|
||||
InferenceStep,
|
||||
ToolExecutionStep,
|
||||
ShieldCallStep,
|
||||
MemoryRetrievalStep,
|
||||
],
|
||||
InferenceStep | ToolExecutionStep | ShieldCallStep | MemoryRetrievalStep,
|
||||
Field(discriminator="step_type"),
|
||||
]
|
||||
|
||||
|
@ -166,18 +151,13 @@ class Turn(BaseModel):
|
|||
|
||||
turn_id: str
|
||||
session_id: str
|
||||
input_messages: List[
|
||||
Union[
|
||||
UserMessage,
|
||||
ToolResponseMessage,
|
||||
]
|
||||
]
|
||||
steps: List[Step]
|
||||
input_messages: list[UserMessage | ToolResponseMessage]
|
||||
steps: list[Step]
|
||||
output_message: CompletionMessage
|
||||
output_attachments: Optional[List[Attachment]] = Field(default_factory=list)
|
||||
output_attachments: list[Attachment] | None = Field(default_factory=list)
|
||||
|
||||
started_at: datetime
|
||||
completed_at: Optional[datetime] = None
|
||||
completed_at: datetime | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -186,34 +166,31 @@ class Session(BaseModel):
|
|||
|
||||
session_id: str
|
||||
session_name: str
|
||||
turns: List[Turn]
|
||||
turns: list[Turn]
|
||||
started_at: datetime
|
||||
|
||||
|
||||
class AgentToolGroupWithArgs(BaseModel):
|
||||
name: str
|
||||
args: Dict[str, Any]
|
||||
args: dict[str, Any]
|
||||
|
||||
|
||||
AgentToolGroup = Union[
|
||||
str,
|
||||
AgentToolGroupWithArgs,
|
||||
]
|
||||
AgentToolGroup = str | AgentToolGroupWithArgs
|
||||
register_schema(AgentToolGroup, name="AgentTool")
|
||||
|
||||
|
||||
class AgentConfigCommon(BaseModel):
|
||||
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
|
||||
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
|
||||
|
||||
input_shields: Optional[List[str]] = Field(default_factory=list)
|
||||
output_shields: Optional[List[str]] = Field(default_factory=list)
|
||||
toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list)
|
||||
client_tools: Optional[List[ToolDef]] = Field(default_factory=list)
|
||||
tool_choice: Optional[ToolChoice] = Field(default=None, deprecated="use tool_config instead")
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None, deprecated="use tool_config instead")
|
||||
tool_config: Optional[ToolConfig] = Field(default=None)
|
||||
input_shields: list[str] | None = Field(default_factory=list)
|
||||
output_shields: list[str] | None = Field(default_factory=list)
|
||||
toolgroups: list[AgentToolGroup] | None = Field(default_factory=list)
|
||||
client_tools: list[ToolDef] | None = Field(default_factory=list)
|
||||
tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead")
|
||||
tool_prompt_format: ToolPromptFormat | None = Field(default=None, deprecated="use tool_config instead")
|
||||
tool_config: ToolConfig | None = Field(default=None)
|
||||
|
||||
max_infer_iters: Optional[int] = 10
|
||||
max_infer_iters: int | None = 10
|
||||
|
||||
def model_post_init(self, __context):
|
||||
if self.tool_config:
|
||||
|
@ -243,9 +220,9 @@ class AgentConfig(AgentConfigCommon):
|
|||
|
||||
model: str
|
||||
instructions: str
|
||||
name: Optional[str] = None
|
||||
enable_session_persistence: Optional[bool] = False
|
||||
response_format: Optional[ResponseFormat] = None
|
||||
name: str | None = None
|
||||
enable_session_persistence: bool | None = False
|
||||
response_format: ResponseFormat | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -257,16 +234,16 @@ class Agent(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class ListAgentsResponse(BaseModel):
|
||||
data: List[Agent]
|
||||
data: list[Agent]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListAgentSessionsResponse(BaseModel):
|
||||
data: List[Session]
|
||||
data: list[Session]
|
||||
|
||||
|
||||
class AgentConfigOverridablePerTurn(AgentConfigCommon):
|
||||
instructions: Optional[str] = None
|
||||
instructions: str | None = None
|
||||
|
||||
|
||||
class AgentTurnResponseEventType(Enum):
|
||||
|
@ -284,7 +261,7 @@ class AgentTurnResponseStepStartPayload(BaseModel):
|
|||
event_type: Literal[AgentTurnResponseEventType.step_start.value] = AgentTurnResponseEventType.step_start.value
|
||||
step_type: StepType
|
||||
step_id: str
|
||||
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
metadata: dict[str, Any] | None = Field(default_factory=dict)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -327,14 +304,12 @@ class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
|
|||
|
||||
|
||||
AgentTurnResponseEventPayload = Annotated[
|
||||
Union[
|
||||
AgentTurnResponseStepStartPayload,
|
||||
AgentTurnResponseStepProgressPayload,
|
||||
AgentTurnResponseStepCompletePayload,
|
||||
AgentTurnResponseTurnStartPayload,
|
||||
AgentTurnResponseTurnCompletePayload,
|
||||
AgentTurnResponseTurnAwaitingInputPayload,
|
||||
],
|
||||
AgentTurnResponseStepStartPayload
|
||||
| AgentTurnResponseStepProgressPayload
|
||||
| AgentTurnResponseStepCompletePayload
|
||||
| AgentTurnResponseTurnStartPayload
|
||||
| AgentTurnResponseTurnCompletePayload
|
||||
| AgentTurnResponseTurnAwaitingInputPayload,
|
||||
Field(discriminator="event_type"),
|
||||
]
|
||||
register_schema(AgentTurnResponseEventPayload, name="AgentTurnResponseEventPayload")
|
||||
|
@ -363,18 +338,13 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
|||
# TODO: figure out how we can simplify this and make why
|
||||
# ToolResponseMessage needs to be here (it is function call
|
||||
# execution from outside the system)
|
||||
messages: List[
|
||||
Union[
|
||||
UserMessage,
|
||||
ToolResponseMessage,
|
||||
]
|
||||
]
|
||||
messages: list[UserMessage | ToolResponseMessage]
|
||||
|
||||
documents: Optional[List[Document]] = None
|
||||
toolgroups: Optional[List[AgentToolGroup]] = None
|
||||
documents: list[Document] | None = None
|
||||
toolgroups: list[AgentToolGroup] | None = None
|
||||
|
||||
stream: Optional[bool] = False
|
||||
tool_config: Optional[ToolConfig] = None
|
||||
stream: bool | None = False
|
||||
tool_config: ToolConfig | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -382,8 +352,8 @@ class AgentTurnResumeRequest(BaseModel):
|
|||
agent_id: str
|
||||
session_id: str
|
||||
turn_id: str
|
||||
tool_responses: List[ToolResponse]
|
||||
stream: Optional[bool] = False
|
||||
tool_responses: list[ToolResponse]
|
||||
stream: bool | None = False
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -429,17 +399,12 @@ class Agents(Protocol):
|
|||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
messages: List[
|
||||
Union[
|
||||
UserMessage,
|
||||
ToolResponseMessage,
|
||||
]
|
||||
],
|
||||
stream: Optional[bool] = False,
|
||||
documents: Optional[List[Document]] = None,
|
||||
toolgroups: Optional[List[AgentToolGroup]] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
|
||||
messages: list[UserMessage | ToolResponseMessage],
|
||||
stream: bool | None = False,
|
||||
documents: list[Document] | None = None,
|
||||
toolgroups: list[AgentToolGroup] | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]:
|
||||
"""Create a new turn for an agent.
|
||||
|
||||
:param agent_id: The ID of the agent to create the turn for.
|
||||
|
@ -463,9 +428,9 @@ class Agents(Protocol):
|
|||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
tool_responses: List[ToolResponse],
|
||||
stream: Optional[bool] = False,
|
||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
|
||||
tool_responses: list[ToolResponse],
|
||||
stream: bool | None = False,
|
||||
) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]:
|
||||
"""Resume an agent turn with executed tool call responses.
|
||||
|
||||
When a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.
|
||||
|
@ -538,7 +503,7 @@ class Agents(Protocol):
|
|||
self,
|
||||
session_id: str,
|
||||
agent_id: str,
|
||||
turn_ids: Optional[List[str]] = None,
|
||||
turn_ids: list[str] | None = None,
|
||||
) -> Session:
|
||||
"""Retrieve an agent session by its ID.
|
||||
|
||||
|
@ -623,14 +588,14 @@ class Agents(Protocol):
|
|||
@webmethod(route="/openai/v1/responses", method="POST")
|
||||
async def create_openai_response(
|
||||
self,
|
||||
input: Union[str, List[OpenAIResponseInputMessage]],
|
||||
input: str | list[OpenAIResponseInputMessage],
|
||||
model: str,
|
||||
previous_response_id: Optional[str] = None,
|
||||
store: Optional[bool] = True,
|
||||
stream: Optional[bool] = False,
|
||||
temperature: Optional[float] = None,
|
||||
tools: Optional[List[OpenAIResponseInputTool]] = None,
|
||||
) -> Union[OpenAIResponseObject, AsyncIterator[OpenAIResponseObjectStream]]:
|
||||
previous_response_id: str | None = None,
|
||||
store: bool | None = True,
|
||||
stream: bool | None = False,
|
||||
temperature: float | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
|
||||
"""Create a new OpenAI response.
|
||||
|
||||
:param input: Input message(s) to create the response.
|
||||
|
|
|
@ -4,10 +4,9 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List, Literal, Optional, Union
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||
|
||||
|
@ -25,7 +24,7 @@ class OpenAIResponseOutputMessageContentOutputText(BaseModel):
|
|||
|
||||
|
||||
OpenAIResponseOutputMessageContent = Annotated[
|
||||
Union[OpenAIResponseOutputMessageContentOutputText,],
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMessageContent")
|
||||
|
@ -34,7 +33,7 @@ register_schema(OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMe
|
|||
@json_schema_type
|
||||
class OpenAIResponseOutputMessage(BaseModel):
|
||||
id: str
|
||||
content: List[OpenAIResponseOutputMessageContent]
|
||||
content: list[OpenAIResponseOutputMessageContent]
|
||||
role: Literal["assistant"] = "assistant"
|
||||
status: str
|
||||
type: Literal["message"] = "message"
|
||||
|
@ -48,10 +47,7 @@ class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
|
|||
|
||||
|
||||
OpenAIResponseOutput = Annotated[
|
||||
Union[
|
||||
OpenAIResponseOutputMessage,
|
||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
],
|
||||
OpenAIResponseOutputMessage | OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
|
||||
|
@ -60,18 +56,18 @@ register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
|
|||
@json_schema_type
|
||||
class OpenAIResponseObject(BaseModel):
|
||||
created_at: int
|
||||
error: Optional[OpenAIResponseError] = None
|
||||
error: OpenAIResponseError | None = None
|
||||
id: str
|
||||
model: str
|
||||
object: Literal["response"] = "response"
|
||||
output: List[OpenAIResponseOutput]
|
||||
output: list[OpenAIResponseOutput]
|
||||
parallel_tool_calls: bool = False
|
||||
previous_response_id: Optional[str] = None
|
||||
previous_response_id: str | None = None
|
||||
status: str
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
truncation: Optional[str] = None
|
||||
user: Optional[str] = None
|
||||
temperature: float | None = None
|
||||
top_p: float | None = None
|
||||
truncation: str | None = None
|
||||
user: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -87,10 +83,7 @@ class OpenAIResponseObjectStreamResponseCompleted(BaseModel):
|
|||
|
||||
|
||||
OpenAIResponseObjectStream = Annotated[
|
||||
Union[
|
||||
OpenAIResponseObjectStreamResponseCreated,
|
||||
OpenAIResponseObjectStreamResponseCompleted,
|
||||
],
|
||||
OpenAIResponseObjectStreamResponseCreated | OpenAIResponseObjectStreamResponseCompleted,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseObjectStream, name="OpenAIResponseObjectStream")
|
||||
|
@ -107,12 +100,12 @@ class OpenAIResponseInputMessageContentImage(BaseModel):
|
|||
detail: Literal["low"] | Literal["high"] | Literal["auto"] = "auto"
|
||||
type: Literal["input_image"] = "input_image"
|
||||
# TODO: handle file_id
|
||||
image_url: Optional[str] = None
|
||||
image_url: str | None = None
|
||||
|
||||
|
||||
# TODO: handle file content types
|
||||
OpenAIResponseInputMessageContent = Annotated[
|
||||
Union[OpenAIResponseInputMessageContentText, OpenAIResponseInputMessageContentImage],
|
||||
OpenAIResponseInputMessageContentText | OpenAIResponseInputMessageContentImage,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseInputMessageContent, name="OpenAIResponseInputMessageContent")
|
||||
|
@ -120,21 +113,21 @@ register_schema(OpenAIResponseInputMessageContent, name="OpenAIResponseInputMess
|
|||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputMessage(BaseModel):
|
||||
content: Union[str, List[OpenAIResponseInputMessageContent]]
|
||||
content: str | list[OpenAIResponseInputMessageContent]
|
||||
role: Literal["system"] | Literal["developer"] | Literal["user"] | Literal["assistant"]
|
||||
type: Optional[Literal["message"]] = "message"
|
||||
type: Literal["message"] | None = "message"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputToolWebSearch(BaseModel):
|
||||
type: Literal["web_search"] | Literal["web_search_preview_2025_03_11"] = "web_search"
|
||||
# TODO: actually use search_context_size somewhere...
|
||||
search_context_size: Optional[str] = Field(default="medium", pattern="^low|medium|high$")
|
||||
search_context_size: str | None = Field(default="medium", pattern="^low|medium|high$")
|
||||
# TODO: add user_location
|
||||
|
||||
|
||||
OpenAIResponseInputTool = Annotated[
|
||||
Union[OpenAIResponseInputToolWebSearch,],
|
||||
OpenAIResponseInputToolWebSearch,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool")
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List, Optional, Protocol, runtime_checkable
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
from llama_stack.apis.common.job_types import Job
|
||||
from llama_stack.apis.inference import (
|
||||
|
@ -34,22 +34,22 @@ class BatchInference(Protocol):
|
|||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
content_batch: List[InterleavedContent],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
content_batch: list[InterleavedContent],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> Job: ...
|
||||
|
||||
@webmethod(route="/batch-inference/chat-completion", method="POST")
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages_batch: List[List[Message]],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
messages_batch: list[list[Message]],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
# zero-shot tool definitions as input to the model
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> Job: ...
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
||||
from typing import Any, Literal, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
@ -13,8 +13,8 @@ from llama_stack.schema_utils import json_schema_type, webmethod
|
|||
|
||||
class CommonBenchmarkFields(BaseModel):
|
||||
dataset_id: str
|
||||
scoring_functions: List[str]
|
||||
metadata: Dict[str, Any] = Field(
|
||||
scoring_functions: list[str]
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Metadata for this evaluation task",
|
||||
)
|
||||
|
@ -35,12 +35,12 @@ class Benchmark(CommonBenchmarkFields, Resource):
|
|||
|
||||
class BenchmarkInput(CommonBenchmarkFields, BaseModel):
|
||||
benchmark_id: str
|
||||
provider_id: Optional[str] = None
|
||||
provider_benchmark_id: Optional[str] = None
|
||||
provider_id: str | None = None
|
||||
provider_benchmark_id: str | None = None
|
||||
|
||||
|
||||
class ListBenchmarksResponse(BaseModel):
|
||||
data: List[Benchmark]
|
||||
data: list[Benchmark]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
@ -59,8 +59,8 @@ class Benchmarks(Protocol):
|
|||
self,
|
||||
benchmark_id: str,
|
||||
dataset_id: str,
|
||||
scoring_functions: List[str],
|
||||
provider_benchmark_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
scoring_functions: list[str],
|
||||
provider_benchmark_id: str | None = None,
|
||||
provider_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None: ...
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Annotated, List, Literal, Optional, Union
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
@ -26,9 +26,9 @@ class _URLOrData(BaseModel):
|
|||
:param data: base64 encoded image data as string
|
||||
"""
|
||||
|
||||
url: Optional[URL] = None
|
||||
url: URL | None = None
|
||||
# data is a base64 encoded string, hint with contentEncoding=base64
|
||||
data: Optional[str] = Field(contentEncoding="base64", default=None)
|
||||
data: str | None = Field(contentEncoding="base64", default=None)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
|
@ -64,13 +64,13 @@ class TextContentItem(BaseModel):
|
|||
|
||||
# other modalities can be added here
|
||||
InterleavedContentItem = Annotated[
|
||||
Union[ImageContentItem, TextContentItem],
|
||||
ImageContentItem | TextContentItem,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(InterleavedContentItem, name="InterleavedContentItem")
|
||||
|
||||
# accept a single "str" as a special case since it is common
|
||||
InterleavedContent = Union[str, InterleavedContentItem, List[InterleavedContentItem]]
|
||||
InterleavedContent = str | InterleavedContentItem | list[InterleavedContentItem]
|
||||
register_schema(InterleavedContent, name="InterleavedContent")
|
||||
|
||||
|
||||
|
@ -100,13 +100,13 @@ class ToolCallDelta(BaseModel):
|
|||
# you either send an in-progress tool call so the client can stream a long
|
||||
# code generation or you send the final parsed tool call at the end of the
|
||||
# stream
|
||||
tool_call: Union[str, ToolCall]
|
||||
tool_call: str | ToolCall
|
||||
parse_status: ToolCallParseStatus
|
||||
|
||||
|
||||
# streaming completions send a stream of ContentDeltas
|
||||
ContentDelta = Annotated[
|
||||
Union[TextDelta, ImageDelta, ToolCallDelta],
|
||||
TextDelta | ImageDelta | ToolCallDelta,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(ContentDelta, name="ContentDelta")
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -25,6 +25,6 @@ class RestAPIMethod(Enum):
|
|||
class RestAPIExecutionConfig(BaseModel):
|
||||
url: URL
|
||||
method: RestAPIMethod
|
||||
params: Optional[Dict[str, Any]] = None
|
||||
headers: Optional[Dict[str, Any]] = None
|
||||
body: Optional[Dict[str, Any]] = None
|
||||
params: dict[str, Any] | None = None
|
||||
headers: dict[str, Any] | None = None
|
||||
body: dict[str, Any] | None = None
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -19,5 +19,5 @@ class PaginatedResponse(BaseModel):
|
|||
:param has_more: Whether there are more items available after this set
|
||||
"""
|
||||
|
||||
data: List[Dict[str, Any]]
|
||||
data: list[dict[str, Any]]
|
||||
has_more: bool
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -27,4 +26,4 @@ class Checkpoint(BaseModel):
|
|||
epoch: int
|
||||
post_training_job_id: str
|
||||
path: str
|
||||
training_metrics: Optional[PostTrainingMetric] = None
|
||||
training_metrics: PostTrainingMetric | None = None
|
||||
|
|
|
@ -4,10 +4,9 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Literal, Union
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||
|
||||
|
@ -73,18 +72,16 @@ class DialogType(BaseModel):
|
|||
|
||||
|
||||
ParamType = Annotated[
|
||||
Union[
|
||||
StringType,
|
||||
NumberType,
|
||||
BooleanType,
|
||||
ArrayType,
|
||||
ObjectType,
|
||||
JsonType,
|
||||
UnionType,
|
||||
ChatCompletionInputType,
|
||||
CompletionInputType,
|
||||
AgentTurnInputType,
|
||||
],
|
||||
StringType
|
||||
| NumberType
|
||||
| BooleanType
|
||||
| ArrayType
|
||||
| ObjectType
|
||||
| JsonType
|
||||
| UnionType
|
||||
| ChatCompletionInputType
|
||||
| CompletionInputType
|
||||
| AgentTurnInputType,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(ParamType, name="ParamType")
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.apis.datasets import Dataset
|
||||
|
@ -24,8 +24,8 @@ class DatasetIO(Protocol):
|
|||
async def iterrows(
|
||||
self,
|
||||
dataset_id: str,
|
||||
start_index: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
start_index: int | None = None,
|
||||
limit: int | None = None,
|
||||
) -> PaginatedResponse:
|
||||
"""Get a paginated list of rows from a dataset.
|
||||
|
||||
|
@ -44,4 +44,4 @@ class DatasetIO(Protocol):
|
|||
...
|
||||
|
||||
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST")
|
||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ...
|
||||
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None: ...
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Protocol, Union
|
||||
from typing import Annotated, Any, Literal, Protocol
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
@ -81,11 +81,11 @@ class RowsDataSource(BaseModel):
|
|||
"""
|
||||
|
||||
type: Literal["rows"] = "rows"
|
||||
rows: List[Dict[str, Any]]
|
||||
rows: list[dict[str, Any]]
|
||||
|
||||
|
||||
DataSource = Annotated[
|
||||
Union[URIDataSource, RowsDataSource],
|
||||
URIDataSource | RowsDataSource,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(DataSource, name="DataSource")
|
||||
|
@ -98,7 +98,7 @@ class CommonDatasetFields(BaseModel):
|
|||
|
||||
purpose: DatasetPurpose
|
||||
source: DataSource
|
||||
metadata: Dict[str, Any] = Field(
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Any additional metadata for this dataset",
|
||||
)
|
||||
|
@ -122,7 +122,7 @@ class DatasetInput(CommonDatasetFields, BaseModel):
|
|||
|
||||
|
||||
class ListDatasetsResponse(BaseModel):
|
||||
data: List[Dataset]
|
||||
data: list[Dataset]
|
||||
|
||||
|
||||
class Datasets(Protocol):
|
||||
|
@ -131,8 +131,8 @@ class Datasets(Protocol):
|
|||
self,
|
||||
purpose: DatasetPurpose,
|
||||
source: DataSource,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
dataset_id: Optional[str] = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
dataset_id: str | None = None,
|
||||
) -> Dataset:
|
||||
"""
|
||||
Register a new dataset.
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -54,4 +53,4 @@ class Error(BaseModel):
|
|||
status: int
|
||||
title: str
|
||||
detail: str
|
||||
instance: Optional[str] = None
|
||||
instance: str | None = None
|
||||
|
|
|
@ -4,10 +4,9 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
||||
from typing import Annotated, Any, Literal, Protocol
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.apis.agents import AgentConfig
|
||||
from llama_stack.apis.common.job_types import Job
|
||||
|
@ -29,7 +28,7 @@ class ModelCandidate(BaseModel):
|
|||
type: Literal["model"] = "model"
|
||||
model: str
|
||||
sampling_params: SamplingParams
|
||||
system_message: Optional[SystemMessage] = None
|
||||
system_message: SystemMessage | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -43,7 +42,7 @@ class AgentCandidate(BaseModel):
|
|||
config: AgentConfig
|
||||
|
||||
|
||||
EvalCandidate = Annotated[Union[ModelCandidate, AgentCandidate], Field(discriminator="type")]
|
||||
EvalCandidate = Annotated[ModelCandidate | AgentCandidate, Field(discriminator="type")]
|
||||
register_schema(EvalCandidate, name="EvalCandidate")
|
||||
|
||||
|
||||
|
@ -57,11 +56,11 @@ class BenchmarkConfig(BaseModel):
|
|||
"""
|
||||
|
||||
eval_candidate: EvalCandidate
|
||||
scoring_params: Dict[str, ScoringFnParams] = Field(
|
||||
scoring_params: dict[str, ScoringFnParams] = Field(
|
||||
description="Map between scoring function id and parameters for each scoring function you want to run",
|
||||
default_factory=dict,
|
||||
)
|
||||
num_examples: Optional[int] = Field(
|
||||
num_examples: int | None = Field(
|
||||
description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated",
|
||||
default=None,
|
||||
)
|
||||
|
@ -76,9 +75,9 @@ class EvaluateResponse(BaseModel):
|
|||
:param scores: The scores from the evaluation.
|
||||
"""
|
||||
|
||||
generations: List[Dict[str, Any]]
|
||||
generations: list[dict[str, Any]]
|
||||
# each key in the dict is a scoring function name
|
||||
scores: Dict[str, ScoringResult]
|
||||
scores: dict[str, ScoringResult]
|
||||
|
||||
|
||||
class Eval(Protocol):
|
||||
|
@ -101,8 +100,8 @@ class Eval(Protocol):
|
|||
async def evaluate_rows(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
input_rows: list[dict[str, Any]],
|
||||
scoring_functions: list[str],
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse:
|
||||
"""Evaluate a list of rows on a benchmark.
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List, Optional, Protocol, runtime_checkable
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -42,7 +42,7 @@ class ListBucketResponse(BaseModel):
|
|||
:param data: List of FileResponse entries
|
||||
"""
|
||||
|
||||
data: List[BucketResponse]
|
||||
data: list[BucketResponse]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -74,7 +74,7 @@ class ListFileResponse(BaseModel):
|
|||
:param data: List of FileResponse entries
|
||||
"""
|
||||
|
||||
data: List[FileResponse]
|
||||
data: list[FileResponse]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
@ -102,7 +102,7 @@ class Files(Protocol):
|
|||
async def upload_content_to_session(
|
||||
self,
|
||||
upload_id: str,
|
||||
) -> Optional[FileResponse]:
|
||||
) -> FileResponse | None:
|
||||
"""
|
||||
Upload file content to an existing upload session.
|
||||
On the server, request body will have the raw bytes that are uploaded.
|
||||
|
|
|
@ -4,21 +4,18 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Protocol,
|
||||
Union,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing_extensions import Annotated, TypedDict
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
|
||||
from llama_stack.apis.models import Model
|
||||
|
@ -47,8 +44,8 @@ class GreedySamplingStrategy(BaseModel):
|
|||
@json_schema_type
|
||||
class TopPSamplingStrategy(BaseModel):
|
||||
type: Literal["top_p"] = "top_p"
|
||||
temperature: Optional[float] = Field(..., gt=0.0)
|
||||
top_p: Optional[float] = 0.95
|
||||
temperature: float | None = Field(..., gt=0.0)
|
||||
top_p: float | None = 0.95
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -58,7 +55,7 @@ class TopKSamplingStrategy(BaseModel):
|
|||
|
||||
|
||||
SamplingStrategy = Annotated[
|
||||
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
|
||||
GreedySamplingStrategy | TopPSamplingStrategy | TopKSamplingStrategy,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(SamplingStrategy, name="SamplingStrategy")
|
||||
|
@ -79,9 +76,9 @@ class SamplingParams(BaseModel):
|
|||
|
||||
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
|
||||
|
||||
max_tokens: Optional[int] = 0
|
||||
repetition_penalty: Optional[float] = 1.0
|
||||
stop: Optional[List[str]] = None
|
||||
max_tokens: int | None = 0
|
||||
repetition_penalty: float | None = 1.0
|
||||
stop: list[str] | None = None
|
||||
|
||||
|
||||
class LogProbConfig(BaseModel):
|
||||
|
@ -90,7 +87,7 @@ class LogProbConfig(BaseModel):
|
|||
:param top_k: How many tokens (for each position) to return log probabilities for.
|
||||
"""
|
||||
|
||||
top_k: Optional[int] = 0
|
||||
top_k: int | None = 0
|
||||
|
||||
|
||||
class QuantizationType(Enum):
|
||||
|
@ -125,11 +122,11 @@ class Int4QuantizationConfig(BaseModel):
|
|||
"""
|
||||
|
||||
type: Literal["int4_mixed"] = "int4_mixed"
|
||||
scheme: Optional[str] = "int4_weight_int8_dynamic_activation"
|
||||
scheme: str | None = "int4_weight_int8_dynamic_activation"
|
||||
|
||||
|
||||
QuantizationConfig = Annotated[
|
||||
Union[Bf16QuantizationConfig, Fp8QuantizationConfig, Int4QuantizationConfig],
|
||||
Bf16QuantizationConfig | Fp8QuantizationConfig | Int4QuantizationConfig,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
@ -145,7 +142,7 @@ class UserMessage(BaseModel):
|
|||
|
||||
role: Literal["user"] = "user"
|
||||
content: InterleavedContent
|
||||
context: Optional[InterleavedContent] = None
|
||||
context: InterleavedContent | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -190,16 +187,11 @@ class CompletionMessage(BaseModel):
|
|||
role: Literal["assistant"] = "assistant"
|
||||
content: InterleavedContent
|
||||
stop_reason: StopReason
|
||||
tool_calls: Optional[List[ToolCall]] = Field(default_factory=list)
|
||||
tool_calls: list[ToolCall] | None = Field(default_factory=list)
|
||||
|
||||
|
||||
Message = Annotated[
|
||||
Union[
|
||||
UserMessage,
|
||||
SystemMessage,
|
||||
ToolResponseMessage,
|
||||
CompletionMessage,
|
||||
],
|
||||
UserMessage | SystemMessage | ToolResponseMessage | CompletionMessage,
|
||||
Field(discriminator="role"),
|
||||
]
|
||||
register_schema(Message, name="Message")
|
||||
|
@ -208,9 +200,9 @@ register_schema(Message, name="Message")
|
|||
@json_schema_type
|
||||
class ToolResponse(BaseModel):
|
||||
call_id: str
|
||||
tool_name: Union[BuiltinTool, str]
|
||||
tool_name: BuiltinTool | str
|
||||
content: InterleavedContent
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
@field_validator("tool_name", mode="before")
|
||||
@classmethod
|
||||
|
@ -243,7 +235,7 @@ class TokenLogProbs(BaseModel):
|
|||
:param logprobs_by_token: Dictionary mapping tokens to their log probabilities
|
||||
"""
|
||||
|
||||
logprobs_by_token: Dict[str, float]
|
||||
logprobs_by_token: dict[str, float]
|
||||
|
||||
|
||||
class ChatCompletionResponseEventType(Enum):
|
||||
|
@ -271,8 +263,8 @@ class ChatCompletionResponseEvent(BaseModel):
|
|||
|
||||
event_type: ChatCompletionResponseEventType
|
||||
delta: ContentDelta
|
||||
logprobs: Optional[List[TokenLogProbs]] = None
|
||||
stop_reason: Optional[StopReason] = None
|
||||
logprobs: list[TokenLogProbs] | None = None
|
||||
stop_reason: StopReason | None = None
|
||||
|
||||
|
||||
class ResponseFormatType(Enum):
|
||||
|
@ -295,7 +287,7 @@ class JsonSchemaResponseFormat(BaseModel):
|
|||
"""
|
||||
|
||||
type: Literal[ResponseFormatType.json_schema.value] = ResponseFormatType.json_schema.value
|
||||
json_schema: Dict[str, Any]
|
||||
json_schema: dict[str, Any]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -307,11 +299,11 @@ class GrammarResponseFormat(BaseModel):
|
|||
"""
|
||||
|
||||
type: Literal[ResponseFormatType.grammar.value] = ResponseFormatType.grammar.value
|
||||
bnf: Dict[str, Any]
|
||||
bnf: dict[str, Any]
|
||||
|
||||
|
||||
ResponseFormat = Annotated[
|
||||
Union[JsonSchemaResponseFormat, GrammarResponseFormat],
|
||||
JsonSchemaResponseFormat | GrammarResponseFormat,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(ResponseFormat, name="ResponseFormat")
|
||||
|
@ -321,10 +313,10 @@ register_schema(ResponseFormat, name="ResponseFormat")
|
|||
class CompletionRequest(BaseModel):
|
||||
model: str
|
||||
content: InterleavedContent
|
||||
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
|
||||
response_format: Optional[ResponseFormat] = None
|
||||
stream: Optional[bool] = False
|
||||
logprobs: Optional[LogProbConfig] = None
|
||||
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
|
||||
response_format: ResponseFormat | None = None
|
||||
stream: bool | None = False
|
||||
logprobs: LogProbConfig | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -338,7 +330,7 @@ class CompletionResponse(MetricResponseMixin):
|
|||
|
||||
content: str
|
||||
stop_reason: StopReason
|
||||
logprobs: Optional[List[TokenLogProbs]] = None
|
||||
logprobs: list[TokenLogProbs] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -351,8 +343,8 @@ class CompletionResponseStreamChunk(MetricResponseMixin):
|
|||
"""
|
||||
|
||||
delta: str
|
||||
stop_reason: Optional[StopReason] = None
|
||||
logprobs: Optional[List[TokenLogProbs]] = None
|
||||
stop_reason: StopReason | None = None
|
||||
logprobs: list[TokenLogProbs] | None = None
|
||||
|
||||
|
||||
class SystemMessageBehavior(Enum):
|
||||
|
@ -383,9 +375,9 @@ class ToolConfig(BaseModel):
|
|||
'{{function_definitions}}' to indicate where the function definitions should be inserted.
|
||||
"""
|
||||
|
||||
tool_choice: Optional[ToolChoice | str] = Field(default=ToolChoice.auto)
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
|
||||
system_message_behavior: Optional[SystemMessageBehavior] = Field(default=SystemMessageBehavior.append)
|
||||
tool_choice: ToolChoice | str | None = Field(default=ToolChoice.auto)
|
||||
tool_prompt_format: ToolPromptFormat | None = Field(default=None)
|
||||
system_message_behavior: SystemMessageBehavior | None = Field(default=SystemMessageBehavior.append)
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
if isinstance(self.tool_choice, str):
|
||||
|
@ -399,15 +391,15 @@ class ToolConfig(BaseModel):
|
|||
@json_schema_type
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[Message]
|
||||
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
|
||||
messages: list[Message]
|
||||
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
|
||||
|
||||
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
|
||||
tool_config: Optional[ToolConfig] = Field(default_factory=ToolConfig)
|
||||
tools: list[ToolDefinition] | None = Field(default_factory=list)
|
||||
tool_config: ToolConfig | None = Field(default_factory=ToolConfig)
|
||||
|
||||
response_format: Optional[ResponseFormat] = None
|
||||
stream: Optional[bool] = False
|
||||
logprobs: Optional[LogProbConfig] = None
|
||||
response_format: ResponseFormat | None = None
|
||||
stream: bool | None = False
|
||||
logprobs: LogProbConfig | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -429,7 +421,7 @@ class ChatCompletionResponse(MetricResponseMixin):
|
|||
"""
|
||||
|
||||
completion_message: CompletionMessage
|
||||
logprobs: Optional[List[TokenLogProbs]] = None
|
||||
logprobs: list[TokenLogProbs] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -439,7 +431,7 @@ class EmbeddingsResponse(BaseModel):
|
|||
:param embeddings: List of embedding vectors, one per input content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
|
||||
"""
|
||||
|
||||
embeddings: List[List[float]]
|
||||
embeddings: list[list[float]]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -451,7 +443,7 @@ class OpenAIChatCompletionContentPartTextParam(BaseModel):
|
|||
@json_schema_type
|
||||
class OpenAIImageURL(BaseModel):
|
||||
url: str
|
||||
detail: Optional[str] = None
|
||||
detail: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -461,16 +453,13 @@ class OpenAIChatCompletionContentPartImageParam(BaseModel):
|
|||
|
||||
|
||||
OpenAIChatCompletionContentPartParam = Annotated[
|
||||
Union[
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
],
|
||||
OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam")
|
||||
|
||||
|
||||
OpenAIChatCompletionMessageContent = Union[str, List[OpenAIChatCompletionContentPartParam]]
|
||||
OpenAIChatCompletionMessageContent = str | list[OpenAIChatCompletionContentPartParam]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -484,7 +473,7 @@ class OpenAIUserMessageParam(BaseModel):
|
|||
|
||||
role: Literal["user"] = "user"
|
||||
content: OpenAIChatCompletionMessageContent
|
||||
name: Optional[str] = None
|
||||
name: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -498,21 +487,21 @@ class OpenAISystemMessageParam(BaseModel):
|
|||
|
||||
role: Literal["system"] = "system"
|
||||
content: OpenAIChatCompletionMessageContent
|
||||
name: Optional[str] = None
|
||||
name: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChatCompletionToolCallFunction(BaseModel):
|
||||
name: Optional[str] = None
|
||||
arguments: Optional[str] = None
|
||||
name: str | None = None
|
||||
arguments: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChatCompletionToolCall(BaseModel):
|
||||
index: Optional[int] = None
|
||||
id: Optional[str] = None
|
||||
index: int | None = None
|
||||
id: str | None = None
|
||||
type: Literal["function"] = "function"
|
||||
function: Optional[OpenAIChatCompletionToolCallFunction] = None
|
||||
function: OpenAIChatCompletionToolCallFunction | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -526,9 +515,9 @@ class OpenAIAssistantMessageParam(BaseModel):
|
|||
"""
|
||||
|
||||
role: Literal["assistant"] = "assistant"
|
||||
content: Optional[OpenAIChatCompletionMessageContent] = None
|
||||
name: Optional[str] = None
|
||||
tool_calls: Optional[List[OpenAIChatCompletionToolCall]] = None
|
||||
content: OpenAIChatCompletionMessageContent | None = None
|
||||
name: str | None = None
|
||||
tool_calls: list[OpenAIChatCompletionToolCall] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -556,17 +545,15 @@ class OpenAIDeveloperMessageParam(BaseModel):
|
|||
|
||||
role: Literal["developer"] = "developer"
|
||||
content: OpenAIChatCompletionMessageContent
|
||||
name: Optional[str] = None
|
||||
name: str | None = None
|
||||
|
||||
|
||||
OpenAIMessageParam = Annotated[
|
||||
Union[
|
||||
OpenAIUserMessageParam,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
OpenAIDeveloperMessageParam,
|
||||
],
|
||||
OpenAIUserMessageParam
|
||||
| OpenAISystemMessageParam
|
||||
| OpenAIAssistantMessageParam
|
||||
| OpenAIToolMessageParam
|
||||
| OpenAIDeveloperMessageParam,
|
||||
Field(discriminator="role"),
|
||||
]
|
||||
register_schema(OpenAIMessageParam, name="OpenAIMessageParam")
|
||||
|
@ -580,14 +567,14 @@ class OpenAIResponseFormatText(BaseModel):
|
|||
@json_schema_type
|
||||
class OpenAIJSONSchema(TypedDict, total=False):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
strict: Optional[bool] = None
|
||||
description: str | None = None
|
||||
strict: bool | None = None
|
||||
|
||||
# Pydantic BaseModel cannot be used with a schema param, since it already
|
||||
# has one. And, we don't want to alias here because then have to handle
|
||||
# that alias when converting to OpenAI params. So, to support schema,
|
||||
# we use a TypedDict.
|
||||
schema: Optional[Dict[str, Any]] = None
|
||||
schema: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -602,11 +589,7 @@ class OpenAIResponseFormatJSONObject(BaseModel):
|
|||
|
||||
|
||||
OpenAIResponseFormatParam = Annotated[
|
||||
Union[
|
||||
OpenAIResponseFormatText,
|
||||
OpenAIResponseFormatJSONSchema,
|
||||
OpenAIResponseFormatJSONObject,
|
||||
],
|
||||
OpenAIResponseFormatText | OpenAIResponseFormatJSONSchema | OpenAIResponseFormatJSONObject,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseFormatParam, name="OpenAIResponseFormatParam")
|
||||
|
@ -622,7 +605,7 @@ class OpenAITopLogProb(BaseModel):
|
|||
"""
|
||||
|
||||
token: str
|
||||
bytes: Optional[List[int]] = None
|
||||
bytes: list[int] | None = None
|
||||
logprob: float
|
||||
|
||||
|
||||
|
@ -637,9 +620,9 @@ class OpenAITokenLogProb(BaseModel):
|
|||
"""
|
||||
|
||||
token: str
|
||||
bytes: Optional[List[int]] = None
|
||||
bytes: list[int] | None = None
|
||||
logprob: float
|
||||
top_logprobs: List[OpenAITopLogProb]
|
||||
top_logprobs: list[OpenAITopLogProb]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -650,8 +633,8 @@ class OpenAIChoiceLogprobs(BaseModel):
|
|||
:param refusal: (Optional) The log probabilities for the tokens in the message
|
||||
"""
|
||||
|
||||
content: Optional[List[OpenAITokenLogProb]] = None
|
||||
refusal: Optional[List[OpenAITokenLogProb]] = None
|
||||
content: list[OpenAITokenLogProb] | None = None
|
||||
refusal: list[OpenAITokenLogProb] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -664,10 +647,10 @@ class OpenAIChoiceDelta(BaseModel):
|
|||
:param tool_calls: (Optional) The tool calls of the delta
|
||||
"""
|
||||
|
||||
content: Optional[str] = None
|
||||
refusal: Optional[str] = None
|
||||
role: Optional[str] = None
|
||||
tool_calls: Optional[List[OpenAIChatCompletionToolCall]] = None
|
||||
content: str | None = None
|
||||
refusal: str | None = None
|
||||
role: str | None = None
|
||||
tool_calls: list[OpenAIChatCompletionToolCall] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -683,7 +666,7 @@ class OpenAIChunkChoice(BaseModel):
|
|||
delta: OpenAIChoiceDelta
|
||||
finish_reason: str
|
||||
index: int
|
||||
logprobs: Optional[OpenAIChoiceLogprobs] = None
|
||||
logprobs: OpenAIChoiceLogprobs | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -699,7 +682,7 @@ class OpenAIChoice(BaseModel):
|
|||
message: OpenAIMessageParam
|
||||
finish_reason: str
|
||||
index: int
|
||||
logprobs: Optional[OpenAIChoiceLogprobs] = None
|
||||
logprobs: OpenAIChoiceLogprobs | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -714,7 +697,7 @@ class OpenAIChatCompletion(BaseModel):
|
|||
"""
|
||||
|
||||
id: str
|
||||
choices: List[OpenAIChoice]
|
||||
choices: list[OpenAIChoice]
|
||||
object: Literal["chat.completion"] = "chat.completion"
|
||||
created: int
|
||||
model: str
|
||||
|
@ -732,7 +715,7 @@ class OpenAIChatCompletionChunk(BaseModel):
|
|||
"""
|
||||
|
||||
id: str
|
||||
choices: List[OpenAIChunkChoice]
|
||||
choices: list[OpenAIChunkChoice]
|
||||
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||
created: int
|
||||
model: str
|
||||
|
@ -748,10 +731,10 @@ class OpenAICompletionLogprobs(BaseModel):
|
|||
:top_logprobs: (Optional) The top log probabilities for the tokens
|
||||
"""
|
||||
|
||||
text_offset: Optional[List[int]] = None
|
||||
token_logprobs: Optional[List[float]] = None
|
||||
tokens: Optional[List[str]] = None
|
||||
top_logprobs: Optional[List[Dict[str, float]]] = None
|
||||
text_offset: list[int] | None = None
|
||||
token_logprobs: list[float] | None = None
|
||||
tokens: list[str] | None = None
|
||||
top_logprobs: list[dict[str, float]] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -767,7 +750,7 @@ class OpenAICompletionChoice(BaseModel):
|
|||
finish_reason: str
|
||||
text: str
|
||||
index: int
|
||||
logprobs: Optional[OpenAIChoiceLogprobs] = None
|
||||
logprobs: OpenAIChoiceLogprobs | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -782,7 +765,7 @@ class OpenAICompletion(BaseModel):
|
|||
"""
|
||||
|
||||
id: str
|
||||
choices: List[OpenAICompletionChoice]
|
||||
choices: list[OpenAICompletionChoice]
|
||||
created: int
|
||||
model: str
|
||||
object: Literal["text_completion"] = "text_completion"
|
||||
|
@ -818,12 +801,12 @@ class EmbeddingTaskType(Enum):
|
|||
|
||||
@json_schema_type
|
||||
class BatchCompletionResponse(BaseModel):
|
||||
batch: List[CompletionResponse]
|
||||
batch: list[CompletionResponse]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BatchChatCompletionResponse(BaseModel):
|
||||
batch: List[ChatCompletionResponse]
|
||||
batch: list[ChatCompletionResponse]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
@ -843,11 +826,11 @@ class Inference(Protocol):
|
|||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> CompletionResponse | AsyncIterator[CompletionResponseStreamChunk]:
|
||||
"""Generate a completion for the given content using the specified model.
|
||||
|
||||
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
|
@ -865,10 +848,10 @@ class Inference(Protocol):
|
|||
async def batch_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content_batch: List[InterleavedContent],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
content_batch: list[InterleavedContent],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> BatchCompletionResponse:
|
||||
raise NotImplementedError("Batch completion is not implemented")
|
||||
|
||||
|
@ -876,16 +859,16 @@ class Inference(Protocol):
|
|||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||
messages: list[Message],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||
"""Generate a chat completion for the given messages using the specified model.
|
||||
|
||||
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
|
@ -916,12 +899,12 @@ class Inference(Protocol):
|
|||
async def batch_chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages_batch: List[List[Message]],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
messages_batch: list[list[Message]],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> BatchChatCompletionResponse:
|
||||
raise NotImplementedError("Batch chat completion is not implemented")
|
||||
|
||||
|
@ -929,10 +912,10 @@ class Inference(Protocol):
|
|||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
"""Generate embeddings for content pieces using the specified model.
|
||||
|
||||
|
@ -950,25 +933,25 @@ class Inference(Protocol):
|
|||
self,
|
||||
# Standard OpenAI completion parameters
|
||||
model: str,
|
||||
prompt: Union[str, List[str], List[int], List[List[int]]],
|
||||
best_of: Optional[int] = None,
|
||||
echo: Optional[bool] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
logit_bias: Optional[Dict[str, float]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
stream_options: Optional[Dict[str, Any]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
# vLLM-specific parameters
|
||||
guided_choice: Optional[List[str]] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
) -> OpenAICompletion:
|
||||
"""Generate an OpenAI-compatible completion for the given prompt using the specified model.
|
||||
|
||||
|
@ -996,29 +979,29 @@ class Inference(Protocol):
|
|||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[OpenAIMessageParam],
|
||||
frequency_penalty: Optional[float] = None,
|
||||
function_call: Optional[Union[str, Dict[str, Any]]] = None,
|
||||
functions: Optional[List[Dict[str, Any]]] = None,
|
||||
logit_bias: Optional[Dict[str, float]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
max_completion_tokens: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
stream_options: Optional[Dict[str, Any]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
"""Generate an OpenAI-compatible chat completion for the given messages using the specified model.
|
||||
|
||||
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List, Protocol, runtime_checkable
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -16,7 +16,7 @@ from llama_stack.schema_utils import json_schema_type, webmethod
|
|||
class RouteInfo(BaseModel):
|
||||
route: str
|
||||
method: str
|
||||
provider_types: List[str]
|
||||
provider_types: list[str]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -30,7 +30,7 @@ class VersionInfo(BaseModel):
|
|||
|
||||
|
||||
class ListRoutesResponse(BaseModel):
|
||||
data: List[RouteInfo]
|
||||
data: list[RouteInfo]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
||||
from typing import Any, Literal, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
@ -15,7 +15,7 @@ from llama_stack.schema_utils import json_schema_type, webmethod
|
|||
|
||||
|
||||
class CommonModelFields(BaseModel):
|
||||
metadata: Dict[str, Any] = Field(
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Any additional metadata for this model",
|
||||
)
|
||||
|
@ -46,14 +46,14 @@ class Model(CommonModelFields, Resource):
|
|||
|
||||
class ModelInput(CommonModelFields):
|
||||
model_id: str
|
||||
provider_id: Optional[str] = None
|
||||
provider_model_id: Optional[str] = None
|
||||
model_type: Optional[ModelType] = ModelType.llm
|
||||
provider_id: str | None = None
|
||||
provider_model_id: str | None = None
|
||||
model_type: ModelType | None = ModelType.llm
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class ListModelsResponse(BaseModel):
|
||||
data: List[Model]
|
||||
data: list[Model]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -73,7 +73,7 @@ class OpenAIModel(BaseModel):
|
|||
|
||||
|
||||
class OpenAIListModelsResponse(BaseModel):
|
||||
data: List[OpenAIModel]
|
||||
data: list[OpenAIModel]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
@ -95,10 +95,10 @@ class Models(Protocol):
|
|||
async def register_model(
|
||||
self,
|
||||
model_id: str,
|
||||
provider_model_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
provider_model_id: str | None = None,
|
||||
provider_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
model_type: ModelType | None = None,
|
||||
) -> Model: ...
|
||||
|
||||
@webmethod(route="/models/{model_id:path}", method="DELETE")
|
||||
|
|
|
@ -6,10 +6,9 @@
|
|||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
||||
from typing import Annotated, Any, Literal, Protocol
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.common.job_types import JobStatus
|
||||
|
@ -36,9 +35,9 @@ class DataConfig(BaseModel):
|
|||
batch_size: int
|
||||
shuffle: bool
|
||||
data_format: DatasetFormat
|
||||
validation_dataset_id: Optional[str] = None
|
||||
packed: Optional[bool] = False
|
||||
train_on_input: Optional[bool] = False
|
||||
validation_dataset_id: str | None = None
|
||||
packed: bool | None = False
|
||||
train_on_input: bool | None = False
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -51,10 +50,10 @@ class OptimizerConfig(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class EfficiencyConfig(BaseModel):
|
||||
enable_activation_checkpointing: Optional[bool] = False
|
||||
enable_activation_offloading: Optional[bool] = False
|
||||
memory_efficient_fsdp_wrap: Optional[bool] = False
|
||||
fsdp_cpu_offload: Optional[bool] = False
|
||||
enable_activation_checkpointing: bool | None = False
|
||||
enable_activation_offloading: bool | None = False
|
||||
memory_efficient_fsdp_wrap: bool | None = False
|
||||
fsdp_cpu_offload: bool | None = False
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -62,23 +61,23 @@ class TrainingConfig(BaseModel):
|
|||
n_epochs: int
|
||||
max_steps_per_epoch: int = 1
|
||||
gradient_accumulation_steps: int = 1
|
||||
max_validation_steps: Optional[int] = 1
|
||||
data_config: Optional[DataConfig] = None
|
||||
optimizer_config: Optional[OptimizerConfig] = None
|
||||
efficiency_config: Optional[EfficiencyConfig] = None
|
||||
dtype: Optional[str] = "bf16"
|
||||
max_validation_steps: int | None = 1
|
||||
data_config: DataConfig | None = None
|
||||
optimizer_config: OptimizerConfig | None = None
|
||||
efficiency_config: EfficiencyConfig | None = None
|
||||
dtype: str | None = "bf16"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class LoraFinetuningConfig(BaseModel):
|
||||
type: Literal["LoRA"] = "LoRA"
|
||||
lora_attn_modules: List[str]
|
||||
lora_attn_modules: list[str]
|
||||
apply_lora_to_mlp: bool
|
||||
apply_lora_to_output: bool
|
||||
rank: int
|
||||
alpha: int
|
||||
use_dora: Optional[bool] = False
|
||||
quantize_base: Optional[bool] = False
|
||||
use_dora: bool | None = False
|
||||
quantize_base: bool | None = False
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -88,7 +87,7 @@ class QATFinetuningConfig(BaseModel):
|
|||
group_size: int
|
||||
|
||||
|
||||
AlgorithmConfig = Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")]
|
||||
AlgorithmConfig = Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")]
|
||||
register_schema(AlgorithmConfig, name="AlgorithmConfig")
|
||||
|
||||
|
||||
|
@ -97,7 +96,7 @@ class PostTrainingJobLogStream(BaseModel):
|
|||
"""Stream of logs from a finetuning job."""
|
||||
|
||||
job_uuid: str
|
||||
log_lines: List[str]
|
||||
log_lines: list[str]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -131,8 +130,8 @@ class PostTrainingRLHFRequest(BaseModel):
|
|||
training_config: TrainingConfig
|
||||
|
||||
# TODO: define these
|
||||
hyperparam_search_config: Dict[str, Any]
|
||||
logger_config: Dict[str, Any]
|
||||
hyperparam_search_config: dict[str, Any]
|
||||
logger_config: dict[str, Any]
|
||||
|
||||
|
||||
class PostTrainingJob(BaseModel):
|
||||
|
@ -146,17 +145,17 @@ class PostTrainingJobStatusResponse(BaseModel):
|
|||
job_uuid: str
|
||||
status: JobStatus
|
||||
|
||||
scheduled_at: Optional[datetime] = None
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
scheduled_at: datetime | None = None
|
||||
started_at: datetime | None = None
|
||||
completed_at: datetime | None = None
|
||||
|
||||
resources_allocated: Optional[Dict[str, Any]] = None
|
||||
resources_allocated: dict[str, Any] | None = None
|
||||
|
||||
checkpoints: List[Checkpoint] = Field(default_factory=list)
|
||||
checkpoints: list[Checkpoint] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ListPostTrainingJobsResponse(BaseModel):
|
||||
data: List[PostTrainingJob]
|
||||
data: list[PostTrainingJob]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -164,7 +163,7 @@ class PostTrainingJobArtifactsResponse(BaseModel):
|
|||
"""Artifacts of a finetuning job."""
|
||||
|
||||
job_uuid: str
|
||||
checkpoints: List[Checkpoint] = Field(default_factory=list)
|
||||
checkpoints: list[Checkpoint] = Field(default_factory=list)
|
||||
|
||||
# TODO(ashwin): metrics, evals
|
||||
|
||||
|
@ -175,14 +174,14 @@ class PostTraining(Protocol):
|
|||
self,
|
||||
job_uuid: str,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: Dict[str, Any],
|
||||
logger_config: Dict[str, Any],
|
||||
model: Optional[str] = Field(
|
||||
hyperparam_search_config: dict[str, Any],
|
||||
logger_config: dict[str, Any],
|
||||
model: str | None = Field(
|
||||
default=None,
|
||||
description="Model descriptor for training if not in provider config`",
|
||||
),
|
||||
checkpoint_dir: Optional[str] = None,
|
||||
algorithm_config: Optional[AlgorithmConfig] = None,
|
||||
checkpoint_dir: str | None = None,
|
||||
algorithm_config: AlgorithmConfig | None = None,
|
||||
) -> PostTrainingJob: ...
|
||||
|
||||
@webmethod(route="/post-training/preference-optimize", method="POST")
|
||||
|
@ -192,8 +191,8 @@ class PostTraining(Protocol):
|
|||
finetuned_model: str,
|
||||
algorithm_config: DPOAlignmentConfig,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: Dict[str, Any],
|
||||
logger_config: Dict[str, Any],
|
||||
hyperparam_search_config: dict[str, Any],
|
||||
logger_config: dict[str, Any],
|
||||
) -> PostTrainingJob: ...
|
||||
|
||||
@webmethod(route="/post-training/jobs", method="GET")
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Protocol, runtime_checkable
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -17,12 +17,12 @@ class ProviderInfo(BaseModel):
|
|||
api: str
|
||||
provider_id: str
|
||||
provider_type: str
|
||||
config: Dict[str, Any]
|
||||
config: dict[str, Any]
|
||||
health: HealthResponse
|
||||
|
||||
|
||||
class ListProvidersResponse(BaseModel):
|
||||
data: List[ProviderInfo]
|
||||
data: list[ProviderInfo]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
@ -27,16 +27,16 @@ class SafetyViolation(BaseModel):
|
|||
violation_level: ViolationLevel
|
||||
|
||||
# what message should you convey to the user
|
||||
user_message: Optional[str] = None
|
||||
user_message: str | None = None
|
||||
|
||||
# additional metadata (including specific violation codes) more for
|
||||
# debugging, telemetry
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RunShieldResponse(BaseModel):
|
||||
violation: Optional[SafetyViolation] = None
|
||||
violation: SafetyViolation | None = None
|
||||
|
||||
|
||||
class ShieldStore(Protocol):
|
||||
|
@ -52,6 +52,6 @@ class Safety(Protocol):
|
|||
async def run_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
messages: List[Message],
|
||||
params: Dict[str, Any] = None,
|
||||
messages: list[Message],
|
||||
params: dict[str, Any] = None,
|
||||
) -> RunShieldResponse: ...
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -12,7 +12,7 @@ from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
|
|||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
# mapping of metric to value
|
||||
ScoringResultRow = Dict[str, Any]
|
||||
ScoringResultRow = dict[str, Any]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -24,15 +24,15 @@ class ScoringResult(BaseModel):
|
|||
:param aggregated_results: Map of metric name to aggregated value
|
||||
"""
|
||||
|
||||
score_rows: List[ScoringResultRow]
|
||||
score_rows: list[ScoringResultRow]
|
||||
# aggregated metrics to value
|
||||
aggregated_results: Dict[str, Any]
|
||||
aggregated_results: dict[str, Any]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ScoreBatchResponse(BaseModel):
|
||||
dataset_id: Optional[str] = None
|
||||
results: Dict[str, ScoringResult]
|
||||
dataset_id: str | None = None
|
||||
results: dict[str, ScoringResult]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -44,7 +44,7 @@ class ScoreResponse(BaseModel):
|
|||
"""
|
||||
|
||||
# each key in the dict is a scoring function name
|
||||
results: Dict[str, ScoringResult]
|
||||
results: dict[str, ScoringResult]
|
||||
|
||||
|
||||
class ScoringFunctionStore(Protocol):
|
||||
|
@ -59,15 +59,15 @@ class Scoring(Protocol):
|
|||
async def score_batch(
|
||||
self,
|
||||
dataset_id: str,
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]],
|
||||
scoring_functions: dict[str, ScoringFnParams | None],
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse: ...
|
||||
|
||||
@webmethod(route="/scoring/score", method="POST")
|
||||
async def score(
|
||||
self,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]],
|
||||
input_rows: list[dict[str, Any]],
|
||||
scoring_functions: dict[str, ScoringFnParams | None],
|
||||
) -> ScoreResponse:
|
||||
"""Score a list of rows.
|
||||
|
||||
|
|
|
@ -6,18 +6,14 @@
|
|||
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Protocol,
|
||||
Union,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.apis.common.type_system import ParamType
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
|
@ -46,12 +42,12 @@ class AggregationFunctionType(Enum):
|
|||
class LLMAsJudgeScoringFnParams(BaseModel):
|
||||
type: Literal[ScoringFnParamsType.llm_as_judge.value] = ScoringFnParamsType.llm_as_judge.value
|
||||
judge_model: str
|
||||
prompt_template: Optional[str] = None
|
||||
judge_score_regexes: Optional[List[str]] = Field(
|
||||
prompt_template: str | None = None
|
||||
judge_score_regexes: list[str] | None = Field(
|
||||
description="Regexes to extract the answer from generated response",
|
||||
default_factory=list,
|
||||
)
|
||||
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
|
||||
aggregation_functions: list[AggregationFunctionType] | None = Field(
|
||||
description="Aggregation functions to apply to the scores of each row",
|
||||
default_factory=list,
|
||||
)
|
||||
|
@ -60,11 +56,11 @@ class LLMAsJudgeScoringFnParams(BaseModel):
|
|||
@json_schema_type
|
||||
class RegexParserScoringFnParams(BaseModel):
|
||||
type: Literal[ScoringFnParamsType.regex_parser.value] = ScoringFnParamsType.regex_parser.value
|
||||
parsing_regexes: Optional[List[str]] = Field(
|
||||
parsing_regexes: list[str] | None = Field(
|
||||
description="Regex to extract the answer from generated response",
|
||||
default_factory=list,
|
||||
)
|
||||
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
|
||||
aggregation_functions: list[AggregationFunctionType] | None = Field(
|
||||
description="Aggregation functions to apply to the scores of each row",
|
||||
default_factory=list,
|
||||
)
|
||||
|
@ -73,33 +69,29 @@ class RegexParserScoringFnParams(BaseModel):
|
|||
@json_schema_type
|
||||
class BasicScoringFnParams(BaseModel):
|
||||
type: Literal[ScoringFnParamsType.basic.value] = ScoringFnParamsType.basic.value
|
||||
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
|
||||
aggregation_functions: list[AggregationFunctionType] | None = Field(
|
||||
description="Aggregation functions to apply to the scores of each row",
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
|
||||
ScoringFnParams = Annotated[
|
||||
Union[
|
||||
LLMAsJudgeScoringFnParams,
|
||||
RegexParserScoringFnParams,
|
||||
BasicScoringFnParams,
|
||||
],
|
||||
LLMAsJudgeScoringFnParams | RegexParserScoringFnParams | BasicScoringFnParams,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(ScoringFnParams, name="ScoringFnParams")
|
||||
|
||||
|
||||
class CommonScoringFnFields(BaseModel):
|
||||
description: Optional[str] = None
|
||||
metadata: Dict[str, Any] = Field(
|
||||
description: str | None = None
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Any additional metadata for this definition",
|
||||
)
|
||||
return_type: ParamType = Field(
|
||||
description="The return type of the deterministic function",
|
||||
)
|
||||
params: Optional[ScoringFnParams] = Field(
|
||||
params: ScoringFnParams | None = Field(
|
||||
description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval",
|
||||
default=None,
|
||||
)
|
||||
|
@ -120,12 +112,12 @@ class ScoringFn(CommonScoringFnFields, Resource):
|
|||
|
||||
class ScoringFnInput(CommonScoringFnFields, BaseModel):
|
||||
scoring_fn_id: str
|
||||
provider_id: Optional[str] = None
|
||||
provider_scoring_fn_id: Optional[str] = None
|
||||
provider_id: str | None = None
|
||||
provider_scoring_fn_id: str | None = None
|
||||
|
||||
|
||||
class ListScoringFunctionsResponse(BaseModel):
|
||||
data: List[ScoringFn]
|
||||
data: list[ScoringFn]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
@ -142,7 +134,7 @@ class ScoringFunctions(Protocol):
|
|||
scoring_fn_id: str,
|
||||
description: str,
|
||||
return_type: ParamType,
|
||||
provider_scoring_fn_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
params: Optional[ScoringFnParams] = None,
|
||||
provider_scoring_fn_id: str | None = None,
|
||||
provider_id: str | None = None,
|
||||
params: ScoringFnParams | None = None,
|
||||
) -> None: ...
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
||||
from typing import Any, Literal, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -14,7 +14,7 @@ from llama_stack.schema_utils import json_schema_type, webmethod
|
|||
|
||||
|
||||
class CommonShieldFields(BaseModel):
|
||||
params: Optional[Dict[str, Any]] = None
|
||||
params: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -34,12 +34,12 @@ class Shield(CommonShieldFields, Resource):
|
|||
|
||||
class ShieldInput(CommonShieldFields):
|
||||
shield_id: str
|
||||
provider_id: Optional[str] = None
|
||||
provider_shield_id: Optional[str] = None
|
||||
provider_id: str | None = None
|
||||
provider_shield_id: str | None = None
|
||||
|
||||
|
||||
class ListShieldsResponse(BaseModel):
|
||||
data: List[Shield]
|
||||
data: list[Shield]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
@ -55,7 +55,7 @@ class Shields(Protocol):
|
|||
async def register_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
provider_shield_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
provider_shield_id: str | None = None,
|
||||
provider_id: str | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> Shield: ...
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Protocol, Union
|
||||
from typing import Any, Protocol
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -28,24 +28,24 @@ class FilteringFunction(Enum):
|
|||
class SyntheticDataGenerationRequest(BaseModel):
|
||||
"""Request to generate synthetic data. A small batch of prompts and a filtering function"""
|
||||
|
||||
dialogs: List[Message]
|
||||
dialogs: list[Message]
|
||||
filtering_function: FilteringFunction = FilteringFunction.none
|
||||
model: Optional[str] = None
|
||||
model: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SyntheticDataGenerationResponse(BaseModel):
|
||||
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."""
|
||||
|
||||
synthetic_data: List[Dict[str, Any]]
|
||||
statistics: Optional[Dict[str, Any]] = None
|
||||
synthetic_data: list[dict[str, Any]]
|
||||
statistics: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class SyntheticDataGeneration(Protocol):
|
||||
@webmethod(route="/synthetic-data-generation/generate")
|
||||
def synthetic_data_generate(
|
||||
self,
|
||||
dialogs: List[Message],
|
||||
dialogs: list[Message],
|
||||
filtering_function: FilteringFunction = FilteringFunction.none,
|
||||
model: Optional[str] = None,
|
||||
) -> Union[SyntheticDataGenerationResponse]: ...
|
||||
model: str | None = None,
|
||||
) -> SyntheticDataGenerationResponse: ...
|
||||
|
|
|
@ -7,18 +7,14 @@
|
|||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Protocol,
|
||||
Union,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.models.llama.datatypes import Primitive
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
@ -37,11 +33,11 @@ class SpanStatus(Enum):
|
|||
class Span(BaseModel):
|
||||
span_id: str
|
||||
trace_id: str
|
||||
parent_span_id: Optional[str] = None
|
||||
parent_span_id: str | None = None
|
||||
name: str
|
||||
start_time: datetime
|
||||
end_time: Optional[datetime] = None
|
||||
attributes: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
end_time: datetime | None = None
|
||||
attributes: dict[str, Any] | None = Field(default_factory=dict)
|
||||
|
||||
def set_attribute(self, key: str, value: Any):
|
||||
if self.attributes is None:
|
||||
|
@ -54,7 +50,7 @@ class Trace(BaseModel):
|
|||
trace_id: str
|
||||
root_span_id: str
|
||||
start_time: datetime
|
||||
end_time: Optional[datetime] = None
|
||||
end_time: datetime | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -78,7 +74,7 @@ class EventCommon(BaseModel):
|
|||
trace_id: str
|
||||
span_id: str
|
||||
timestamp: datetime
|
||||
attributes: Optional[Dict[str, Primitive]] = Field(default_factory=dict)
|
||||
attributes: dict[str, Primitive] | None = Field(default_factory=dict)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -92,15 +88,15 @@ class UnstructuredLogEvent(EventCommon):
|
|||
class MetricEvent(EventCommon):
|
||||
type: Literal[EventType.METRIC.value] = EventType.METRIC.value
|
||||
metric: str # this would be an enum
|
||||
value: Union[int, float]
|
||||
value: int | float
|
||||
unit: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MetricInResponse(BaseModel):
|
||||
metric: str
|
||||
value: Union[int, float]
|
||||
unit: Optional[str] = None
|
||||
value: int | float
|
||||
unit: str | None = None
|
||||
|
||||
|
||||
# This is a short term solution to allow inference API to return metrics
|
||||
|
@ -124,7 +120,7 @@ class MetricInResponse(BaseModel):
|
|||
|
||||
|
||||
class MetricResponseMixin(BaseModel):
|
||||
metrics: Optional[List[MetricInResponse]] = None
|
||||
metrics: list[MetricInResponse] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -137,7 +133,7 @@ class StructuredLogType(Enum):
|
|||
class SpanStartPayload(BaseModel):
|
||||
type: Literal[StructuredLogType.SPAN_START.value] = StructuredLogType.SPAN_START.value
|
||||
name: str
|
||||
parent_span_id: Optional[str] = None
|
||||
parent_span_id: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -147,10 +143,7 @@ class SpanEndPayload(BaseModel):
|
|||
|
||||
|
||||
StructuredLogPayload = Annotated[
|
||||
Union[
|
||||
SpanStartPayload,
|
||||
SpanEndPayload,
|
||||
],
|
||||
SpanStartPayload | SpanEndPayload,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(StructuredLogPayload, name="StructuredLogPayload")
|
||||
|
@ -163,11 +156,7 @@ class StructuredLogEvent(EventCommon):
|
|||
|
||||
|
||||
Event = Annotated[
|
||||
Union[
|
||||
UnstructuredLogEvent,
|
||||
MetricEvent,
|
||||
StructuredLogEvent,
|
||||
],
|
||||
UnstructuredLogEvent | MetricEvent | StructuredLogEvent,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(Event, name="Event")
|
||||
|
@ -184,7 +173,7 @@ class EvalTrace(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class SpanWithStatus(Span):
|
||||
status: Optional[SpanStatus] = None
|
||||
status: SpanStatus | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -203,15 +192,15 @@ class QueryCondition(BaseModel):
|
|||
|
||||
|
||||
class QueryTracesResponse(BaseModel):
|
||||
data: List[Trace]
|
||||
data: list[Trace]
|
||||
|
||||
|
||||
class QuerySpansResponse(BaseModel):
|
||||
data: List[Span]
|
||||
data: list[Span]
|
||||
|
||||
|
||||
class QuerySpanTreeResponse(BaseModel):
|
||||
data: Dict[str, SpanWithStatus]
|
||||
data: dict[str, SpanWithStatus]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
@ -222,10 +211,10 @@ class Telemetry(Protocol):
|
|||
@webmethod(route="/telemetry/traces", method="POST")
|
||||
async def query_traces(
|
||||
self,
|
||||
attribute_filters: Optional[List[QueryCondition]] = None,
|
||||
limit: Optional[int] = 100,
|
||||
offset: Optional[int] = 0,
|
||||
order_by: Optional[List[str]] = None,
|
||||
attribute_filters: list[QueryCondition] | None = None,
|
||||
limit: int | None = 100,
|
||||
offset: int | None = 0,
|
||||
order_by: list[str] | None = None,
|
||||
) -> QueryTracesResponse: ...
|
||||
|
||||
@webmethod(route="/telemetry/traces/{trace_id:path}", method="GET")
|
||||
|
@ -238,23 +227,23 @@ class Telemetry(Protocol):
|
|||
async def get_span_tree(
|
||||
self,
|
||||
span_id: str,
|
||||
attributes_to_return: Optional[List[str]] = None,
|
||||
max_depth: Optional[int] = None,
|
||||
attributes_to_return: list[str] | None = None,
|
||||
max_depth: int | None = None,
|
||||
) -> QuerySpanTreeResponse: ...
|
||||
|
||||
@webmethod(route="/telemetry/spans", method="POST")
|
||||
async def query_spans(
|
||||
self,
|
||||
attribute_filters: List[QueryCondition],
|
||||
attributes_to_return: List[str],
|
||||
max_depth: Optional[int] = None,
|
||||
attribute_filters: list[QueryCondition],
|
||||
attributes_to_return: list[str],
|
||||
max_depth: int | None = None,
|
||||
) -> QuerySpansResponse: ...
|
||||
|
||||
@webmethod(route="/telemetry/spans/export", method="POST")
|
||||
async def save_spans_to_dataset(
|
||||
self,
|
||||
attribute_filters: List[QueryCondition],
|
||||
attributes_to_save: List[str],
|
||||
attribute_filters: list[QueryCondition],
|
||||
attributes_to_save: list[str],
|
||||
dataset_id: str,
|
||||
max_depth: Optional[int] = None,
|
||||
max_depth: int | None = None,
|
||||
) -> None: ...
|
||||
|
|
|
@ -5,10 +5,10 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated, Protocol, runtime_checkable
|
||||
from typing_extensions import Protocol, runtime_checkable
|
||||
|
||||
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
|
@ -29,13 +29,13 @@ class RAGDocument(BaseModel):
|
|||
document_id: str
|
||||
content: InterleavedContent | URL
|
||||
mime_type: str | None = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RAGQueryResult(BaseModel):
|
||||
content: Optional[InterleavedContent] = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
content: InterleavedContent | None = None
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -59,10 +59,7 @@ class LLMRAGQueryGeneratorConfig(BaseModel):
|
|||
|
||||
|
||||
RAGQueryGeneratorConfig = Annotated[
|
||||
Union[
|
||||
DefaultRAGQueryGeneratorConfig,
|
||||
LLMRAGQueryGeneratorConfig,
|
||||
],
|
||||
DefaultRAGQueryGeneratorConfig | LLMRAGQueryGeneratorConfig,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig")
|
||||
|
@ -83,7 +80,7 @@ class RAGToolRuntime(Protocol):
|
|||
@webmethod(route="/tool-runtime/rag-tool/insert", method="POST")
|
||||
async def insert(
|
||||
self,
|
||||
documents: List[RAGDocument],
|
||||
documents: list[RAGDocument],
|
||||
vector_db_id: str,
|
||||
chunk_size_in_tokens: int = 512,
|
||||
) -> None:
|
||||
|
@ -94,8 +91,8 @@ class RAGToolRuntime(Protocol):
|
|||
async def query(
|
||||
self,
|
||||
content: InterleavedContent,
|
||||
vector_db_ids: List[str],
|
||||
query_config: Optional[RAGQueryConfig] = None,
|
||||
vector_db_ids: list[str],
|
||||
query_config: RAGQueryConfig | None = None,
|
||||
) -> RAGQueryResult:
|
||||
"""Query the RAG system for context; typically invoked by the agent"""
|
||||
...
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Protocol, runtime_checkable
|
||||
|
@ -24,7 +24,7 @@ class ToolParameter(BaseModel):
|
|||
parameter_type: str
|
||||
description: str
|
||||
required: bool = Field(default=True)
|
||||
default: Optional[Any] = None
|
||||
default: Any | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -40,39 +40,39 @@ class Tool(Resource):
|
|||
toolgroup_id: str
|
||||
tool_host: ToolHost
|
||||
description: str
|
||||
parameters: List[ToolParameter]
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
parameters: list[ToolParameter]
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolDef(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
parameters: Optional[List[ToolParameter]] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
description: str | None = None
|
||||
parameters: list[ToolParameter] | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolGroupInput(BaseModel):
|
||||
toolgroup_id: str
|
||||
provider_id: str
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
mcp_endpoint: Optional[URL] = None
|
||||
args: dict[str, Any] | None = None
|
||||
mcp_endpoint: URL | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolGroup(Resource):
|
||||
type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value
|
||||
mcp_endpoint: Optional[URL] = None
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
mcp_endpoint: URL | None = None
|
||||
args: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolInvocationResult(BaseModel):
|
||||
content: Optional[InterleavedContent] = None
|
||||
error_message: Optional[str] = None
|
||||
error_code: Optional[int] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
content: InterleavedContent | None = None
|
||||
error_message: str | None = None
|
||||
error_code: int | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ToolStore(Protocol):
|
||||
|
@ -81,11 +81,11 @@ class ToolStore(Protocol):
|
|||
|
||||
|
||||
class ListToolGroupsResponse(BaseModel):
|
||||
data: List[ToolGroup]
|
||||
data: list[ToolGroup]
|
||||
|
||||
|
||||
class ListToolsResponse(BaseModel):
|
||||
data: List[Tool]
|
||||
data: list[Tool]
|
||||
|
||||
|
||||
class ListToolDefsResponse(BaseModel):
|
||||
|
@ -100,8 +100,8 @@ class ToolGroups(Protocol):
|
|||
self,
|
||||
toolgroup_id: str,
|
||||
provider_id: str,
|
||||
mcp_endpoint: Optional[URL] = None,
|
||||
args: Optional[Dict[str, Any]] = None,
|
||||
mcp_endpoint: URL | None = None,
|
||||
args: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Register a tool group"""
|
||||
...
|
||||
|
@ -118,7 +118,7 @@ class ToolGroups(Protocol):
|
|||
...
|
||||
|
||||
@webmethod(route="/tools", method="GET")
|
||||
async def list_tools(self, toolgroup_id: Optional[str] = None) -> ListToolsResponse:
|
||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
||||
"""List tools with optional tool group"""
|
||||
...
|
||||
|
||||
|
@ -151,10 +151,10 @@ class ToolRuntime(Protocol):
|
|||
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
|
||||
@webmethod(route="/tool-runtime/list-tools", method="GET")
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||
) -> ListToolDefsResponse: ...
|
||||
|
||||
@webmethod(route="/tool-runtime/invoke", method="POST")
|
||||
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
|
||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
||||
"""Run a tool with the given arguments"""
|
||||
...
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List, Literal, Optional, Protocol, runtime_checkable
|
||||
from typing import Literal, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -33,11 +33,11 @@ class VectorDBInput(BaseModel):
|
|||
vector_db_id: str
|
||||
embedding_model: str
|
||||
embedding_dimension: int
|
||||
provider_vector_db_id: Optional[str] = None
|
||||
provider_vector_db_id: str | None = None
|
||||
|
||||
|
||||
class ListVectorDBsResponse(BaseModel):
|
||||
data: List[VectorDB]
|
||||
data: list[VectorDB]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
@ -57,9 +57,9 @@ class VectorDBs(Protocol):
|
|||
self,
|
||||
vector_db_id: str,
|
||||
embedding_model: str,
|
||||
embedding_dimension: Optional[int] = 384,
|
||||
provider_id: Optional[str] = None,
|
||||
provider_vector_db_id: Optional[str] = None,
|
||||
embedding_dimension: int | None = 384,
|
||||
provider_id: str | None = None,
|
||||
provider_vector_db_id: str | None = None,
|
||||
) -> VectorDB: ...
|
||||
|
||||
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE")
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
@ -20,17 +20,17 @@ from llama_stack.schema_utils import json_schema_type, webmethod
|
|||
|
||||
class Chunk(BaseModel):
|
||||
content: InterleavedContent
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class QueryChunksResponse(BaseModel):
|
||||
chunks: List[Chunk]
|
||||
scores: List[float]
|
||||
chunks: list[Chunk]
|
||||
scores: list[float]
|
||||
|
||||
|
||||
class VectorDBStore(Protocol):
|
||||
def get_vector_db(self, vector_db_id: str) -> Optional[VectorDB]: ...
|
||||
def get_vector_db(self, vector_db_id: str) -> VectorDB | None: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
@ -44,8 +44,8 @@ class VectorIO(Protocol):
|
|||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
chunks: List[Chunk],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None: ...
|
||||
|
||||
@webmethod(route="/vector-io/query", method="POST")
|
||||
|
@ -53,5 +53,5 @@ class VectorIO(Protocol):
|
|||
self,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse: ...
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue