mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
chore: enable pyupgrade fixes (#1806)
# What does this PR do? The goal of this PR is code base modernization. Schema reflection code needed a minor adjustment to handle UnionTypes and collections.abc.AsyncIterator. (Both are preferred for latest Python releases.) Note to reviewers: almost all changes here are automatically generated by pyupgrade. Some additional unused imports were cleaned up. The only change worth of note can be found under `docs/openapi_generator` and `llama_stack/strong_typing/schema.py` where reflection code was updated to deal with "newer" types. Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
ffe3d0b2cd
commit
9e6561a1ec
319 changed files with 2843 additions and 3033 deletions
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import ipaddress
|
import ipaddress
|
||||||
|
import types
|
||||||
import typing
|
import typing
|
||||||
from dataclasses import make_dataclass
|
from dataclasses import make_dataclass
|
||||||
from typing import Any, Dict, Set, Union
|
from typing import Any, Dict, Set, Union
|
||||||
|
@ -189,7 +190,7 @@ class ContentBuilder:
|
||||||
else:
|
else:
|
||||||
return "application/json"
|
return "application/json"
|
||||||
|
|
||||||
if typing.get_origin(payload_type) is typing.Union:
|
if typing.get_origin(payload_type) in (typing.Union, types.UnionType):
|
||||||
media_types = []
|
media_types = []
|
||||||
item_types = []
|
item_types = []
|
||||||
for x in typing.get_args(payload_type):
|
for x in typing.get_args(payload_type):
|
||||||
|
|
|
@ -4,20 +4,10 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import (
|
from typing import Annotated, Any, Literal, Protocol, runtime_checkable
|
||||||
Annotated,
|
|
||||||
Any,
|
|
||||||
AsyncIterator,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Literal,
|
|
||||||
Optional,
|
|
||||||
Protocol,
|
|
||||||
Union,
|
|
||||||
runtime_checkable,
|
|
||||||
)
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
@ -79,8 +69,8 @@ class StepCommon(BaseModel):
|
||||||
|
|
||||||
turn_id: str
|
turn_id: str
|
||||||
step_id: str
|
step_id: str
|
||||||
started_at: Optional[datetime] = None
|
started_at: datetime | None = None
|
||||||
completed_at: Optional[datetime] = None
|
completed_at: datetime | None = None
|
||||||
|
|
||||||
|
|
||||||
class StepType(Enum):
|
class StepType(Enum):
|
||||||
|
@ -120,8 +110,8 @@ class ToolExecutionStep(StepCommon):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
step_type: Literal[StepType.tool_execution.value] = StepType.tool_execution.value
|
step_type: Literal[StepType.tool_execution.value] = StepType.tool_execution.value
|
||||||
tool_calls: List[ToolCall]
|
tool_calls: list[ToolCall]
|
||||||
tool_responses: List[ToolResponse]
|
tool_responses: list[ToolResponse]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -132,7 +122,7 @@ class ShieldCallStep(StepCommon):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value
|
step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value
|
||||||
violation: Optional[SafetyViolation]
|
violation: SafetyViolation | None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -150,12 +140,7 @@ class MemoryRetrievalStep(StepCommon):
|
||||||
|
|
||||||
|
|
||||||
Step = Annotated[
|
Step = Annotated[
|
||||||
Union[
|
InferenceStep | ToolExecutionStep | ShieldCallStep | MemoryRetrievalStep,
|
||||||
InferenceStep,
|
|
||||||
ToolExecutionStep,
|
|
||||||
ShieldCallStep,
|
|
||||||
MemoryRetrievalStep,
|
|
||||||
],
|
|
||||||
Field(discriminator="step_type"),
|
Field(discriminator="step_type"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -166,18 +151,13 @@ class Turn(BaseModel):
|
||||||
|
|
||||||
turn_id: str
|
turn_id: str
|
||||||
session_id: str
|
session_id: str
|
||||||
input_messages: List[
|
input_messages: list[UserMessage | ToolResponseMessage]
|
||||||
Union[
|
steps: list[Step]
|
||||||
UserMessage,
|
|
||||||
ToolResponseMessage,
|
|
||||||
]
|
|
||||||
]
|
|
||||||
steps: List[Step]
|
|
||||||
output_message: CompletionMessage
|
output_message: CompletionMessage
|
||||||
output_attachments: Optional[List[Attachment]] = Field(default_factory=list)
|
output_attachments: list[Attachment] | None = Field(default_factory=list)
|
||||||
|
|
||||||
started_at: datetime
|
started_at: datetime
|
||||||
completed_at: Optional[datetime] = None
|
completed_at: datetime | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -186,34 +166,31 @@ class Session(BaseModel):
|
||||||
|
|
||||||
session_id: str
|
session_id: str
|
||||||
session_name: str
|
session_name: str
|
||||||
turns: List[Turn]
|
turns: list[Turn]
|
||||||
started_at: datetime
|
started_at: datetime
|
||||||
|
|
||||||
|
|
||||||
class AgentToolGroupWithArgs(BaseModel):
|
class AgentToolGroupWithArgs(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
args: Dict[str, Any]
|
args: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
AgentToolGroup = Union[
|
AgentToolGroup = str | AgentToolGroupWithArgs
|
||||||
str,
|
|
||||||
AgentToolGroupWithArgs,
|
|
||||||
]
|
|
||||||
register_schema(AgentToolGroup, name="AgentTool")
|
register_schema(AgentToolGroup, name="AgentTool")
|
||||||
|
|
||||||
|
|
||||||
class AgentConfigCommon(BaseModel):
|
class AgentConfigCommon(BaseModel):
|
||||||
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
|
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
|
||||||
|
|
||||||
input_shields: Optional[List[str]] = Field(default_factory=list)
|
input_shields: list[str] | None = Field(default_factory=list)
|
||||||
output_shields: Optional[List[str]] = Field(default_factory=list)
|
output_shields: list[str] | None = Field(default_factory=list)
|
||||||
toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list)
|
toolgroups: list[AgentToolGroup] | None = Field(default_factory=list)
|
||||||
client_tools: Optional[List[ToolDef]] = Field(default_factory=list)
|
client_tools: list[ToolDef] | None = Field(default_factory=list)
|
||||||
tool_choice: Optional[ToolChoice] = Field(default=None, deprecated="use tool_config instead")
|
tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead")
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None, deprecated="use tool_config instead")
|
tool_prompt_format: ToolPromptFormat | None = Field(default=None, deprecated="use tool_config instead")
|
||||||
tool_config: Optional[ToolConfig] = Field(default=None)
|
tool_config: ToolConfig | None = Field(default=None)
|
||||||
|
|
||||||
max_infer_iters: Optional[int] = 10
|
max_infer_iters: int | None = 10
|
||||||
|
|
||||||
def model_post_init(self, __context):
|
def model_post_init(self, __context):
|
||||||
if self.tool_config:
|
if self.tool_config:
|
||||||
|
@ -243,9 +220,9 @@ class AgentConfig(AgentConfigCommon):
|
||||||
|
|
||||||
model: str
|
model: str
|
||||||
instructions: str
|
instructions: str
|
||||||
name: Optional[str] = None
|
name: str | None = None
|
||||||
enable_session_persistence: Optional[bool] = False
|
enable_session_persistence: bool | None = False
|
||||||
response_format: Optional[ResponseFormat] = None
|
response_format: ResponseFormat | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -257,16 +234,16 @@ class Agent(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ListAgentsResponse(BaseModel):
|
class ListAgentsResponse(BaseModel):
|
||||||
data: List[Agent]
|
data: list[Agent]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ListAgentSessionsResponse(BaseModel):
|
class ListAgentSessionsResponse(BaseModel):
|
||||||
data: List[Session]
|
data: list[Session]
|
||||||
|
|
||||||
|
|
||||||
class AgentConfigOverridablePerTurn(AgentConfigCommon):
|
class AgentConfigOverridablePerTurn(AgentConfigCommon):
|
||||||
instructions: Optional[str] = None
|
instructions: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class AgentTurnResponseEventType(Enum):
|
class AgentTurnResponseEventType(Enum):
|
||||||
|
@ -284,7 +261,7 @@ class AgentTurnResponseStepStartPayload(BaseModel):
|
||||||
event_type: Literal[AgentTurnResponseEventType.step_start.value] = AgentTurnResponseEventType.step_start.value
|
event_type: Literal[AgentTurnResponseEventType.step_start.value] = AgentTurnResponseEventType.step_start.value
|
||||||
step_type: StepType
|
step_type: StepType
|
||||||
step_id: str
|
step_id: str
|
||||||
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
metadata: dict[str, Any] | None = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -327,14 +304,12 @@ class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
AgentTurnResponseEventPayload = Annotated[
|
AgentTurnResponseEventPayload = Annotated[
|
||||||
Union[
|
AgentTurnResponseStepStartPayload
|
||||||
AgentTurnResponseStepStartPayload,
|
| AgentTurnResponseStepProgressPayload
|
||||||
AgentTurnResponseStepProgressPayload,
|
| AgentTurnResponseStepCompletePayload
|
||||||
AgentTurnResponseStepCompletePayload,
|
| AgentTurnResponseTurnStartPayload
|
||||||
AgentTurnResponseTurnStartPayload,
|
| AgentTurnResponseTurnCompletePayload
|
||||||
AgentTurnResponseTurnCompletePayload,
|
| AgentTurnResponseTurnAwaitingInputPayload,
|
||||||
AgentTurnResponseTurnAwaitingInputPayload,
|
|
||||||
],
|
|
||||||
Field(discriminator="event_type"),
|
Field(discriminator="event_type"),
|
||||||
]
|
]
|
||||||
register_schema(AgentTurnResponseEventPayload, name="AgentTurnResponseEventPayload")
|
register_schema(AgentTurnResponseEventPayload, name="AgentTurnResponseEventPayload")
|
||||||
|
@ -363,18 +338,13 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
||||||
# TODO: figure out how we can simplify this and make why
|
# TODO: figure out how we can simplify this and make why
|
||||||
# ToolResponseMessage needs to be here (it is function call
|
# ToolResponseMessage needs to be here (it is function call
|
||||||
# execution from outside the system)
|
# execution from outside the system)
|
||||||
messages: List[
|
messages: list[UserMessage | ToolResponseMessage]
|
||||||
Union[
|
|
||||||
UserMessage,
|
|
||||||
ToolResponseMessage,
|
|
||||||
]
|
|
||||||
]
|
|
||||||
|
|
||||||
documents: Optional[List[Document]] = None
|
documents: list[Document] | None = None
|
||||||
toolgroups: Optional[List[AgentToolGroup]] = None
|
toolgroups: list[AgentToolGroup] | None = None
|
||||||
|
|
||||||
stream: Optional[bool] = False
|
stream: bool | None = False
|
||||||
tool_config: Optional[ToolConfig] = None
|
tool_config: ToolConfig | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -382,8 +352,8 @@ class AgentTurnResumeRequest(BaseModel):
|
||||||
agent_id: str
|
agent_id: str
|
||||||
session_id: str
|
session_id: str
|
||||||
turn_id: str
|
turn_id: str
|
||||||
tool_responses: List[ToolResponse]
|
tool_responses: list[ToolResponse]
|
||||||
stream: Optional[bool] = False
|
stream: bool | None = False
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -429,17 +399,12 @@ class Agents(Protocol):
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
messages: List[
|
messages: list[UserMessage | ToolResponseMessage],
|
||||||
Union[
|
stream: bool | None = False,
|
||||||
UserMessage,
|
documents: list[Document] | None = None,
|
||||||
ToolResponseMessage,
|
toolgroups: list[AgentToolGroup] | None = None,
|
||||||
]
|
tool_config: ToolConfig | None = None,
|
||||||
],
|
) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]:
|
||||||
stream: Optional[bool] = False,
|
|
||||||
documents: Optional[List[Document]] = None,
|
|
||||||
toolgroups: Optional[List[AgentToolGroup]] = None,
|
|
||||||
tool_config: Optional[ToolConfig] = None,
|
|
||||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
|
|
||||||
"""Create a new turn for an agent.
|
"""Create a new turn for an agent.
|
||||||
|
|
||||||
:param agent_id: The ID of the agent to create the turn for.
|
:param agent_id: The ID of the agent to create the turn for.
|
||||||
|
@ -463,9 +428,9 @@ class Agents(Protocol):
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
turn_id: str,
|
turn_id: str,
|
||||||
tool_responses: List[ToolResponse],
|
tool_responses: list[ToolResponse],
|
||||||
stream: Optional[bool] = False,
|
stream: bool | None = False,
|
||||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
|
) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]:
|
||||||
"""Resume an agent turn with executed tool call responses.
|
"""Resume an agent turn with executed tool call responses.
|
||||||
|
|
||||||
When a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.
|
When a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.
|
||||||
|
@ -538,7 +503,7 @@ class Agents(Protocol):
|
||||||
self,
|
self,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
turn_ids: Optional[List[str]] = None,
|
turn_ids: list[str] | None = None,
|
||||||
) -> Session:
|
) -> Session:
|
||||||
"""Retrieve an agent session by its ID.
|
"""Retrieve an agent session by its ID.
|
||||||
|
|
||||||
|
@ -623,14 +588,14 @@ class Agents(Protocol):
|
||||||
@webmethod(route="/openai/v1/responses", method="POST")
|
@webmethod(route="/openai/v1/responses", method="POST")
|
||||||
async def create_openai_response(
|
async def create_openai_response(
|
||||||
self,
|
self,
|
||||||
input: Union[str, List[OpenAIResponseInputMessage]],
|
input: str | list[OpenAIResponseInputMessage],
|
||||||
model: str,
|
model: str,
|
||||||
previous_response_id: Optional[str] = None,
|
previous_response_id: str | None = None,
|
||||||
store: Optional[bool] = True,
|
store: bool | None = True,
|
||||||
stream: Optional[bool] = False,
|
stream: bool | None = False,
|
||||||
temperature: Optional[float] = None,
|
temperature: float | None = None,
|
||||||
tools: Optional[List[OpenAIResponseInputTool]] = None,
|
tools: list[OpenAIResponseInputTool] | None = None,
|
||||||
) -> Union[OpenAIResponseObject, AsyncIterator[OpenAIResponseObjectStream]]:
|
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
|
||||||
"""Create a new OpenAI response.
|
"""Create a new OpenAI response.
|
||||||
|
|
||||||
:param input: Input message(s) to create the response.
|
:param input: Input message(s) to create the response.
|
||||||
|
|
|
@ -4,10 +4,9 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import List, Literal, Optional, Union
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema
|
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||||
|
|
||||||
|
@ -25,7 +24,7 @@ class OpenAIResponseOutputMessageContentOutputText(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
OpenAIResponseOutputMessageContent = Annotated[
|
OpenAIResponseOutputMessageContent = Annotated[
|
||||||
Union[OpenAIResponseOutputMessageContentOutputText,],
|
OpenAIResponseOutputMessageContentOutputText,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMessageContent")
|
register_schema(OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMessageContent")
|
||||||
|
@ -34,7 +33,7 @@ register_schema(OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMe
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIResponseOutputMessage(BaseModel):
|
class OpenAIResponseOutputMessage(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
content: List[OpenAIResponseOutputMessageContent]
|
content: list[OpenAIResponseOutputMessageContent]
|
||||||
role: Literal["assistant"] = "assistant"
|
role: Literal["assistant"] = "assistant"
|
||||||
status: str
|
status: str
|
||||||
type: Literal["message"] = "message"
|
type: Literal["message"] = "message"
|
||||||
|
@ -48,10 +47,7 @@ class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
OpenAIResponseOutput = Annotated[
|
OpenAIResponseOutput = Annotated[
|
||||||
Union[
|
OpenAIResponseOutputMessage | OpenAIResponseOutputMessageWebSearchToolCall,
|
||||||
OpenAIResponseOutputMessage,
|
|
||||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
|
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
|
||||||
|
@ -60,18 +56,18 @@ register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIResponseObject(BaseModel):
|
class OpenAIResponseObject(BaseModel):
|
||||||
created_at: int
|
created_at: int
|
||||||
error: Optional[OpenAIResponseError] = None
|
error: OpenAIResponseError | None = None
|
||||||
id: str
|
id: str
|
||||||
model: str
|
model: str
|
||||||
object: Literal["response"] = "response"
|
object: Literal["response"] = "response"
|
||||||
output: List[OpenAIResponseOutput]
|
output: list[OpenAIResponseOutput]
|
||||||
parallel_tool_calls: bool = False
|
parallel_tool_calls: bool = False
|
||||||
previous_response_id: Optional[str] = None
|
previous_response_id: str | None = None
|
||||||
status: str
|
status: str
|
||||||
temperature: Optional[float] = None
|
temperature: float | None = None
|
||||||
top_p: Optional[float] = None
|
top_p: float | None = None
|
||||||
truncation: Optional[str] = None
|
truncation: str | None = None
|
||||||
user: Optional[str] = None
|
user: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -87,10 +83,7 @@ class OpenAIResponseObjectStreamResponseCompleted(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
OpenAIResponseObjectStream = Annotated[
|
OpenAIResponseObjectStream = Annotated[
|
||||||
Union[
|
OpenAIResponseObjectStreamResponseCreated | OpenAIResponseObjectStreamResponseCompleted,
|
||||||
OpenAIResponseObjectStreamResponseCreated,
|
|
||||||
OpenAIResponseObjectStreamResponseCompleted,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(OpenAIResponseObjectStream, name="OpenAIResponseObjectStream")
|
register_schema(OpenAIResponseObjectStream, name="OpenAIResponseObjectStream")
|
||||||
|
@ -107,12 +100,12 @@ class OpenAIResponseInputMessageContentImage(BaseModel):
|
||||||
detail: Literal["low"] | Literal["high"] | Literal["auto"] = "auto"
|
detail: Literal["low"] | Literal["high"] | Literal["auto"] = "auto"
|
||||||
type: Literal["input_image"] = "input_image"
|
type: Literal["input_image"] = "input_image"
|
||||||
# TODO: handle file_id
|
# TODO: handle file_id
|
||||||
image_url: Optional[str] = None
|
image_url: str | None = None
|
||||||
|
|
||||||
|
|
||||||
# TODO: handle file content types
|
# TODO: handle file content types
|
||||||
OpenAIResponseInputMessageContent = Annotated[
|
OpenAIResponseInputMessageContent = Annotated[
|
||||||
Union[OpenAIResponseInputMessageContentText, OpenAIResponseInputMessageContentImage],
|
OpenAIResponseInputMessageContentText | OpenAIResponseInputMessageContentImage,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(OpenAIResponseInputMessageContent, name="OpenAIResponseInputMessageContent")
|
register_schema(OpenAIResponseInputMessageContent, name="OpenAIResponseInputMessageContent")
|
||||||
|
@ -120,21 +113,21 @@ register_schema(OpenAIResponseInputMessageContent, name="OpenAIResponseInputMess
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIResponseInputMessage(BaseModel):
|
class OpenAIResponseInputMessage(BaseModel):
|
||||||
content: Union[str, List[OpenAIResponseInputMessageContent]]
|
content: str | list[OpenAIResponseInputMessageContent]
|
||||||
role: Literal["system"] | Literal["developer"] | Literal["user"] | Literal["assistant"]
|
role: Literal["system"] | Literal["developer"] | Literal["user"] | Literal["assistant"]
|
||||||
type: Optional[Literal["message"]] = "message"
|
type: Literal["message"] | None = "message"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIResponseInputToolWebSearch(BaseModel):
|
class OpenAIResponseInputToolWebSearch(BaseModel):
|
||||||
type: Literal["web_search"] | Literal["web_search_preview_2025_03_11"] = "web_search"
|
type: Literal["web_search"] | Literal["web_search_preview_2025_03_11"] = "web_search"
|
||||||
# TODO: actually use search_context_size somewhere...
|
# TODO: actually use search_context_size somewhere...
|
||||||
search_context_size: Optional[str] = Field(default="medium", pattern="^low|medium|high$")
|
search_context_size: str | None = Field(default="medium", pattern="^low|medium|high$")
|
||||||
# TODO: add user_location
|
# TODO: add user_location
|
||||||
|
|
||||||
|
|
||||||
OpenAIResponseInputTool = Annotated[
|
OpenAIResponseInputTool = Annotated[
|
||||||
Union[OpenAIResponseInputToolWebSearch,],
|
OpenAIResponseInputToolWebSearch,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool")
|
register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool")
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import List, Optional, Protocol, runtime_checkable
|
from typing import Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_stack.apis.common.job_types import Job
|
from llama_stack.apis.common.job_types import Job
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
|
@ -34,22 +34,22 @@ class BatchInference(Protocol):
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
content_batch: List[InterleavedContent],
|
content_batch: list[InterleavedContent],
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: SamplingParams | None = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: ResponseFormat | None = None,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
) -> Job: ...
|
) -> Job: ...
|
||||||
|
|
||||||
@webmethod(route="/batch-inference/chat-completion", method="POST")
|
@webmethod(route="/batch-inference/chat-completion", method="POST")
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages_batch: List[List[Message]],
|
messages_batch: list[list[Message]],
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: SamplingParams | None = None,
|
||||||
# zero-shot tool definitions as input to the model
|
# zero-shot tool definitions as input to the model
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: list[ToolDefinition] | None = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: ToolPromptFormat | None = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: ResponseFormat | None = None,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
) -> Job: ...
|
) -> Job: ...
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
from typing import Any, Literal, Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
@ -13,8 +13,8 @@ from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
class CommonBenchmarkFields(BaseModel):
|
class CommonBenchmarkFields(BaseModel):
|
||||||
dataset_id: str
|
dataset_id: str
|
||||||
scoring_functions: List[str]
|
scoring_functions: list[str]
|
||||||
metadata: Dict[str, Any] = Field(
|
metadata: dict[str, Any] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
description="Metadata for this evaluation task",
|
description="Metadata for this evaluation task",
|
||||||
)
|
)
|
||||||
|
@ -35,12 +35,12 @@ class Benchmark(CommonBenchmarkFields, Resource):
|
||||||
|
|
||||||
class BenchmarkInput(CommonBenchmarkFields, BaseModel):
|
class BenchmarkInput(CommonBenchmarkFields, BaseModel):
|
||||||
benchmark_id: str
|
benchmark_id: str
|
||||||
provider_id: Optional[str] = None
|
provider_id: str | None = None
|
||||||
provider_benchmark_id: Optional[str] = None
|
provider_benchmark_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class ListBenchmarksResponse(BaseModel):
|
class ListBenchmarksResponse(BaseModel):
|
||||||
data: List[Benchmark]
|
data: list[Benchmark]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
@ -59,8 +59,8 @@ class Benchmarks(Protocol):
|
||||||
self,
|
self,
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
scoring_functions: List[str],
|
scoring_functions: list[str],
|
||||||
provider_benchmark_id: Optional[str] = None,
|
provider_benchmark_id: str | None = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: str | None = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Annotated, List, Literal, Optional, Union
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
|
@ -26,9 +26,9 @@ class _URLOrData(BaseModel):
|
||||||
:param data: base64 encoded image data as string
|
:param data: base64 encoded image data as string
|
||||||
"""
|
"""
|
||||||
|
|
||||||
url: Optional[URL] = None
|
url: URL | None = None
|
||||||
# data is a base64 encoded string, hint with contentEncoding=base64
|
# data is a base64 encoded string, hint with contentEncoding=base64
|
||||||
data: Optional[str] = Field(contentEncoding="base64", default=None)
|
data: str | None = Field(contentEncoding="base64", default=None)
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -64,13 +64,13 @@ class TextContentItem(BaseModel):
|
||||||
|
|
||||||
# other modalities can be added here
|
# other modalities can be added here
|
||||||
InterleavedContentItem = Annotated[
|
InterleavedContentItem = Annotated[
|
||||||
Union[ImageContentItem, TextContentItem],
|
ImageContentItem | TextContentItem,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(InterleavedContentItem, name="InterleavedContentItem")
|
register_schema(InterleavedContentItem, name="InterleavedContentItem")
|
||||||
|
|
||||||
# accept a single "str" as a special case since it is common
|
# accept a single "str" as a special case since it is common
|
||||||
InterleavedContent = Union[str, InterleavedContentItem, List[InterleavedContentItem]]
|
InterleavedContent = str | InterleavedContentItem | list[InterleavedContentItem]
|
||||||
register_schema(InterleavedContent, name="InterleavedContent")
|
register_schema(InterleavedContent, name="InterleavedContent")
|
||||||
|
|
||||||
|
|
||||||
|
@ -100,13 +100,13 @@ class ToolCallDelta(BaseModel):
|
||||||
# you either send an in-progress tool call so the client can stream a long
|
# you either send an in-progress tool call so the client can stream a long
|
||||||
# code generation or you send the final parsed tool call at the end of the
|
# code generation or you send the final parsed tool call at the end of the
|
||||||
# stream
|
# stream
|
||||||
tool_call: Union[str, ToolCall]
|
tool_call: str | ToolCall
|
||||||
parse_status: ToolCallParseStatus
|
parse_status: ToolCallParseStatus
|
||||||
|
|
||||||
|
|
||||||
# streaming completions send a stream of ContentDeltas
|
# streaming completions send a stream of ContentDeltas
|
||||||
ContentDelta = Annotated[
|
ContentDelta = Annotated[
|
||||||
Union[TextDelta, ImageDelta, ToolCallDelta],
|
TextDelta | ImageDelta | ToolCallDelta,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(ContentDelta, name="ContentDelta")
|
register_schema(ContentDelta, name="ContentDelta")
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -25,6 +25,6 @@ class RestAPIMethod(Enum):
|
||||||
class RestAPIExecutionConfig(BaseModel):
|
class RestAPIExecutionConfig(BaseModel):
|
||||||
url: URL
|
url: URL
|
||||||
method: RestAPIMethod
|
method: RestAPIMethod
|
||||||
params: Optional[Dict[str, Any]] = None
|
params: dict[str, Any] | None = None
|
||||||
headers: Optional[Dict[str, Any]] = None
|
headers: dict[str, Any] | None = None
|
||||||
body: Optional[Dict[str, Any]] = None
|
body: dict[str, Any] | None = None
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, List
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -19,5 +19,5 @@ class PaginatedResponse(BaseModel):
|
||||||
:param has_more: Whether there are more items available after this set
|
:param has_more: Whether there are more items available after this set
|
||||||
"""
|
"""
|
||||||
|
|
||||||
data: List[Dict[str, Any]]
|
data: list[dict[str, Any]]
|
||||||
has_more: bool
|
has_more: bool
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -27,4 +26,4 @@ class Checkpoint(BaseModel):
|
||||||
epoch: int
|
epoch: int
|
||||||
post_training_job_id: str
|
post_training_job_id: str
|
||||||
path: str
|
path: str
|
||||||
training_metrics: Optional[PostTrainingMetric] = None
|
training_metrics: PostTrainingMetric | None = None
|
||||||
|
|
|
@ -4,10 +4,9 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Literal, Union
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema
|
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||||
|
|
||||||
|
@ -73,18 +72,16 @@ class DialogType(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
ParamType = Annotated[
|
ParamType = Annotated[
|
||||||
Union[
|
StringType
|
||||||
StringType,
|
| NumberType
|
||||||
NumberType,
|
| BooleanType
|
||||||
BooleanType,
|
| ArrayType
|
||||||
ArrayType,
|
| ObjectType
|
||||||
ObjectType,
|
| JsonType
|
||||||
JsonType,
|
| UnionType
|
||||||
UnionType,
|
| ChatCompletionInputType
|
||||||
ChatCompletionInputType,
|
| CompletionInputType
|
||||||
CompletionInputType,
|
| AgentTurnInputType,
|
||||||
AgentTurnInputType,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(ParamType, name="ParamType")
|
register_schema(ParamType, name="ParamType")
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
from typing import Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_stack.apis.common.responses import PaginatedResponse
|
from llama_stack.apis.common.responses import PaginatedResponse
|
||||||
from llama_stack.apis.datasets import Dataset
|
from llama_stack.apis.datasets import Dataset
|
||||||
|
@ -24,8 +24,8 @@ class DatasetIO(Protocol):
|
||||||
async def iterrows(
|
async def iterrows(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
start_index: Optional[int] = None,
|
start_index: int | None = None,
|
||||||
limit: Optional[int] = None,
|
limit: int | None = None,
|
||||||
) -> PaginatedResponse:
|
) -> PaginatedResponse:
|
||||||
"""Get a paginated list of rows from a dataset.
|
"""Get a paginated list of rows from a dataset.
|
||||||
|
|
||||||
|
@ -44,4 +44,4 @@ class DatasetIO(Protocol):
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST")
|
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST")
|
||||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ...
|
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None: ...
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Protocol, Union
|
from typing import Annotated, Any, Literal, Protocol
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
@ -81,11 +81,11 @@ class RowsDataSource(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: Literal["rows"] = "rows"
|
type: Literal["rows"] = "rows"
|
||||||
rows: List[Dict[str, Any]]
|
rows: list[dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
DataSource = Annotated[
|
DataSource = Annotated[
|
||||||
Union[URIDataSource, RowsDataSource],
|
URIDataSource | RowsDataSource,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(DataSource, name="DataSource")
|
register_schema(DataSource, name="DataSource")
|
||||||
|
@ -98,7 +98,7 @@ class CommonDatasetFields(BaseModel):
|
||||||
|
|
||||||
purpose: DatasetPurpose
|
purpose: DatasetPurpose
|
||||||
source: DataSource
|
source: DataSource
|
||||||
metadata: Dict[str, Any] = Field(
|
metadata: dict[str, Any] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
description="Any additional metadata for this dataset",
|
description="Any additional metadata for this dataset",
|
||||||
)
|
)
|
||||||
|
@ -122,7 +122,7 @@ class DatasetInput(CommonDatasetFields, BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class ListDatasetsResponse(BaseModel):
|
class ListDatasetsResponse(BaseModel):
|
||||||
data: List[Dataset]
|
data: list[Dataset]
|
||||||
|
|
||||||
|
|
||||||
class Datasets(Protocol):
|
class Datasets(Protocol):
|
||||||
|
@ -131,8 +131,8 @@ class Datasets(Protocol):
|
||||||
self,
|
self,
|
||||||
purpose: DatasetPurpose,
|
purpose: DatasetPurpose,
|
||||||
source: DataSource,
|
source: DataSource,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
dataset_id: Optional[str] = None,
|
dataset_id: str | None = None,
|
||||||
) -> Dataset:
|
) -> Dataset:
|
||||||
"""
|
"""
|
||||||
Register a new dataset.
|
Register a new dataset.
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -54,4 +53,4 @@ class Error(BaseModel):
|
||||||
status: int
|
status: int
|
||||||
title: str
|
title: str
|
||||||
detail: str
|
detail: str
|
||||||
instance: Optional[str] = None
|
instance: str | None = None
|
||||||
|
|
|
@ -4,10 +4,9 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
from typing import Annotated, Any, Literal, Protocol
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from llama_stack.apis.agents import AgentConfig
|
from llama_stack.apis.agents import AgentConfig
|
||||||
from llama_stack.apis.common.job_types import Job
|
from llama_stack.apis.common.job_types import Job
|
||||||
|
@ -29,7 +28,7 @@ class ModelCandidate(BaseModel):
|
||||||
type: Literal["model"] = "model"
|
type: Literal["model"] = "model"
|
||||||
model: str
|
model: str
|
||||||
sampling_params: SamplingParams
|
sampling_params: SamplingParams
|
||||||
system_message: Optional[SystemMessage] = None
|
system_message: SystemMessage | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -43,7 +42,7 @@ class AgentCandidate(BaseModel):
|
||||||
config: AgentConfig
|
config: AgentConfig
|
||||||
|
|
||||||
|
|
||||||
EvalCandidate = Annotated[Union[ModelCandidate, AgentCandidate], Field(discriminator="type")]
|
EvalCandidate = Annotated[ModelCandidate | AgentCandidate, Field(discriminator="type")]
|
||||||
register_schema(EvalCandidate, name="EvalCandidate")
|
register_schema(EvalCandidate, name="EvalCandidate")
|
||||||
|
|
||||||
|
|
||||||
|
@ -57,11 +56,11 @@ class BenchmarkConfig(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
eval_candidate: EvalCandidate
|
eval_candidate: EvalCandidate
|
||||||
scoring_params: Dict[str, ScoringFnParams] = Field(
|
scoring_params: dict[str, ScoringFnParams] = Field(
|
||||||
description="Map between scoring function id and parameters for each scoring function you want to run",
|
description="Map between scoring function id and parameters for each scoring function you want to run",
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
)
|
)
|
||||||
num_examples: Optional[int] = Field(
|
num_examples: int | None = Field(
|
||||||
description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated",
|
description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated",
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
@ -76,9 +75,9 @@ class EvaluateResponse(BaseModel):
|
||||||
:param scores: The scores from the evaluation.
|
:param scores: The scores from the evaluation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
generations: List[Dict[str, Any]]
|
generations: list[dict[str, Any]]
|
||||||
# each key in the dict is a scoring function name
|
# each key in the dict is a scoring function name
|
||||||
scores: Dict[str, ScoringResult]
|
scores: dict[str, ScoringResult]
|
||||||
|
|
||||||
|
|
||||||
class Eval(Protocol):
|
class Eval(Protocol):
|
||||||
|
@ -101,8 +100,8 @@ class Eval(Protocol):
|
||||||
async def evaluate_rows(
|
async def evaluate_rows(
|
||||||
self,
|
self,
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
input_rows: List[Dict[str, Any]],
|
input_rows: list[dict[str, Any]],
|
||||||
scoring_functions: List[str],
|
scoring_functions: list[str],
|
||||||
benchmark_config: BenchmarkConfig,
|
benchmark_config: BenchmarkConfig,
|
||||||
) -> EvaluateResponse:
|
) -> EvaluateResponse:
|
||||||
"""Evaluate a list of rows on a benchmark.
|
"""Evaluate a list of rows on a benchmark.
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import List, Optional, Protocol, runtime_checkable
|
from typing import Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -42,7 +42,7 @@ class ListBucketResponse(BaseModel):
|
||||||
:param data: List of FileResponse entries
|
:param data: List of FileResponse entries
|
||||||
"""
|
"""
|
||||||
|
|
||||||
data: List[BucketResponse]
|
data: list[BucketResponse]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -74,7 +74,7 @@ class ListFileResponse(BaseModel):
|
||||||
:param data: List of FileResponse entries
|
:param data: List of FileResponse entries
|
||||||
"""
|
"""
|
||||||
|
|
||||||
data: List[FileResponse]
|
data: list[FileResponse]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
@ -102,7 +102,7 @@ class Files(Protocol):
|
||||||
async def upload_content_to_session(
|
async def upload_content_to_session(
|
||||||
self,
|
self,
|
||||||
upload_id: str,
|
upload_id: str,
|
||||||
) -> Optional[FileResponse]:
|
) -> FileResponse | None:
|
||||||
"""
|
"""
|
||||||
Upload file content to an existing upload session.
|
Upload file content to an existing upload session.
|
||||||
On the server, request body will have the raw bytes that are uploaded.
|
On the server, request body will have the raw bytes that are uploaded.
|
||||||
|
|
|
@ -4,21 +4,18 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import (
|
from typing import (
|
||||||
|
Annotated,
|
||||||
Any,
|
Any,
|
||||||
AsyncIterator,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
|
||||||
Protocol,
|
Protocol,
|
||||||
Union,
|
|
||||||
runtime_checkable,
|
runtime_checkable,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from typing_extensions import Annotated, TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
|
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
|
@ -47,8 +44,8 @@ class GreedySamplingStrategy(BaseModel):
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class TopPSamplingStrategy(BaseModel):
|
class TopPSamplingStrategy(BaseModel):
|
||||||
type: Literal["top_p"] = "top_p"
|
type: Literal["top_p"] = "top_p"
|
||||||
temperature: Optional[float] = Field(..., gt=0.0)
|
temperature: float | None = Field(..., gt=0.0)
|
||||||
top_p: Optional[float] = 0.95
|
top_p: float | None = 0.95
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -58,7 +55,7 @@ class TopKSamplingStrategy(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
SamplingStrategy = Annotated[
|
SamplingStrategy = Annotated[
|
||||||
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
|
GreedySamplingStrategy | TopPSamplingStrategy | TopKSamplingStrategy,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(SamplingStrategy, name="SamplingStrategy")
|
register_schema(SamplingStrategy, name="SamplingStrategy")
|
||||||
|
@ -79,9 +76,9 @@ class SamplingParams(BaseModel):
|
||||||
|
|
||||||
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
|
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
|
||||||
|
|
||||||
max_tokens: Optional[int] = 0
|
max_tokens: int | None = 0
|
||||||
repetition_penalty: Optional[float] = 1.0
|
repetition_penalty: float | None = 1.0
|
||||||
stop: Optional[List[str]] = None
|
stop: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
class LogProbConfig(BaseModel):
|
class LogProbConfig(BaseModel):
|
||||||
|
@ -90,7 +87,7 @@ class LogProbConfig(BaseModel):
|
||||||
:param top_k: How many tokens (for each position) to return log probabilities for.
|
:param top_k: How many tokens (for each position) to return log probabilities for.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
top_k: Optional[int] = 0
|
top_k: int | None = 0
|
||||||
|
|
||||||
|
|
||||||
class QuantizationType(Enum):
|
class QuantizationType(Enum):
|
||||||
|
@ -125,11 +122,11 @@ class Int4QuantizationConfig(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: Literal["int4_mixed"] = "int4_mixed"
|
type: Literal["int4_mixed"] = "int4_mixed"
|
||||||
scheme: Optional[str] = "int4_weight_int8_dynamic_activation"
|
scheme: str | None = "int4_weight_int8_dynamic_activation"
|
||||||
|
|
||||||
|
|
||||||
QuantizationConfig = Annotated[
|
QuantizationConfig = Annotated[
|
||||||
Union[Bf16QuantizationConfig, Fp8QuantizationConfig, Int4QuantizationConfig],
|
Bf16QuantizationConfig | Fp8QuantizationConfig | Int4QuantizationConfig,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -145,7 +142,7 @@ class UserMessage(BaseModel):
|
||||||
|
|
||||||
role: Literal["user"] = "user"
|
role: Literal["user"] = "user"
|
||||||
content: InterleavedContent
|
content: InterleavedContent
|
||||||
context: Optional[InterleavedContent] = None
|
context: InterleavedContent | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -190,16 +187,11 @@ class CompletionMessage(BaseModel):
|
||||||
role: Literal["assistant"] = "assistant"
|
role: Literal["assistant"] = "assistant"
|
||||||
content: InterleavedContent
|
content: InterleavedContent
|
||||||
stop_reason: StopReason
|
stop_reason: StopReason
|
||||||
tool_calls: Optional[List[ToolCall]] = Field(default_factory=list)
|
tool_calls: list[ToolCall] | None = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
Message = Annotated[
|
Message = Annotated[
|
||||||
Union[
|
UserMessage | SystemMessage | ToolResponseMessage | CompletionMessage,
|
||||||
UserMessage,
|
|
||||||
SystemMessage,
|
|
||||||
ToolResponseMessage,
|
|
||||||
CompletionMessage,
|
|
||||||
],
|
|
||||||
Field(discriminator="role"),
|
Field(discriminator="role"),
|
||||||
]
|
]
|
||||||
register_schema(Message, name="Message")
|
register_schema(Message, name="Message")
|
||||||
|
@ -208,9 +200,9 @@ register_schema(Message, name="Message")
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ToolResponse(BaseModel):
|
class ToolResponse(BaseModel):
|
||||||
call_id: str
|
call_id: str
|
||||||
tool_name: Union[BuiltinTool, str]
|
tool_name: BuiltinTool | str
|
||||||
content: InterleavedContent
|
content: InterleavedContent
|
||||||
metadata: Optional[Dict[str, Any]] = None
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
@field_validator("tool_name", mode="before")
|
@field_validator("tool_name", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -243,7 +235,7 @@ class TokenLogProbs(BaseModel):
|
||||||
:param logprobs_by_token: Dictionary mapping tokens to their log probabilities
|
:param logprobs_by_token: Dictionary mapping tokens to their log probabilities
|
||||||
"""
|
"""
|
||||||
|
|
||||||
logprobs_by_token: Dict[str, float]
|
logprobs_by_token: dict[str, float]
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponseEventType(Enum):
|
class ChatCompletionResponseEventType(Enum):
|
||||||
|
@ -271,8 +263,8 @@ class ChatCompletionResponseEvent(BaseModel):
|
||||||
|
|
||||||
event_type: ChatCompletionResponseEventType
|
event_type: ChatCompletionResponseEventType
|
||||||
delta: ContentDelta
|
delta: ContentDelta
|
||||||
logprobs: Optional[List[TokenLogProbs]] = None
|
logprobs: list[TokenLogProbs] | None = None
|
||||||
stop_reason: Optional[StopReason] = None
|
stop_reason: StopReason | None = None
|
||||||
|
|
||||||
|
|
||||||
class ResponseFormatType(Enum):
|
class ResponseFormatType(Enum):
|
||||||
|
@ -295,7 +287,7 @@ class JsonSchemaResponseFormat(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: Literal[ResponseFormatType.json_schema.value] = ResponseFormatType.json_schema.value
|
type: Literal[ResponseFormatType.json_schema.value] = ResponseFormatType.json_schema.value
|
||||||
json_schema: Dict[str, Any]
|
json_schema: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -307,11 +299,11 @@ class GrammarResponseFormat(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: Literal[ResponseFormatType.grammar.value] = ResponseFormatType.grammar.value
|
type: Literal[ResponseFormatType.grammar.value] = ResponseFormatType.grammar.value
|
||||||
bnf: Dict[str, Any]
|
bnf: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
ResponseFormat = Annotated[
|
ResponseFormat = Annotated[
|
||||||
Union[JsonSchemaResponseFormat, GrammarResponseFormat],
|
JsonSchemaResponseFormat | GrammarResponseFormat,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(ResponseFormat, name="ResponseFormat")
|
register_schema(ResponseFormat, name="ResponseFormat")
|
||||||
|
@ -321,10 +313,10 @@ register_schema(ResponseFormat, name="ResponseFormat")
|
||||||
class CompletionRequest(BaseModel):
|
class CompletionRequest(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
content: InterleavedContent
|
content: InterleavedContent
|
||||||
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
|
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
|
||||||
response_format: Optional[ResponseFormat] = None
|
response_format: ResponseFormat | None = None
|
||||||
stream: Optional[bool] = False
|
stream: bool | None = False
|
||||||
logprobs: Optional[LogProbConfig] = None
|
logprobs: LogProbConfig | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -338,7 +330,7 @@ class CompletionResponse(MetricResponseMixin):
|
||||||
|
|
||||||
content: str
|
content: str
|
||||||
stop_reason: StopReason
|
stop_reason: StopReason
|
||||||
logprobs: Optional[List[TokenLogProbs]] = None
|
logprobs: list[TokenLogProbs] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -351,8 +343,8 @@ class CompletionResponseStreamChunk(MetricResponseMixin):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
delta: str
|
delta: str
|
||||||
stop_reason: Optional[StopReason] = None
|
stop_reason: StopReason | None = None
|
||||||
logprobs: Optional[List[TokenLogProbs]] = None
|
logprobs: list[TokenLogProbs] | None = None
|
||||||
|
|
||||||
|
|
||||||
class SystemMessageBehavior(Enum):
|
class SystemMessageBehavior(Enum):
|
||||||
|
@ -383,9 +375,9 @@ class ToolConfig(BaseModel):
|
||||||
'{{function_definitions}}' to indicate where the function definitions should be inserted.
|
'{{function_definitions}}' to indicate where the function definitions should be inserted.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tool_choice: Optional[ToolChoice | str] = Field(default=ToolChoice.auto)
|
tool_choice: ToolChoice | str | None = Field(default=ToolChoice.auto)
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
|
tool_prompt_format: ToolPromptFormat | None = Field(default=None)
|
||||||
system_message_behavior: Optional[SystemMessageBehavior] = Field(default=SystemMessageBehavior.append)
|
system_message_behavior: SystemMessageBehavior | None = Field(default=SystemMessageBehavior.append)
|
||||||
|
|
||||||
def model_post_init(self, __context: Any) -> None:
|
def model_post_init(self, __context: Any) -> None:
|
||||||
if isinstance(self.tool_choice, str):
|
if isinstance(self.tool_choice, str):
|
||||||
|
@ -399,15 +391,15 @@ class ToolConfig(BaseModel):
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ChatCompletionRequest(BaseModel):
|
class ChatCompletionRequest(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
messages: List[Message]
|
messages: list[Message]
|
||||||
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
|
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
|
||||||
|
|
||||||
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
|
tools: list[ToolDefinition] | None = Field(default_factory=list)
|
||||||
tool_config: Optional[ToolConfig] = Field(default_factory=ToolConfig)
|
tool_config: ToolConfig | None = Field(default_factory=ToolConfig)
|
||||||
|
|
||||||
response_format: Optional[ResponseFormat] = None
|
response_format: ResponseFormat | None = None
|
||||||
stream: Optional[bool] = False
|
stream: bool | None = False
|
||||||
logprobs: Optional[LogProbConfig] = None
|
logprobs: LogProbConfig | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -429,7 +421,7 @@ class ChatCompletionResponse(MetricResponseMixin):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
completion_message: CompletionMessage
|
completion_message: CompletionMessage
|
||||||
logprobs: Optional[List[TokenLogProbs]] = None
|
logprobs: list[TokenLogProbs] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -439,7 +431,7 @@ class EmbeddingsResponse(BaseModel):
|
||||||
:param embeddings: List of embedding vectors, one per input content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
|
:param embeddings: List of embedding vectors, one per input content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
embeddings: List[List[float]]
|
embeddings: list[list[float]]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -451,7 +443,7 @@ class OpenAIChatCompletionContentPartTextParam(BaseModel):
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIImageURL(BaseModel):
|
class OpenAIImageURL(BaseModel):
|
||||||
url: str
|
url: str
|
||||||
detail: Optional[str] = None
|
detail: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -461,16 +453,13 @@ class OpenAIChatCompletionContentPartImageParam(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
OpenAIChatCompletionContentPartParam = Annotated[
|
OpenAIChatCompletionContentPartParam = Annotated[
|
||||||
Union[
|
OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
||||||
OpenAIChatCompletionContentPartTextParam,
|
|
||||||
OpenAIChatCompletionContentPartImageParam,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam")
|
register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam")
|
||||||
|
|
||||||
|
|
||||||
OpenAIChatCompletionMessageContent = Union[str, List[OpenAIChatCompletionContentPartParam]]
|
OpenAIChatCompletionMessageContent = str | list[OpenAIChatCompletionContentPartParam]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -484,7 +473,7 @@ class OpenAIUserMessageParam(BaseModel):
|
||||||
|
|
||||||
role: Literal["user"] = "user"
|
role: Literal["user"] = "user"
|
||||||
content: OpenAIChatCompletionMessageContent
|
content: OpenAIChatCompletionMessageContent
|
||||||
name: Optional[str] = None
|
name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -498,21 +487,21 @@ class OpenAISystemMessageParam(BaseModel):
|
||||||
|
|
||||||
role: Literal["system"] = "system"
|
role: Literal["system"] = "system"
|
||||||
content: OpenAIChatCompletionMessageContent
|
content: OpenAIChatCompletionMessageContent
|
||||||
name: Optional[str] = None
|
name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIChatCompletionToolCallFunction(BaseModel):
|
class OpenAIChatCompletionToolCallFunction(BaseModel):
|
||||||
name: Optional[str] = None
|
name: str | None = None
|
||||||
arguments: Optional[str] = None
|
arguments: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIChatCompletionToolCall(BaseModel):
|
class OpenAIChatCompletionToolCall(BaseModel):
|
||||||
index: Optional[int] = None
|
index: int | None = None
|
||||||
id: Optional[str] = None
|
id: str | None = None
|
||||||
type: Literal["function"] = "function"
|
type: Literal["function"] = "function"
|
||||||
function: Optional[OpenAIChatCompletionToolCallFunction] = None
|
function: OpenAIChatCompletionToolCallFunction | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -526,9 +515,9 @@ class OpenAIAssistantMessageParam(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
role: Literal["assistant"] = "assistant"
|
role: Literal["assistant"] = "assistant"
|
||||||
content: Optional[OpenAIChatCompletionMessageContent] = None
|
content: OpenAIChatCompletionMessageContent | None = None
|
||||||
name: Optional[str] = None
|
name: str | None = None
|
||||||
tool_calls: Optional[List[OpenAIChatCompletionToolCall]] = None
|
tool_calls: list[OpenAIChatCompletionToolCall] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -556,17 +545,15 @@ class OpenAIDeveloperMessageParam(BaseModel):
|
||||||
|
|
||||||
role: Literal["developer"] = "developer"
|
role: Literal["developer"] = "developer"
|
||||||
content: OpenAIChatCompletionMessageContent
|
content: OpenAIChatCompletionMessageContent
|
||||||
name: Optional[str] = None
|
name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
OpenAIMessageParam = Annotated[
|
OpenAIMessageParam = Annotated[
|
||||||
Union[
|
OpenAIUserMessageParam
|
||||||
OpenAIUserMessageParam,
|
| OpenAISystemMessageParam
|
||||||
OpenAISystemMessageParam,
|
| OpenAIAssistantMessageParam
|
||||||
OpenAIAssistantMessageParam,
|
| OpenAIToolMessageParam
|
||||||
OpenAIToolMessageParam,
|
| OpenAIDeveloperMessageParam,
|
||||||
OpenAIDeveloperMessageParam,
|
|
||||||
],
|
|
||||||
Field(discriminator="role"),
|
Field(discriminator="role"),
|
||||||
]
|
]
|
||||||
register_schema(OpenAIMessageParam, name="OpenAIMessageParam")
|
register_schema(OpenAIMessageParam, name="OpenAIMessageParam")
|
||||||
|
@ -580,14 +567,14 @@ class OpenAIResponseFormatText(BaseModel):
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIJSONSchema(TypedDict, total=False):
|
class OpenAIJSONSchema(TypedDict, total=False):
|
||||||
name: str
|
name: str
|
||||||
description: Optional[str] = None
|
description: str | None = None
|
||||||
strict: Optional[bool] = None
|
strict: bool | None = None
|
||||||
|
|
||||||
# Pydantic BaseModel cannot be used with a schema param, since it already
|
# Pydantic BaseModel cannot be used with a schema param, since it already
|
||||||
# has one. And, we don't want to alias here because then have to handle
|
# has one. And, we don't want to alias here because then have to handle
|
||||||
# that alias when converting to OpenAI params. So, to support schema,
|
# that alias when converting to OpenAI params. So, to support schema,
|
||||||
# we use a TypedDict.
|
# we use a TypedDict.
|
||||||
schema: Optional[Dict[str, Any]] = None
|
schema: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -602,11 +589,7 @@ class OpenAIResponseFormatJSONObject(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
OpenAIResponseFormatParam = Annotated[
|
OpenAIResponseFormatParam = Annotated[
|
||||||
Union[
|
OpenAIResponseFormatText | OpenAIResponseFormatJSONSchema | OpenAIResponseFormatJSONObject,
|
||||||
OpenAIResponseFormatText,
|
|
||||||
OpenAIResponseFormatJSONSchema,
|
|
||||||
OpenAIResponseFormatJSONObject,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(OpenAIResponseFormatParam, name="OpenAIResponseFormatParam")
|
register_schema(OpenAIResponseFormatParam, name="OpenAIResponseFormatParam")
|
||||||
|
@ -622,7 +605,7 @@ class OpenAITopLogProb(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
token: str
|
token: str
|
||||||
bytes: Optional[List[int]] = None
|
bytes: list[int] | None = None
|
||||||
logprob: float
|
logprob: float
|
||||||
|
|
||||||
|
|
||||||
|
@ -637,9 +620,9 @@ class OpenAITokenLogProb(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
token: str
|
token: str
|
||||||
bytes: Optional[List[int]] = None
|
bytes: list[int] | None = None
|
||||||
logprob: float
|
logprob: float
|
||||||
top_logprobs: List[OpenAITopLogProb]
|
top_logprobs: list[OpenAITopLogProb]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -650,8 +633,8 @@ class OpenAIChoiceLogprobs(BaseModel):
|
||||||
:param refusal: (Optional) The log probabilities for the tokens in the message
|
:param refusal: (Optional) The log probabilities for the tokens in the message
|
||||||
"""
|
"""
|
||||||
|
|
||||||
content: Optional[List[OpenAITokenLogProb]] = None
|
content: list[OpenAITokenLogProb] | None = None
|
||||||
refusal: Optional[List[OpenAITokenLogProb]] = None
|
refusal: list[OpenAITokenLogProb] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -664,10 +647,10 @@ class OpenAIChoiceDelta(BaseModel):
|
||||||
:param tool_calls: (Optional) The tool calls of the delta
|
:param tool_calls: (Optional) The tool calls of the delta
|
||||||
"""
|
"""
|
||||||
|
|
||||||
content: Optional[str] = None
|
content: str | None = None
|
||||||
refusal: Optional[str] = None
|
refusal: str | None = None
|
||||||
role: Optional[str] = None
|
role: str | None = None
|
||||||
tool_calls: Optional[List[OpenAIChatCompletionToolCall]] = None
|
tool_calls: list[OpenAIChatCompletionToolCall] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -683,7 +666,7 @@ class OpenAIChunkChoice(BaseModel):
|
||||||
delta: OpenAIChoiceDelta
|
delta: OpenAIChoiceDelta
|
||||||
finish_reason: str
|
finish_reason: str
|
||||||
index: int
|
index: int
|
||||||
logprobs: Optional[OpenAIChoiceLogprobs] = None
|
logprobs: OpenAIChoiceLogprobs | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -699,7 +682,7 @@ class OpenAIChoice(BaseModel):
|
||||||
message: OpenAIMessageParam
|
message: OpenAIMessageParam
|
||||||
finish_reason: str
|
finish_reason: str
|
||||||
index: int
|
index: int
|
||||||
logprobs: Optional[OpenAIChoiceLogprobs] = None
|
logprobs: OpenAIChoiceLogprobs | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -714,7 +697,7 @@ class OpenAIChatCompletion(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
choices: List[OpenAIChoice]
|
choices: list[OpenAIChoice]
|
||||||
object: Literal["chat.completion"] = "chat.completion"
|
object: Literal["chat.completion"] = "chat.completion"
|
||||||
created: int
|
created: int
|
||||||
model: str
|
model: str
|
||||||
|
@ -732,7 +715,7 @@ class OpenAIChatCompletionChunk(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
choices: List[OpenAIChunkChoice]
|
choices: list[OpenAIChunkChoice]
|
||||||
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||||
created: int
|
created: int
|
||||||
model: str
|
model: str
|
||||||
|
@ -748,10 +731,10 @@ class OpenAICompletionLogprobs(BaseModel):
|
||||||
:top_logprobs: (Optional) The top log probabilities for the tokens
|
:top_logprobs: (Optional) The top log probabilities for the tokens
|
||||||
"""
|
"""
|
||||||
|
|
||||||
text_offset: Optional[List[int]] = None
|
text_offset: list[int] | None = None
|
||||||
token_logprobs: Optional[List[float]] = None
|
token_logprobs: list[float] | None = None
|
||||||
tokens: Optional[List[str]] = None
|
tokens: list[str] | None = None
|
||||||
top_logprobs: Optional[List[Dict[str, float]]] = None
|
top_logprobs: list[dict[str, float]] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -767,7 +750,7 @@ class OpenAICompletionChoice(BaseModel):
|
||||||
finish_reason: str
|
finish_reason: str
|
||||||
text: str
|
text: str
|
||||||
index: int
|
index: int
|
||||||
logprobs: Optional[OpenAIChoiceLogprobs] = None
|
logprobs: OpenAIChoiceLogprobs | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -782,7 +765,7 @@ class OpenAICompletion(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
choices: List[OpenAICompletionChoice]
|
choices: list[OpenAICompletionChoice]
|
||||||
created: int
|
created: int
|
||||||
model: str
|
model: str
|
||||||
object: Literal["text_completion"] = "text_completion"
|
object: Literal["text_completion"] = "text_completion"
|
||||||
|
@ -818,12 +801,12 @@ class EmbeddingTaskType(Enum):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class BatchCompletionResponse(BaseModel):
|
class BatchCompletionResponse(BaseModel):
|
||||||
batch: List[CompletionResponse]
|
batch: list[CompletionResponse]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class BatchChatCompletionResponse(BaseModel):
|
class BatchChatCompletionResponse(BaseModel):
|
||||||
batch: List[ChatCompletionResponse]
|
batch: list[ChatCompletionResponse]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
@ -843,11 +826,11 @@ class Inference(Protocol):
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
content: InterleavedContent,
|
content: InterleavedContent,
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: SamplingParams | None = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: ResponseFormat | None = None,
|
||||||
stream: Optional[bool] = False,
|
stream: bool | None = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
|
) -> CompletionResponse | AsyncIterator[CompletionResponseStreamChunk]:
|
||||||
"""Generate a completion for the given content using the specified model.
|
"""Generate a completion for the given content using the specified model.
|
||||||
|
|
||||||
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||||
|
@ -865,10 +848,10 @@ class Inference(Protocol):
|
||||||
async def batch_completion(
|
async def batch_completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
content_batch: List[InterleavedContent],
|
content_batch: list[InterleavedContent],
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: SamplingParams | None = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: ResponseFormat | None = None,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
) -> BatchCompletionResponse:
|
) -> BatchCompletionResponse:
|
||||||
raise NotImplementedError("Batch completion is not implemented")
|
raise NotImplementedError("Batch completion is not implemented")
|
||||||
|
|
||||||
|
@ -876,16 +859,16 @@ class Inference(Protocol):
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
messages: List[Message],
|
messages: list[Message],
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: SamplingParams | None = None,
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: list[ToolDefinition] | None = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: ToolPromptFormat | None = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: ResponseFormat | None = None,
|
||||||
stream: Optional[bool] = False,
|
stream: bool | None = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: ToolConfig | None = None,
|
||||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||||
"""Generate a chat completion for the given messages using the specified model.
|
"""Generate a chat completion for the given messages using the specified model.
|
||||||
|
|
||||||
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||||
|
@ -916,12 +899,12 @@ class Inference(Protocol):
|
||||||
async def batch_chat_completion(
|
async def batch_chat_completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
messages_batch: List[List[Message]],
|
messages_batch: list[list[Message]],
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: SamplingParams | None = None,
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: list[ToolDefinition] | None = None,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: ToolConfig | None = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: ResponseFormat | None = None,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
) -> BatchChatCompletionResponse:
|
) -> BatchChatCompletionResponse:
|
||||||
raise NotImplementedError("Batch chat completion is not implemented")
|
raise NotImplementedError("Batch chat completion is not implemented")
|
||||||
|
|
||||||
|
@ -929,10 +912,10 @@ class Inference(Protocol):
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[str] | List[InterleavedContentItem],
|
contents: list[str] | list[InterleavedContentItem],
|
||||||
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||||
output_dimension: Optional[int] = None,
|
output_dimension: int | None = None,
|
||||||
task_type: Optional[EmbeddingTaskType] = None,
|
task_type: EmbeddingTaskType | None = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
"""Generate embeddings for content pieces using the specified model.
|
"""Generate embeddings for content pieces using the specified model.
|
||||||
|
|
||||||
|
@ -950,25 +933,25 @@ class Inference(Protocol):
|
||||||
self,
|
self,
|
||||||
# Standard OpenAI completion parameters
|
# Standard OpenAI completion parameters
|
||||||
model: str,
|
model: str,
|
||||||
prompt: Union[str, List[str], List[int], List[List[int]]],
|
prompt: str | list[str] | list[int] | list[list[int]],
|
||||||
best_of: Optional[int] = None,
|
best_of: int | None = None,
|
||||||
echo: Optional[bool] = None,
|
echo: bool | None = None,
|
||||||
frequency_penalty: Optional[float] = None,
|
frequency_penalty: float | None = None,
|
||||||
logit_bias: Optional[Dict[str, float]] = None,
|
logit_bias: dict[str, float] | None = None,
|
||||||
logprobs: Optional[bool] = None,
|
logprobs: bool | None = None,
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: int | None = None,
|
||||||
n: Optional[int] = None,
|
n: int | None = None,
|
||||||
presence_penalty: Optional[float] = None,
|
presence_penalty: float | None = None,
|
||||||
seed: Optional[int] = None,
|
seed: int | None = None,
|
||||||
stop: Optional[Union[str, List[str]]] = None,
|
stop: str | list[str] | None = None,
|
||||||
stream: Optional[bool] = None,
|
stream: bool | None = None,
|
||||||
stream_options: Optional[Dict[str, Any]] = None,
|
stream_options: dict[str, Any] | None = None,
|
||||||
temperature: Optional[float] = None,
|
temperature: float | None = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: float | None = None,
|
||||||
user: Optional[str] = None,
|
user: str | None = None,
|
||||||
# vLLM-specific parameters
|
# vLLM-specific parameters
|
||||||
guided_choice: Optional[List[str]] = None,
|
guided_choice: list[str] | None = None,
|
||||||
prompt_logprobs: Optional[int] = None,
|
prompt_logprobs: int | None = None,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
"""Generate an OpenAI-compatible completion for the given prompt using the specified model.
|
"""Generate an OpenAI-compatible completion for the given prompt using the specified model.
|
||||||
|
|
||||||
|
@ -996,29 +979,29 @@ class Inference(Protocol):
|
||||||
async def openai_chat_completion(
|
async def openai_chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[OpenAIMessageParam],
|
messages: list[OpenAIMessageParam],
|
||||||
frequency_penalty: Optional[float] = None,
|
frequency_penalty: float | None = None,
|
||||||
function_call: Optional[Union[str, Dict[str, Any]]] = None,
|
function_call: str | dict[str, Any] | None = None,
|
||||||
functions: Optional[List[Dict[str, Any]]] = None,
|
functions: list[dict[str, Any]] | None = None,
|
||||||
logit_bias: Optional[Dict[str, float]] = None,
|
logit_bias: dict[str, float] | None = None,
|
||||||
logprobs: Optional[bool] = None,
|
logprobs: bool | None = None,
|
||||||
max_completion_tokens: Optional[int] = None,
|
max_completion_tokens: int | None = None,
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: int | None = None,
|
||||||
n: Optional[int] = None,
|
n: int | None = None,
|
||||||
parallel_tool_calls: Optional[bool] = None,
|
parallel_tool_calls: bool | None = None,
|
||||||
presence_penalty: Optional[float] = None,
|
presence_penalty: float | None = None,
|
||||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
response_format: OpenAIResponseFormatParam | None = None,
|
||||||
seed: Optional[int] = None,
|
seed: int | None = None,
|
||||||
stop: Optional[Union[str, List[str]]] = None,
|
stop: str | list[str] | None = None,
|
||||||
stream: Optional[bool] = None,
|
stream: bool | None = None,
|
||||||
stream_options: Optional[Dict[str, Any]] = None,
|
stream_options: dict[str, Any] | None = None,
|
||||||
temperature: Optional[float] = None,
|
temperature: float | None = None,
|
||||||
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
tools: Optional[List[Dict[str, Any]]] = None,
|
tools: list[dict[str, Any]] | None = None,
|
||||||
top_logprobs: Optional[int] = None,
|
top_logprobs: int | None = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: float | None = None,
|
||||||
user: Optional[str] = None,
|
user: str | None = None,
|
||||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||||
"""Generate an OpenAI-compatible chat completion for the given messages using the specified model.
|
"""Generate an OpenAI-compatible chat completion for the given messages using the specified model.
|
||||||
|
|
||||||
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import List, Protocol, runtime_checkable
|
from typing import Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@ from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
class RouteInfo(BaseModel):
|
class RouteInfo(BaseModel):
|
||||||
route: str
|
route: str
|
||||||
method: str
|
method: str
|
||||||
provider_types: List[str]
|
provider_types: list[str]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -30,7 +30,7 @@ class VersionInfo(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class ListRoutesResponse(BaseModel):
|
class ListRoutesResponse(BaseModel):
|
||||||
data: List[RouteInfo]
|
data: list[RouteInfo]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
from typing import Any, Literal, Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@ from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
class CommonModelFields(BaseModel):
|
class CommonModelFields(BaseModel):
|
||||||
metadata: Dict[str, Any] = Field(
|
metadata: dict[str, Any] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
description="Any additional metadata for this model",
|
description="Any additional metadata for this model",
|
||||||
)
|
)
|
||||||
|
@ -46,14 +46,14 @@ class Model(CommonModelFields, Resource):
|
||||||
|
|
||||||
class ModelInput(CommonModelFields):
|
class ModelInput(CommonModelFields):
|
||||||
model_id: str
|
model_id: str
|
||||||
provider_id: Optional[str] = None
|
provider_id: str | None = None
|
||||||
provider_model_id: Optional[str] = None
|
provider_model_id: str | None = None
|
||||||
model_type: Optional[ModelType] = ModelType.llm
|
model_type: ModelType | None = ModelType.llm
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
class ListModelsResponse(BaseModel):
|
class ListModelsResponse(BaseModel):
|
||||||
data: List[Model]
|
data: list[Model]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -73,7 +73,7 @@ class OpenAIModel(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class OpenAIListModelsResponse(BaseModel):
|
class OpenAIListModelsResponse(BaseModel):
|
||||||
data: List[OpenAIModel]
|
data: list[OpenAIModel]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
@ -95,10 +95,10 @@ class Models(Protocol):
|
||||||
async def register_model(
|
async def register_model(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
provider_model_id: Optional[str] = None,
|
provider_model_id: str | None = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: str | None = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
model_type: Optional[ModelType] = None,
|
model_type: ModelType | None = None,
|
||||||
) -> Model: ...
|
) -> Model: ...
|
||||||
|
|
||||||
@webmethod(route="/models/{model_id:path}", method="DELETE")
|
@webmethod(route="/models/{model_id:path}", method="DELETE")
|
||||||
|
|
|
@ -6,10 +6,9 @@
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
from typing import Annotated, Any, Literal, Protocol
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.common.job_types import JobStatus
|
from llama_stack.apis.common.job_types import JobStatus
|
||||||
|
@ -36,9 +35,9 @@ class DataConfig(BaseModel):
|
||||||
batch_size: int
|
batch_size: int
|
||||||
shuffle: bool
|
shuffle: bool
|
||||||
data_format: DatasetFormat
|
data_format: DatasetFormat
|
||||||
validation_dataset_id: Optional[str] = None
|
validation_dataset_id: str | None = None
|
||||||
packed: Optional[bool] = False
|
packed: bool | None = False
|
||||||
train_on_input: Optional[bool] = False
|
train_on_input: bool | None = False
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -51,10 +50,10 @@ class OptimizerConfig(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class EfficiencyConfig(BaseModel):
|
class EfficiencyConfig(BaseModel):
|
||||||
enable_activation_checkpointing: Optional[bool] = False
|
enable_activation_checkpointing: bool | None = False
|
||||||
enable_activation_offloading: Optional[bool] = False
|
enable_activation_offloading: bool | None = False
|
||||||
memory_efficient_fsdp_wrap: Optional[bool] = False
|
memory_efficient_fsdp_wrap: bool | None = False
|
||||||
fsdp_cpu_offload: Optional[bool] = False
|
fsdp_cpu_offload: bool | None = False
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -62,23 +61,23 @@ class TrainingConfig(BaseModel):
|
||||||
n_epochs: int
|
n_epochs: int
|
||||||
max_steps_per_epoch: int = 1
|
max_steps_per_epoch: int = 1
|
||||||
gradient_accumulation_steps: int = 1
|
gradient_accumulation_steps: int = 1
|
||||||
max_validation_steps: Optional[int] = 1
|
max_validation_steps: int | None = 1
|
||||||
data_config: Optional[DataConfig] = None
|
data_config: DataConfig | None = None
|
||||||
optimizer_config: Optional[OptimizerConfig] = None
|
optimizer_config: OptimizerConfig | None = None
|
||||||
efficiency_config: Optional[EfficiencyConfig] = None
|
efficiency_config: EfficiencyConfig | None = None
|
||||||
dtype: Optional[str] = "bf16"
|
dtype: str | None = "bf16"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class LoraFinetuningConfig(BaseModel):
|
class LoraFinetuningConfig(BaseModel):
|
||||||
type: Literal["LoRA"] = "LoRA"
|
type: Literal["LoRA"] = "LoRA"
|
||||||
lora_attn_modules: List[str]
|
lora_attn_modules: list[str]
|
||||||
apply_lora_to_mlp: bool
|
apply_lora_to_mlp: bool
|
||||||
apply_lora_to_output: bool
|
apply_lora_to_output: bool
|
||||||
rank: int
|
rank: int
|
||||||
alpha: int
|
alpha: int
|
||||||
use_dora: Optional[bool] = False
|
use_dora: bool | None = False
|
||||||
quantize_base: Optional[bool] = False
|
quantize_base: bool | None = False
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -88,7 +87,7 @@ class QATFinetuningConfig(BaseModel):
|
||||||
group_size: int
|
group_size: int
|
||||||
|
|
||||||
|
|
||||||
AlgorithmConfig = Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")]
|
AlgorithmConfig = Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")]
|
||||||
register_schema(AlgorithmConfig, name="AlgorithmConfig")
|
register_schema(AlgorithmConfig, name="AlgorithmConfig")
|
||||||
|
|
||||||
|
|
||||||
|
@ -97,7 +96,7 @@ class PostTrainingJobLogStream(BaseModel):
|
||||||
"""Stream of logs from a finetuning job."""
|
"""Stream of logs from a finetuning job."""
|
||||||
|
|
||||||
job_uuid: str
|
job_uuid: str
|
||||||
log_lines: List[str]
|
log_lines: list[str]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -131,8 +130,8 @@ class PostTrainingRLHFRequest(BaseModel):
|
||||||
training_config: TrainingConfig
|
training_config: TrainingConfig
|
||||||
|
|
||||||
# TODO: define these
|
# TODO: define these
|
||||||
hyperparam_search_config: Dict[str, Any]
|
hyperparam_search_config: dict[str, Any]
|
||||||
logger_config: Dict[str, Any]
|
logger_config: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
class PostTrainingJob(BaseModel):
|
class PostTrainingJob(BaseModel):
|
||||||
|
@ -146,17 +145,17 @@ class PostTrainingJobStatusResponse(BaseModel):
|
||||||
job_uuid: str
|
job_uuid: str
|
||||||
status: JobStatus
|
status: JobStatus
|
||||||
|
|
||||||
scheduled_at: Optional[datetime] = None
|
scheduled_at: datetime | None = None
|
||||||
started_at: Optional[datetime] = None
|
started_at: datetime | None = None
|
||||||
completed_at: Optional[datetime] = None
|
completed_at: datetime | None = None
|
||||||
|
|
||||||
resources_allocated: Optional[Dict[str, Any]] = None
|
resources_allocated: dict[str, Any] | None = None
|
||||||
|
|
||||||
checkpoints: List[Checkpoint] = Field(default_factory=list)
|
checkpoints: list[Checkpoint] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class ListPostTrainingJobsResponse(BaseModel):
|
class ListPostTrainingJobsResponse(BaseModel):
|
||||||
data: List[PostTrainingJob]
|
data: list[PostTrainingJob]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -164,7 +163,7 @@ class PostTrainingJobArtifactsResponse(BaseModel):
|
||||||
"""Artifacts of a finetuning job."""
|
"""Artifacts of a finetuning job."""
|
||||||
|
|
||||||
job_uuid: str
|
job_uuid: str
|
||||||
checkpoints: List[Checkpoint] = Field(default_factory=list)
|
checkpoints: list[Checkpoint] = Field(default_factory=list)
|
||||||
|
|
||||||
# TODO(ashwin): metrics, evals
|
# TODO(ashwin): metrics, evals
|
||||||
|
|
||||||
|
@ -175,14 +174,14 @@ class PostTraining(Protocol):
|
||||||
self,
|
self,
|
||||||
job_uuid: str,
|
job_uuid: str,
|
||||||
training_config: TrainingConfig,
|
training_config: TrainingConfig,
|
||||||
hyperparam_search_config: Dict[str, Any],
|
hyperparam_search_config: dict[str, Any],
|
||||||
logger_config: Dict[str, Any],
|
logger_config: dict[str, Any],
|
||||||
model: Optional[str] = Field(
|
model: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Model descriptor for training if not in provider config`",
|
description="Model descriptor for training if not in provider config`",
|
||||||
),
|
),
|
||||||
checkpoint_dir: Optional[str] = None,
|
checkpoint_dir: str | None = None,
|
||||||
algorithm_config: Optional[AlgorithmConfig] = None,
|
algorithm_config: AlgorithmConfig | None = None,
|
||||||
) -> PostTrainingJob: ...
|
) -> PostTrainingJob: ...
|
||||||
|
|
||||||
@webmethod(route="/post-training/preference-optimize", method="POST")
|
@webmethod(route="/post-training/preference-optimize", method="POST")
|
||||||
|
@ -192,8 +191,8 @@ class PostTraining(Protocol):
|
||||||
finetuned_model: str,
|
finetuned_model: str,
|
||||||
algorithm_config: DPOAlignmentConfig,
|
algorithm_config: DPOAlignmentConfig,
|
||||||
training_config: TrainingConfig,
|
training_config: TrainingConfig,
|
||||||
hyperparam_search_config: Dict[str, Any],
|
hyperparam_search_config: dict[str, Any],
|
||||||
logger_config: Dict[str, Any],
|
logger_config: dict[str, Any],
|
||||||
) -> PostTrainingJob: ...
|
) -> PostTrainingJob: ...
|
||||||
|
|
||||||
@webmethod(route="/post-training/jobs", method="GET")
|
@webmethod(route="/post-training/jobs", method="GET")
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, List, Protocol, runtime_checkable
|
from typing import Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -17,12 +17,12 @@ class ProviderInfo(BaseModel):
|
||||||
api: str
|
api: str
|
||||||
provider_id: str
|
provider_id: str
|
||||||
provider_type: str
|
provider_type: str
|
||||||
config: Dict[str, Any]
|
config: dict[str, Any]
|
||||||
health: HealthResponse
|
health: HealthResponse
|
||||||
|
|
||||||
|
|
||||||
class ListProvidersResponse(BaseModel):
|
class ListProvidersResponse(BaseModel):
|
||||||
data: List[ProviderInfo]
|
data: list[ProviderInfo]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
from typing import Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
@ -27,16 +27,16 @@ class SafetyViolation(BaseModel):
|
||||||
violation_level: ViolationLevel
|
violation_level: ViolationLevel
|
||||||
|
|
||||||
# what message should you convey to the user
|
# what message should you convey to the user
|
||||||
user_message: Optional[str] = None
|
user_message: str | None = None
|
||||||
|
|
||||||
# additional metadata (including specific violation codes) more for
|
# additional metadata (including specific violation codes) more for
|
||||||
# debugging, telemetry
|
# debugging, telemetry
|
||||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RunShieldResponse(BaseModel):
|
class RunShieldResponse(BaseModel):
|
||||||
violation: Optional[SafetyViolation] = None
|
violation: SafetyViolation | None = None
|
||||||
|
|
||||||
|
|
||||||
class ShieldStore(Protocol):
|
class ShieldStore(Protocol):
|
||||||
|
@ -52,6 +52,6 @@ class Safety(Protocol):
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self,
|
self,
|
||||||
shield_id: str,
|
shield_id: str,
|
||||||
messages: List[Message],
|
messages: list[Message],
|
||||||
params: Dict[str, Any] = None,
|
params: dict[str, Any] = None,
|
||||||
) -> RunShieldResponse: ...
|
) -> RunShieldResponse: ...
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
from typing import Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
|
||||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
# mapping of metric to value
|
# mapping of metric to value
|
||||||
ScoringResultRow = Dict[str, Any]
|
ScoringResultRow = dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -24,15 +24,15 @@ class ScoringResult(BaseModel):
|
||||||
:param aggregated_results: Map of metric name to aggregated value
|
:param aggregated_results: Map of metric name to aggregated value
|
||||||
"""
|
"""
|
||||||
|
|
||||||
score_rows: List[ScoringResultRow]
|
score_rows: list[ScoringResultRow]
|
||||||
# aggregated metrics to value
|
# aggregated metrics to value
|
||||||
aggregated_results: Dict[str, Any]
|
aggregated_results: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ScoreBatchResponse(BaseModel):
|
class ScoreBatchResponse(BaseModel):
|
||||||
dataset_id: Optional[str] = None
|
dataset_id: str | None = None
|
||||||
results: Dict[str, ScoringResult]
|
results: dict[str, ScoringResult]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -44,7 +44,7 @@ class ScoreResponse(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# each key in the dict is a scoring function name
|
# each key in the dict is a scoring function name
|
||||||
results: Dict[str, ScoringResult]
|
results: dict[str, ScoringResult]
|
||||||
|
|
||||||
|
|
||||||
class ScoringFunctionStore(Protocol):
|
class ScoringFunctionStore(Protocol):
|
||||||
|
@ -59,15 +59,15 @@ class Scoring(Protocol):
|
||||||
async def score_batch(
|
async def score_batch(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
scoring_functions: Dict[str, Optional[ScoringFnParams]],
|
scoring_functions: dict[str, ScoringFnParams | None],
|
||||||
save_results_dataset: bool = False,
|
save_results_dataset: bool = False,
|
||||||
) -> ScoreBatchResponse: ...
|
) -> ScoreBatchResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/scoring/score", method="POST")
|
@webmethod(route="/scoring/score", method="POST")
|
||||||
async def score(
|
async def score(
|
||||||
self,
|
self,
|
||||||
input_rows: List[Dict[str, Any]],
|
input_rows: list[dict[str, Any]],
|
||||||
scoring_functions: Dict[str, Optional[ScoringFnParams]],
|
scoring_functions: dict[str, ScoringFnParams | None],
|
||||||
) -> ScoreResponse:
|
) -> ScoreResponse:
|
||||||
"""Score a list of rows.
|
"""Score a list of rows.
|
||||||
|
|
||||||
|
|
|
@ -6,18 +6,14 @@
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import (
|
from typing import (
|
||||||
|
Annotated,
|
||||||
Any,
|
Any,
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
|
||||||
Protocol,
|
Protocol,
|
||||||
Union,
|
|
||||||
runtime_checkable,
|
runtime_checkable,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from llama_stack.apis.common.type_system import ParamType
|
from llama_stack.apis.common.type_system import ParamType
|
||||||
from llama_stack.apis.resource import Resource, ResourceType
|
from llama_stack.apis.resource import Resource, ResourceType
|
||||||
|
@ -46,12 +42,12 @@ class AggregationFunctionType(Enum):
|
||||||
class LLMAsJudgeScoringFnParams(BaseModel):
|
class LLMAsJudgeScoringFnParams(BaseModel):
|
||||||
type: Literal[ScoringFnParamsType.llm_as_judge.value] = ScoringFnParamsType.llm_as_judge.value
|
type: Literal[ScoringFnParamsType.llm_as_judge.value] = ScoringFnParamsType.llm_as_judge.value
|
||||||
judge_model: str
|
judge_model: str
|
||||||
prompt_template: Optional[str] = None
|
prompt_template: str | None = None
|
||||||
judge_score_regexes: Optional[List[str]] = Field(
|
judge_score_regexes: list[str] | None = Field(
|
||||||
description="Regexes to extract the answer from generated response",
|
description="Regexes to extract the answer from generated response",
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
)
|
)
|
||||||
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
|
aggregation_functions: list[AggregationFunctionType] | None = Field(
|
||||||
description="Aggregation functions to apply to the scores of each row",
|
description="Aggregation functions to apply to the scores of each row",
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
)
|
)
|
||||||
|
@ -60,11 +56,11 @@ class LLMAsJudgeScoringFnParams(BaseModel):
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RegexParserScoringFnParams(BaseModel):
|
class RegexParserScoringFnParams(BaseModel):
|
||||||
type: Literal[ScoringFnParamsType.regex_parser.value] = ScoringFnParamsType.regex_parser.value
|
type: Literal[ScoringFnParamsType.regex_parser.value] = ScoringFnParamsType.regex_parser.value
|
||||||
parsing_regexes: Optional[List[str]] = Field(
|
parsing_regexes: list[str] | None = Field(
|
||||||
description="Regex to extract the answer from generated response",
|
description="Regex to extract the answer from generated response",
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
)
|
)
|
||||||
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
|
aggregation_functions: list[AggregationFunctionType] | None = Field(
|
||||||
description="Aggregation functions to apply to the scores of each row",
|
description="Aggregation functions to apply to the scores of each row",
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
)
|
)
|
||||||
|
@ -73,33 +69,29 @@ class RegexParserScoringFnParams(BaseModel):
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class BasicScoringFnParams(BaseModel):
|
class BasicScoringFnParams(BaseModel):
|
||||||
type: Literal[ScoringFnParamsType.basic.value] = ScoringFnParamsType.basic.value
|
type: Literal[ScoringFnParamsType.basic.value] = ScoringFnParamsType.basic.value
|
||||||
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
|
aggregation_functions: list[AggregationFunctionType] | None = Field(
|
||||||
description="Aggregation functions to apply to the scores of each row",
|
description="Aggregation functions to apply to the scores of each row",
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
ScoringFnParams = Annotated[
|
ScoringFnParams = Annotated[
|
||||||
Union[
|
LLMAsJudgeScoringFnParams | RegexParserScoringFnParams | BasicScoringFnParams,
|
||||||
LLMAsJudgeScoringFnParams,
|
|
||||||
RegexParserScoringFnParams,
|
|
||||||
BasicScoringFnParams,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(ScoringFnParams, name="ScoringFnParams")
|
register_schema(ScoringFnParams, name="ScoringFnParams")
|
||||||
|
|
||||||
|
|
||||||
class CommonScoringFnFields(BaseModel):
|
class CommonScoringFnFields(BaseModel):
|
||||||
description: Optional[str] = None
|
description: str | None = None
|
||||||
metadata: Dict[str, Any] = Field(
|
metadata: dict[str, Any] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
description="Any additional metadata for this definition",
|
description="Any additional metadata for this definition",
|
||||||
)
|
)
|
||||||
return_type: ParamType = Field(
|
return_type: ParamType = Field(
|
||||||
description="The return type of the deterministic function",
|
description="The return type of the deterministic function",
|
||||||
)
|
)
|
||||||
params: Optional[ScoringFnParams] = Field(
|
params: ScoringFnParams | None = Field(
|
||||||
description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval",
|
description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval",
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
@ -120,12 +112,12 @@ class ScoringFn(CommonScoringFnFields, Resource):
|
||||||
|
|
||||||
class ScoringFnInput(CommonScoringFnFields, BaseModel):
|
class ScoringFnInput(CommonScoringFnFields, BaseModel):
|
||||||
scoring_fn_id: str
|
scoring_fn_id: str
|
||||||
provider_id: Optional[str] = None
|
provider_id: str | None = None
|
||||||
provider_scoring_fn_id: Optional[str] = None
|
provider_scoring_fn_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class ListScoringFunctionsResponse(BaseModel):
|
class ListScoringFunctionsResponse(BaseModel):
|
||||||
data: List[ScoringFn]
|
data: list[ScoringFn]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
@ -142,7 +134,7 @@ class ScoringFunctions(Protocol):
|
||||||
scoring_fn_id: str,
|
scoring_fn_id: str,
|
||||||
description: str,
|
description: str,
|
||||||
return_type: ParamType,
|
return_type: ParamType,
|
||||||
provider_scoring_fn_id: Optional[str] = None,
|
provider_scoring_fn_id: str | None = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: str | None = None,
|
||||||
params: Optional[ScoringFnParams] = None,
|
params: ScoringFnParams | None = None,
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
from typing import Any, Literal, Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@ from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
class CommonShieldFields(BaseModel):
|
class CommonShieldFields(BaseModel):
|
||||||
params: Optional[Dict[str, Any]] = None
|
params: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -34,12 +34,12 @@ class Shield(CommonShieldFields, Resource):
|
||||||
|
|
||||||
class ShieldInput(CommonShieldFields):
|
class ShieldInput(CommonShieldFields):
|
||||||
shield_id: str
|
shield_id: str
|
||||||
provider_id: Optional[str] = None
|
provider_id: str | None = None
|
||||||
provider_shield_id: Optional[str] = None
|
provider_shield_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class ListShieldsResponse(BaseModel):
|
class ListShieldsResponse(BaseModel):
|
||||||
data: List[Shield]
|
data: list[Shield]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
@ -55,7 +55,7 @@ class Shields(Protocol):
|
||||||
async def register_shield(
|
async def register_shield(
|
||||||
self,
|
self,
|
||||||
shield_id: str,
|
shield_id: str,
|
||||||
provider_shield_id: Optional[str] = None,
|
provider_shield_id: str | None = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: str | None = None,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: dict[str, Any] | None = None,
|
||||||
) -> Shield: ...
|
) -> Shield: ...
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional, Protocol, Union
|
from typing import Any, Protocol
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -28,24 +28,24 @@ class FilteringFunction(Enum):
|
||||||
class SyntheticDataGenerationRequest(BaseModel):
|
class SyntheticDataGenerationRequest(BaseModel):
|
||||||
"""Request to generate synthetic data. A small batch of prompts and a filtering function"""
|
"""Request to generate synthetic data. A small batch of prompts and a filtering function"""
|
||||||
|
|
||||||
dialogs: List[Message]
|
dialogs: list[Message]
|
||||||
filtering_function: FilteringFunction = FilteringFunction.none
|
filtering_function: FilteringFunction = FilteringFunction.none
|
||||||
model: Optional[str] = None
|
model: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class SyntheticDataGenerationResponse(BaseModel):
|
class SyntheticDataGenerationResponse(BaseModel):
|
||||||
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."""
|
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."""
|
||||||
|
|
||||||
synthetic_data: List[Dict[str, Any]]
|
synthetic_data: list[dict[str, Any]]
|
||||||
statistics: Optional[Dict[str, Any]] = None
|
statistics: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
class SyntheticDataGeneration(Protocol):
|
class SyntheticDataGeneration(Protocol):
|
||||||
@webmethod(route="/synthetic-data-generation/generate")
|
@webmethod(route="/synthetic-data-generation/generate")
|
||||||
def synthetic_data_generate(
|
def synthetic_data_generate(
|
||||||
self,
|
self,
|
||||||
dialogs: List[Message],
|
dialogs: list[Message],
|
||||||
filtering_function: FilteringFunction = FilteringFunction.none,
|
filtering_function: FilteringFunction = FilteringFunction.none,
|
||||||
model: Optional[str] = None,
|
model: str | None = None,
|
||||||
) -> Union[SyntheticDataGenerationResponse]: ...
|
) -> SyntheticDataGenerationResponse: ...
|
||||||
|
|
|
@ -7,18 +7,14 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import (
|
from typing import (
|
||||||
|
Annotated,
|
||||||
Any,
|
Any,
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
|
||||||
Protocol,
|
Protocol,
|
||||||
Union,
|
|
||||||
runtime_checkable,
|
runtime_checkable,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from llama_stack.models.llama.datatypes import Primitive
|
from llama_stack.models.llama.datatypes import Primitive
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
@ -37,11 +33,11 @@ class SpanStatus(Enum):
|
||||||
class Span(BaseModel):
|
class Span(BaseModel):
|
||||||
span_id: str
|
span_id: str
|
||||||
trace_id: str
|
trace_id: str
|
||||||
parent_span_id: Optional[str] = None
|
parent_span_id: str | None = None
|
||||||
name: str
|
name: str
|
||||||
start_time: datetime
|
start_time: datetime
|
||||||
end_time: Optional[datetime] = None
|
end_time: datetime | None = None
|
||||||
attributes: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
attributes: dict[str, Any] | None = Field(default_factory=dict)
|
||||||
|
|
||||||
def set_attribute(self, key: str, value: Any):
|
def set_attribute(self, key: str, value: Any):
|
||||||
if self.attributes is None:
|
if self.attributes is None:
|
||||||
|
@ -54,7 +50,7 @@ class Trace(BaseModel):
|
||||||
trace_id: str
|
trace_id: str
|
||||||
root_span_id: str
|
root_span_id: str
|
||||||
start_time: datetime
|
start_time: datetime
|
||||||
end_time: Optional[datetime] = None
|
end_time: datetime | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -78,7 +74,7 @@ class EventCommon(BaseModel):
|
||||||
trace_id: str
|
trace_id: str
|
||||||
span_id: str
|
span_id: str
|
||||||
timestamp: datetime
|
timestamp: datetime
|
||||||
attributes: Optional[Dict[str, Primitive]] = Field(default_factory=dict)
|
attributes: dict[str, Primitive] | None = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -92,15 +88,15 @@ class UnstructuredLogEvent(EventCommon):
|
||||||
class MetricEvent(EventCommon):
|
class MetricEvent(EventCommon):
|
||||||
type: Literal[EventType.METRIC.value] = EventType.METRIC.value
|
type: Literal[EventType.METRIC.value] = EventType.METRIC.value
|
||||||
metric: str # this would be an enum
|
metric: str # this would be an enum
|
||||||
value: Union[int, float]
|
value: int | float
|
||||||
unit: str
|
unit: str
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class MetricInResponse(BaseModel):
|
class MetricInResponse(BaseModel):
|
||||||
metric: str
|
metric: str
|
||||||
value: Union[int, float]
|
value: int | float
|
||||||
unit: Optional[str] = None
|
unit: str | None = None
|
||||||
|
|
||||||
|
|
||||||
# This is a short term solution to allow inference API to return metrics
|
# This is a short term solution to allow inference API to return metrics
|
||||||
|
@ -124,7 +120,7 @@ class MetricInResponse(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class MetricResponseMixin(BaseModel):
|
class MetricResponseMixin(BaseModel):
|
||||||
metrics: Optional[List[MetricInResponse]] = None
|
metrics: list[MetricInResponse] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -137,7 +133,7 @@ class StructuredLogType(Enum):
|
||||||
class SpanStartPayload(BaseModel):
|
class SpanStartPayload(BaseModel):
|
||||||
type: Literal[StructuredLogType.SPAN_START.value] = StructuredLogType.SPAN_START.value
|
type: Literal[StructuredLogType.SPAN_START.value] = StructuredLogType.SPAN_START.value
|
||||||
name: str
|
name: str
|
||||||
parent_span_id: Optional[str] = None
|
parent_span_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -147,10 +143,7 @@ class SpanEndPayload(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
StructuredLogPayload = Annotated[
|
StructuredLogPayload = Annotated[
|
||||||
Union[
|
SpanStartPayload | SpanEndPayload,
|
||||||
SpanStartPayload,
|
|
||||||
SpanEndPayload,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(StructuredLogPayload, name="StructuredLogPayload")
|
register_schema(StructuredLogPayload, name="StructuredLogPayload")
|
||||||
|
@ -163,11 +156,7 @@ class StructuredLogEvent(EventCommon):
|
||||||
|
|
||||||
|
|
||||||
Event = Annotated[
|
Event = Annotated[
|
||||||
Union[
|
UnstructuredLogEvent | MetricEvent | StructuredLogEvent,
|
||||||
UnstructuredLogEvent,
|
|
||||||
MetricEvent,
|
|
||||||
StructuredLogEvent,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(Event, name="Event")
|
register_schema(Event, name="Event")
|
||||||
|
@ -184,7 +173,7 @@ class EvalTrace(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class SpanWithStatus(Span):
|
class SpanWithStatus(Span):
|
||||||
status: Optional[SpanStatus] = None
|
status: SpanStatus | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -203,15 +192,15 @@ class QueryCondition(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class QueryTracesResponse(BaseModel):
|
class QueryTracesResponse(BaseModel):
|
||||||
data: List[Trace]
|
data: list[Trace]
|
||||||
|
|
||||||
|
|
||||||
class QuerySpansResponse(BaseModel):
|
class QuerySpansResponse(BaseModel):
|
||||||
data: List[Span]
|
data: list[Span]
|
||||||
|
|
||||||
|
|
||||||
class QuerySpanTreeResponse(BaseModel):
|
class QuerySpanTreeResponse(BaseModel):
|
||||||
data: Dict[str, SpanWithStatus]
|
data: dict[str, SpanWithStatus]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
@ -222,10 +211,10 @@ class Telemetry(Protocol):
|
||||||
@webmethod(route="/telemetry/traces", method="POST")
|
@webmethod(route="/telemetry/traces", method="POST")
|
||||||
async def query_traces(
|
async def query_traces(
|
||||||
self,
|
self,
|
||||||
attribute_filters: Optional[List[QueryCondition]] = None,
|
attribute_filters: list[QueryCondition] | None = None,
|
||||||
limit: Optional[int] = 100,
|
limit: int | None = 100,
|
||||||
offset: Optional[int] = 0,
|
offset: int | None = 0,
|
||||||
order_by: Optional[List[str]] = None,
|
order_by: list[str] | None = None,
|
||||||
) -> QueryTracesResponse: ...
|
) -> QueryTracesResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/telemetry/traces/{trace_id:path}", method="GET")
|
@webmethod(route="/telemetry/traces/{trace_id:path}", method="GET")
|
||||||
|
@ -238,23 +227,23 @@ class Telemetry(Protocol):
|
||||||
async def get_span_tree(
|
async def get_span_tree(
|
||||||
self,
|
self,
|
||||||
span_id: str,
|
span_id: str,
|
||||||
attributes_to_return: Optional[List[str]] = None,
|
attributes_to_return: list[str] | None = None,
|
||||||
max_depth: Optional[int] = None,
|
max_depth: int | None = None,
|
||||||
) -> QuerySpanTreeResponse: ...
|
) -> QuerySpanTreeResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/telemetry/spans", method="POST")
|
@webmethod(route="/telemetry/spans", method="POST")
|
||||||
async def query_spans(
|
async def query_spans(
|
||||||
self,
|
self,
|
||||||
attribute_filters: List[QueryCondition],
|
attribute_filters: list[QueryCondition],
|
||||||
attributes_to_return: List[str],
|
attributes_to_return: list[str],
|
||||||
max_depth: Optional[int] = None,
|
max_depth: int | None = None,
|
||||||
) -> QuerySpansResponse: ...
|
) -> QuerySpansResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/telemetry/spans/export", method="POST")
|
@webmethod(route="/telemetry/spans/export", method="POST")
|
||||||
async def save_spans_to_dataset(
|
async def save_spans_to_dataset(
|
||||||
self,
|
self,
|
||||||
attribute_filters: List[QueryCondition],
|
attribute_filters: list[QueryCondition],
|
||||||
attributes_to_save: List[str],
|
attributes_to_save: list[str],
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
max_depth: Optional[int] = None,
|
max_depth: int | None = None,
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
|
@ -5,10 +5,10 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Literal, Optional, Union
|
from typing import Annotated, Any, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated, Protocol, runtime_checkable
|
from typing_extensions import Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
|
@ -29,13 +29,13 @@ class RAGDocument(BaseModel):
|
||||||
document_id: str
|
document_id: str
|
||||||
content: InterleavedContent | URL
|
content: InterleavedContent | URL
|
||||||
mime_type: str | None = None
|
mime_type: str | None = None
|
||||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RAGQueryResult(BaseModel):
|
class RAGQueryResult(BaseModel):
|
||||||
content: Optional[InterleavedContent] = None
|
content: InterleavedContent | None = None
|
||||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -59,10 +59,7 @@ class LLMRAGQueryGeneratorConfig(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
RAGQueryGeneratorConfig = Annotated[
|
RAGQueryGeneratorConfig = Annotated[
|
||||||
Union[
|
DefaultRAGQueryGeneratorConfig | LLMRAGQueryGeneratorConfig,
|
||||||
DefaultRAGQueryGeneratorConfig,
|
|
||||||
LLMRAGQueryGeneratorConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig")
|
register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig")
|
||||||
|
@ -83,7 +80,7 @@ class RAGToolRuntime(Protocol):
|
||||||
@webmethod(route="/tool-runtime/rag-tool/insert", method="POST")
|
@webmethod(route="/tool-runtime/rag-tool/insert", method="POST")
|
||||||
async def insert(
|
async def insert(
|
||||||
self,
|
self,
|
||||||
documents: List[RAGDocument],
|
documents: list[RAGDocument],
|
||||||
vector_db_id: str,
|
vector_db_id: str,
|
||||||
chunk_size_in_tokens: int = 512,
|
chunk_size_in_tokens: int = 512,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -94,8 +91,8 @@ class RAGToolRuntime(Protocol):
|
||||||
async def query(
|
async def query(
|
||||||
self,
|
self,
|
||||||
content: InterleavedContent,
|
content: InterleavedContent,
|
||||||
vector_db_ids: List[str],
|
vector_db_ids: list[str],
|
||||||
query_config: Optional[RAGQueryConfig] = None,
|
query_config: RAGQueryConfig | None = None,
|
||||||
) -> RAGQueryResult:
|
) -> RAGQueryResult:
|
||||||
"""Query the RAG system for context; typically invoked by the agent"""
|
"""Query the RAG system for context; typically invoked by the agent"""
|
||||||
...
|
...
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Literal, Optional
|
from typing import Any, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Protocol, runtime_checkable
|
from typing_extensions import Protocol, runtime_checkable
|
||||||
|
@ -24,7 +24,7 @@ class ToolParameter(BaseModel):
|
||||||
parameter_type: str
|
parameter_type: str
|
||||||
description: str
|
description: str
|
||||||
required: bool = Field(default=True)
|
required: bool = Field(default=True)
|
||||||
default: Optional[Any] = None
|
default: Any | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -40,39 +40,39 @@ class Tool(Resource):
|
||||||
toolgroup_id: str
|
toolgroup_id: str
|
||||||
tool_host: ToolHost
|
tool_host: ToolHost
|
||||||
description: str
|
description: str
|
||||||
parameters: List[ToolParameter]
|
parameters: list[ToolParameter]
|
||||||
metadata: Optional[Dict[str, Any]] = None
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ToolDef(BaseModel):
|
class ToolDef(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
description: Optional[str] = None
|
description: str | None = None
|
||||||
parameters: Optional[List[ToolParameter]] = None
|
parameters: list[ToolParameter] | None = None
|
||||||
metadata: Optional[Dict[str, Any]] = None
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ToolGroupInput(BaseModel):
|
class ToolGroupInput(BaseModel):
|
||||||
toolgroup_id: str
|
toolgroup_id: str
|
||||||
provider_id: str
|
provider_id: str
|
||||||
args: Optional[Dict[str, Any]] = None
|
args: dict[str, Any] | None = None
|
||||||
mcp_endpoint: Optional[URL] = None
|
mcp_endpoint: URL | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ToolGroup(Resource):
|
class ToolGroup(Resource):
|
||||||
type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value
|
type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value
|
||||||
mcp_endpoint: Optional[URL] = None
|
mcp_endpoint: URL | None = None
|
||||||
args: Optional[Dict[str, Any]] = None
|
args: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ToolInvocationResult(BaseModel):
|
class ToolInvocationResult(BaseModel):
|
||||||
content: Optional[InterleavedContent] = None
|
content: InterleavedContent | None = None
|
||||||
error_message: Optional[str] = None
|
error_message: str | None = None
|
||||||
error_code: Optional[int] = None
|
error_code: int | None = None
|
||||||
metadata: Optional[Dict[str, Any]] = None
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
class ToolStore(Protocol):
|
class ToolStore(Protocol):
|
||||||
|
@ -81,11 +81,11 @@ class ToolStore(Protocol):
|
||||||
|
|
||||||
|
|
||||||
class ListToolGroupsResponse(BaseModel):
|
class ListToolGroupsResponse(BaseModel):
|
||||||
data: List[ToolGroup]
|
data: list[ToolGroup]
|
||||||
|
|
||||||
|
|
||||||
class ListToolsResponse(BaseModel):
|
class ListToolsResponse(BaseModel):
|
||||||
data: List[Tool]
|
data: list[Tool]
|
||||||
|
|
||||||
|
|
||||||
class ListToolDefsResponse(BaseModel):
|
class ListToolDefsResponse(BaseModel):
|
||||||
|
@ -100,8 +100,8 @@ class ToolGroups(Protocol):
|
||||||
self,
|
self,
|
||||||
toolgroup_id: str,
|
toolgroup_id: str,
|
||||||
provider_id: str,
|
provider_id: str,
|
||||||
mcp_endpoint: Optional[URL] = None,
|
mcp_endpoint: URL | None = None,
|
||||||
args: Optional[Dict[str, Any]] = None,
|
args: dict[str, Any] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Register a tool group"""
|
"""Register a tool group"""
|
||||||
...
|
...
|
||||||
|
@ -118,7 +118,7 @@ class ToolGroups(Protocol):
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/tools", method="GET")
|
@webmethod(route="/tools", method="GET")
|
||||||
async def list_tools(self, toolgroup_id: Optional[str] = None) -> ListToolsResponse:
|
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
||||||
"""List tools with optional tool group"""
|
"""List tools with optional tool group"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -151,10 +151,10 @@ class ToolRuntime(Protocol):
|
||||||
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
|
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
|
||||||
@webmethod(route="/tool-runtime/list-tools", method="GET")
|
@webmethod(route="/tool-runtime/list-tools", method="GET")
|
||||||
async def list_runtime_tools(
|
async def list_runtime_tools(
|
||||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||||
) -> ListToolDefsResponse: ...
|
) -> ListToolDefsResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/tool-runtime/invoke", method="POST")
|
@webmethod(route="/tool-runtime/invoke", method="POST")
|
||||||
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
|
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
||||||
"""Run a tool with the given arguments"""
|
"""Run a tool with the given arguments"""
|
||||||
...
|
...
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import List, Literal, Optional, Protocol, runtime_checkable
|
from typing import Literal, Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -33,11 +33,11 @@ class VectorDBInput(BaseModel):
|
||||||
vector_db_id: str
|
vector_db_id: str
|
||||||
embedding_model: str
|
embedding_model: str
|
||||||
embedding_dimension: int
|
embedding_dimension: int
|
||||||
provider_vector_db_id: Optional[str] = None
|
provider_vector_db_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class ListVectorDBsResponse(BaseModel):
|
class ListVectorDBsResponse(BaseModel):
|
||||||
data: List[VectorDB]
|
data: list[VectorDB]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
@ -57,9 +57,9 @@ class VectorDBs(Protocol):
|
||||||
self,
|
self,
|
||||||
vector_db_id: str,
|
vector_db_id: str,
|
||||||
embedding_model: str,
|
embedding_model: str,
|
||||||
embedding_dimension: Optional[int] = 384,
|
embedding_dimension: int | None = 384,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: str | None = None,
|
||||||
provider_vector_db_id: Optional[str] = None,
|
provider_vector_db_id: str | None = None,
|
||||||
) -> VectorDB: ...
|
) -> VectorDB: ...
|
||||||
|
|
||||||
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE")
|
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE")
|
||||||
|
|
|
@ -8,7 +8,7 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
from typing import Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
@ -20,17 +20,17 @@ from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
class Chunk(BaseModel):
|
class Chunk(BaseModel):
|
||||||
content: InterleavedContent
|
content: InterleavedContent
|
||||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class QueryChunksResponse(BaseModel):
|
class QueryChunksResponse(BaseModel):
|
||||||
chunks: List[Chunk]
|
chunks: list[Chunk]
|
||||||
scores: List[float]
|
scores: list[float]
|
||||||
|
|
||||||
|
|
||||||
class VectorDBStore(Protocol):
|
class VectorDBStore(Protocol):
|
||||||
def get_vector_db(self, vector_db_id: str) -> Optional[VectorDB]: ...
|
def get_vector_db(self, vector_db_id: str) -> VectorDB | None: ...
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
@ -44,8 +44,8 @@ class VectorIO(Protocol):
|
||||||
async def insert_chunks(
|
async def insert_chunks(
|
||||||
self,
|
self,
|
||||||
vector_db_id: str,
|
vector_db_id: str,
|
||||||
chunks: List[Chunk],
|
chunks: list[Chunk],
|
||||||
ttl_seconds: Optional[int] = None,
|
ttl_seconds: int | None = None,
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
@webmethod(route="/vector-io/query", method="POST")
|
@webmethod(route="/vector-io/query", method="POST")
|
||||||
|
@ -53,5 +53,5 @@ class VectorIO(Protocol):
|
||||||
self,
|
self,
|
||||||
vector_db_id: str,
|
vector_db_id: str,
|
||||||
query: InterleavedContent,
|
query: InterleavedContent,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: dict[str, Any] | None = None,
|
||||||
) -> QueryChunksResponse: ...
|
) -> QueryChunksResponse: ...
|
||||||
|
|
|
@ -13,7 +13,6 @@ from dataclasses import dataclass
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
@ -102,7 +101,7 @@ class DownloadTask:
|
||||||
output_file: str
|
output_file: str
|
||||||
total_size: int = 0
|
total_size: int = 0
|
||||||
downloaded_size: int = 0
|
downloaded_size: int = 0
|
||||||
task_id: Optional[int] = None
|
task_id: int | None = None
|
||||||
retries: int = 0
|
retries: int = 0
|
||||||
max_retries: int = 3
|
max_retries: int = 3
|
||||||
|
|
||||||
|
@ -262,7 +261,7 @@ class ParallelDownloader:
|
||||||
self.progress.update(task.task_id, description=f"[red]Failed: {task.output_file}[/red]")
|
self.progress.update(task.task_id, description=f"[red]Failed: {task.output_file}[/red]")
|
||||||
raise DownloadError(f"Download failed for {task.output_file}: {str(e)}") from e
|
raise DownloadError(f"Download failed for {task.output_file}: {str(e)}") from e
|
||||||
|
|
||||||
def has_disk_space(self, tasks: List[DownloadTask]) -> bool:
|
def has_disk_space(self, tasks: list[DownloadTask]) -> bool:
|
||||||
try:
|
try:
|
||||||
total_remaining_size = sum(task.total_size - task.downloaded_size for task in tasks)
|
total_remaining_size = sum(task.total_size - task.downloaded_size for task in tasks)
|
||||||
dir_path = os.path.dirname(os.path.abspath(tasks[0].output_file))
|
dir_path = os.path.dirname(os.path.abspath(tasks[0].output_file))
|
||||||
|
@ -282,7 +281,7 @@ class ParallelDownloader:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise DownloadError(f"Failed to check disk space: {str(e)}") from e
|
raise DownloadError(f"Failed to check disk space: {str(e)}") from e
|
||||||
|
|
||||||
async def download_all(self, tasks: List[DownloadTask]) -> None:
|
async def download_all(self, tasks: list[DownloadTask]) -> None:
|
||||||
if not tasks:
|
if not tasks:
|
||||||
raise ValueError("No download tasks provided")
|
raise ValueError("No download tasks provided")
|
||||||
|
|
||||||
|
@ -391,20 +390,20 @@ def _meta_download(
|
||||||
|
|
||||||
class ModelEntry(BaseModel):
|
class ModelEntry(BaseModel):
|
||||||
model_id: str
|
model_id: str
|
||||||
files: Dict[str, str]
|
files: dict[str, str]
|
||||||
|
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
class Manifest(BaseModel):
|
class Manifest(BaseModel):
|
||||||
models: List[ModelEntry]
|
models: list[ModelEntry]
|
||||||
expires_on: datetime
|
expires_on: datetime
|
||||||
|
|
||||||
|
|
||||||
def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
|
def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
|
||||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
|
|
||||||
with open(manifest_file, "r") as f:
|
with open(manifest_file) as f:
|
||||||
d = json.load(f)
|
d = json.load(f)
|
||||||
manifest = Manifest(**d)
|
manifest = Manifest(**d)
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ class PromptGuardModel(BaseModel):
|
||||||
max_seq_length: int = 512
|
max_seq_length: int = 512
|
||||||
is_instruct_model: bool = False
|
is_instruct_model: bool = False
|
||||||
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
|
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
|
||||||
arch_args: Dict[str, Any] = Field(default_factory=dict)
|
arch_args: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
def descriptor(self) -> str:
|
def descriptor(self) -> str:
|
||||||
return self.model_id
|
return self.model_id
|
||||||
|
@ -44,11 +44,11 @@ def prompt_guard_model_skus():
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def prompt_guard_model_sku_map() -> Dict[str, Any]:
|
def prompt_guard_model_sku_map() -> dict[str, Any]:
|
||||||
return {model.model_id: model for model in prompt_guard_model_skus()}
|
return {model.model_id: model for model in prompt_guard_model_skus()}
|
||||||
|
|
||||||
|
|
||||||
def prompt_guard_download_info_map() -> Dict[str, LlamaDownloadInfo]:
|
def prompt_guard_download_info_map() -> dict[str, LlamaDownloadInfo]:
|
||||||
return {
|
return {
|
||||||
model.model_id: LlamaDownloadInfo(
|
model.model_id: LlamaDownloadInfo(
|
||||||
folder="Prompt-Guard" if model.model_id == "Prompt-Guard-86M" else model.model_id,
|
folder="Prompt-Guard" if model.model_id == "Prompt-Guard-86M" else model.model_id,
|
||||||
|
|
|
@ -13,7 +13,6 @@ import sys
|
||||||
import textwrap
|
import textwrap
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Optional
|
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from prompt_toolkit import prompt
|
from prompt_toolkit import prompt
|
||||||
|
@ -46,14 +45,14 @@ from llama_stack.providers.datatypes import Api
|
||||||
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates"
|
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates"
|
||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache
|
||||||
def available_templates_specs() -> Dict[str, BuildConfig]:
|
def available_templates_specs() -> dict[str, BuildConfig]:
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
template_specs = {}
|
template_specs = {}
|
||||||
for p in TEMPLATES_PATH.rglob("*build.yaml"):
|
for p in TEMPLATES_PATH.rglob("*build.yaml"):
|
||||||
template_name = p.parent.name
|
template_name = p.parent.name
|
||||||
with open(p, "r") as f:
|
with open(p) as f:
|
||||||
build_config = BuildConfig(**yaml.safe_load(f))
|
build_config = BuildConfig(**yaml.safe_load(f))
|
||||||
template_specs[template_name] = build_config
|
template_specs[template_name] = build_config
|
||||||
return template_specs
|
return template_specs
|
||||||
|
@ -178,7 +177,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
if not available_providers:
|
if not available_providers:
|
||||||
continue
|
continue
|
||||||
api_provider = prompt(
|
api_provider = prompt(
|
||||||
"> Enter provider for API {}: ".format(api.value),
|
f"> Enter provider for API {api.value}: ",
|
||||||
completer=WordCompleter(available_providers),
|
completer=WordCompleter(available_providers),
|
||||||
complete_while_typing=True,
|
complete_while_typing=True,
|
||||||
validator=Validator.from_callable(
|
validator=Validator.from_callable(
|
||||||
|
@ -201,7 +200,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
|
|
||||||
build_config = BuildConfig(image_type=image_type, distribution_spec=distribution_spec)
|
build_config = BuildConfig(image_type=image_type, distribution_spec=distribution_spec)
|
||||||
else:
|
else:
|
||||||
with open(args.config, "r") as f:
|
with open(args.config) as f:
|
||||||
try:
|
try:
|
||||||
build_config = BuildConfig(**yaml.safe_load(f))
|
build_config = BuildConfig(**yaml.safe_load(f))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -332,9 +331,9 @@ def _generate_run_config(
|
||||||
|
|
||||||
def _run_stack_build_command_from_build_config(
|
def _run_stack_build_command_from_build_config(
|
||||||
build_config: BuildConfig,
|
build_config: BuildConfig,
|
||||||
image_name: Optional[str] = None,
|
image_name: str | None = None,
|
||||||
template_name: Optional[str] = None,
|
template_name: str | None = None,
|
||||||
config_path: Optional[str] = None,
|
config_path: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
image_name = image_name or build_config.image_name
|
image_name = image_name or build_config.image_name
|
||||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Iterable
|
from collections.abc import Iterable
|
||||||
|
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.table import Table
|
from rich.table import Table
|
||||||
|
|
|
@ -9,7 +9,6 @@ import hashlib
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional
|
|
||||||
|
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.progress import Progress, SpinnerColumn, TextColumn
|
from rich.progress import Progress, SpinnerColumn, TextColumn
|
||||||
|
@ -21,7 +20,7 @@ from llama_stack.cli.subcommand import Subcommand
|
||||||
class VerificationResult:
|
class VerificationResult:
|
||||||
filename: str
|
filename: str
|
||||||
expected_hash: str
|
expected_hash: str
|
||||||
actual_hash: Optional[str]
|
actual_hash: str | None
|
||||||
exists: bool
|
exists: bool
|
||||||
matches: bool
|
matches: bool
|
||||||
|
|
||||||
|
@ -60,9 +59,9 @@ def calculate_md5(filepath: Path, chunk_size: int = 8192) -> str:
|
||||||
return md5_hash.hexdigest()
|
return md5_hash.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def load_checksums(checklist_path: Path) -> Dict[str, str]:
|
def load_checksums(checklist_path: Path) -> dict[str, str]:
|
||||||
checksums = {}
|
checksums = {}
|
||||||
with open(checklist_path, "r") as f:
|
with open(checklist_path) as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
if line.strip():
|
if line.strip():
|
||||||
md5sum, filepath = line.strip().split(" ", 1)
|
md5sum, filepath = line.strip().split(" ", 1)
|
||||||
|
@ -72,7 +71,7 @@ def load_checksums(checklist_path: Path) -> Dict[str, str]:
|
||||||
return checksums
|
return checksums
|
||||||
|
|
||||||
|
|
||||||
def verify_files(model_dir: Path, checksums: Dict[str, str], console: Console) -> List[VerificationResult]:
|
def verify_files(model_dir: Path, checksums: dict[str, str], console: Console) -> list[VerificationResult]:
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
with Progress(
|
with Progress(
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import AccessAttributes
|
from llama_stack.distribution.datatypes import AccessAttributes
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
@ -14,8 +14,8 @@ logger = get_logger(__name__, category="core")
|
||||||
|
|
||||||
def check_access(
|
def check_access(
|
||||||
obj_identifier: str,
|
obj_identifier: str,
|
||||||
obj_attributes: Optional[AccessAttributes],
|
obj_attributes: AccessAttributes | None,
|
||||||
user_attributes: Optional[Dict[str, Any]] = None,
|
user_attributes: dict[str, Any] | None = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Check if the current user has access to the given object, based on access attributes.
|
"""Check if the current user has access to the given object, based on access attributes.
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ import inspect
|
||||||
import json
|
import json
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Type, Union, get_args, get_origin
|
from typing import Any, Union, get_args, get_origin
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from pydantic import BaseModel, parse_obj_as
|
from pydantic import BaseModel, parse_obj_as
|
||||||
|
@ -27,7 +27,7 @@ async def get_client_impl(protocol, config: RemoteProviderConfig, _deps: Any):
|
||||||
return impl
|
return impl
|
||||||
|
|
||||||
|
|
||||||
def create_api_client_class(protocol) -> Type:
|
def create_api_client_class(protocol) -> type:
|
||||||
if protocol in _CLIENT_CLASSES:
|
if protocol in _CLIENT_CLASSES:
|
||||||
return _CLIENT_CLASSES[protocol]
|
return _CLIENT_CLASSES[protocol]
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import logging
|
import logging
|
||||||
import textwrap
|
import textwrap
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
LLAMA_STACK_RUN_CONFIG_VERSION,
|
LLAMA_STACK_RUN_CONFIG_VERSION,
|
||||||
|
@ -24,7 +24,7 @@ from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def configure_single_provider(registry: Dict[str, ProviderSpec], provider: Provider) -> Provider:
|
def configure_single_provider(registry: dict[str, ProviderSpec], provider: Provider) -> Provider:
|
||||||
provider_spec = registry[provider.provider_type]
|
provider_spec = registry[provider.provider_type]
|
||||||
config_type = instantiate_class_type(provider_spec.config_class)
|
config_type = instantiate_class_type(provider_spec.config_class)
|
||||||
try:
|
try:
|
||||||
|
@ -120,8 +120,8 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
|
||||||
|
|
||||||
|
|
||||||
def upgrade_from_routing_table(
|
def upgrade_from_routing_table(
|
||||||
config_dict: Dict[str, Any],
|
config_dict: dict[str, Any],
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
def get_providers(entries):
|
def get_providers(entries):
|
||||||
return [
|
return [
|
||||||
Provider(
|
Provider(
|
||||||
|
@ -163,7 +163,7 @@ def upgrade_from_routing_table(
|
||||||
return config_dict
|
return config_dict
|
||||||
|
|
||||||
|
|
||||||
def parse_and_maybe_upgrade_config(config_dict: Dict[str, Any]) -> StackRunConfig:
|
def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfig:
|
||||||
version = config_dict.get("version", None)
|
version = config_dict.get("version", None)
|
||||||
if version == LLAMA_STACK_RUN_CONFIG_VERSION:
|
if version == LLAMA_STACK_RUN_CONFIG_VERSION:
|
||||||
return StackRunConfig(**config_dict)
|
return StackRunConfig(**config_dict)
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Annotated, Any, Dict, List, Optional, Union
|
from typing import Annotated, Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
@ -30,7 +30,7 @@ LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
|
||||||
LLAMA_STACK_RUN_CONFIG_VERSION = "2"
|
LLAMA_STACK_RUN_CONFIG_VERSION = "2"
|
||||||
|
|
||||||
|
|
||||||
RoutingKey = Union[str, List[str]]
|
RoutingKey = str | list[str]
|
||||||
|
|
||||||
|
|
||||||
class AccessAttributes(BaseModel):
|
class AccessAttributes(BaseModel):
|
||||||
|
@ -47,17 +47,17 @@ class AccessAttributes(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Standard attribute categories - the minimal set we need now
|
# Standard attribute categories - the minimal set we need now
|
||||||
roles: Optional[List[str]] = Field(
|
roles: list[str] | None = Field(
|
||||||
default=None, description="Role-based attributes (e.g., 'admin', 'data-scientist', 'user')"
|
default=None, description="Role-based attributes (e.g., 'admin', 'data-scientist', 'user')"
|
||||||
)
|
)
|
||||||
|
|
||||||
teams: Optional[List[str]] = Field(default=None, description="Team-based attributes (e.g., 'ml-team', 'nlp-team')")
|
teams: list[str] | None = Field(default=None, description="Team-based attributes (e.g., 'ml-team', 'nlp-team')")
|
||||||
|
|
||||||
projects: Optional[List[str]] = Field(
|
projects: list[str] | None = Field(
|
||||||
default=None, description="Project-based access attributes (e.g., 'llama-3', 'customer-insights')"
|
default=None, description="Project-based access attributes (e.g., 'llama-3', 'customer-insights')"
|
||||||
)
|
)
|
||||||
|
|
||||||
namespaces: Optional[List[str]] = Field(
|
namespaces: list[str] | None = Field(
|
||||||
default=None, description="Namespace-based access control for resource isolation"
|
default=None, description="Namespace-based access control for resource isolation"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -106,7 +106,7 @@ class ResourceWithACL(Resource):
|
||||||
# ^ User must have access to the customer-insights project AND have confidential namespace
|
# ^ User must have access to the customer-insights project AND have confidential namespace
|
||||||
"""
|
"""
|
||||||
|
|
||||||
access_attributes: Optional[AccessAttributes] = None
|
access_attributes: AccessAttributes | None = None
|
||||||
|
|
||||||
|
|
||||||
# Use the extended Resource for all routable objects
|
# Use the extended Resource for all routable objects
|
||||||
|
@ -142,41 +142,21 @@ class ToolGroupWithACL(ToolGroup, ResourceWithACL):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
RoutableObject = Union[
|
RoutableObject = Model | Shield | VectorDB | Dataset | ScoringFn | Benchmark | Tool | ToolGroup
|
||||||
Model,
|
|
||||||
Shield,
|
|
||||||
VectorDB,
|
|
||||||
Dataset,
|
|
||||||
ScoringFn,
|
|
||||||
Benchmark,
|
|
||||||
Tool,
|
|
||||||
ToolGroup,
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
RoutableObjectWithProvider = Annotated[
|
RoutableObjectWithProvider = Annotated[
|
||||||
Union[
|
ModelWithACL
|
||||||
ModelWithACL,
|
| ShieldWithACL
|
||||||
ShieldWithACL,
|
| VectorDBWithACL
|
||||||
VectorDBWithACL,
|
| DatasetWithACL
|
||||||
DatasetWithACL,
|
| ScoringFnWithACL
|
||||||
ScoringFnWithACL,
|
| BenchmarkWithACL
|
||||||
BenchmarkWithACL,
|
| ToolWithACL
|
||||||
ToolWithACL,
|
| ToolGroupWithACL,
|
||||||
ToolGroupWithACL,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
|
|
||||||
RoutedProtocol = Union[
|
RoutedProtocol = Inference | Safety | VectorIO | DatasetIO | Scoring | Eval | ToolRuntime
|
||||||
Inference,
|
|
||||||
Safety,
|
|
||||||
VectorIO,
|
|
||||||
DatasetIO,
|
|
||||||
Scoring,
|
|
||||||
Eval,
|
|
||||||
ToolRuntime,
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# Example: /inference, /safety
|
# Example: /inference, /safety
|
||||||
|
@ -184,15 +164,15 @@ class AutoRoutedProviderSpec(ProviderSpec):
|
||||||
provider_type: str = "router"
|
provider_type: str = "router"
|
||||||
config_class: str = ""
|
config_class: str = ""
|
||||||
|
|
||||||
container_image: Optional[str] = None
|
container_image: str | None = None
|
||||||
routing_table_api: Api
|
routing_table_api: Api
|
||||||
module: str
|
module: str
|
||||||
provider_data_validator: Optional[str] = Field(
|
provider_data_validator: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pip_packages(self) -> List[str]:
|
def pip_packages(self) -> list[str]:
|
||||||
raise AssertionError("Should not be called on AutoRoutedProviderSpec")
|
raise AssertionError("Should not be called on AutoRoutedProviderSpec")
|
||||||
|
|
||||||
|
|
||||||
|
@ -200,20 +180,20 @@ class AutoRoutedProviderSpec(ProviderSpec):
|
||||||
class RoutingTableProviderSpec(ProviderSpec):
|
class RoutingTableProviderSpec(ProviderSpec):
|
||||||
provider_type: str = "routing_table"
|
provider_type: str = "routing_table"
|
||||||
config_class: str = ""
|
config_class: str = ""
|
||||||
container_image: Optional[str] = None
|
container_image: str | None = None
|
||||||
|
|
||||||
router_api: Api
|
router_api: Api
|
||||||
module: str
|
module: str
|
||||||
pip_packages: List[str] = Field(default_factory=list)
|
pip_packages: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class DistributionSpec(BaseModel):
|
class DistributionSpec(BaseModel):
|
||||||
description: Optional[str] = Field(
|
description: str | None = Field(
|
||||||
default="",
|
default="",
|
||||||
description="Description of the distribution",
|
description="Description of the distribution",
|
||||||
)
|
)
|
||||||
container_image: Optional[str] = None
|
container_image: str | None = None
|
||||||
providers: Dict[str, Union[str, List[str]]] = Field(
|
providers: dict[str, str | list[str]] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
description="""
|
description="""
|
||||||
Provider Types for each of the APIs provided by this distribution. If you
|
Provider Types for each of the APIs provided by this distribution. If you
|
||||||
|
@ -225,12 +205,12 @@ in the runtime configuration to help route to the correct provider.""",
|
||||||
class Provider(BaseModel):
|
class Provider(BaseModel):
|
||||||
provider_id: str
|
provider_id: str
|
||||||
provider_type: str
|
provider_type: str
|
||||||
config: Dict[str, Any]
|
config: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
class LoggingConfig(BaseModel):
|
class LoggingConfig(BaseModel):
|
||||||
category_levels: Dict[str, str] = Field(
|
category_levels: dict[str, str] = Field(
|
||||||
default_factory=Dict,
|
default_factory=dict,
|
||||||
description="""
|
description="""
|
||||||
Dictionary of different logging configurations for different portions (ex: core, server) of llama stack""",
|
Dictionary of different logging configurations for different portions (ex: core, server) of llama stack""",
|
||||||
)
|
)
|
||||||
|
@ -248,7 +228,7 @@ class AuthenticationConfig(BaseModel):
|
||||||
...,
|
...,
|
||||||
description="Type of authentication provider (e.g., 'kubernetes', 'custom')",
|
description="Type of authentication provider (e.g., 'kubernetes', 'custom')",
|
||||||
)
|
)
|
||||||
config: Dict[str, str] = Field(
|
config: dict[str, str] = Field(
|
||||||
...,
|
...,
|
||||||
description="Provider-specific configuration",
|
description="Provider-specific configuration",
|
||||||
)
|
)
|
||||||
|
@ -261,15 +241,15 @@ class ServerConfig(BaseModel):
|
||||||
ge=1024,
|
ge=1024,
|
||||||
le=65535,
|
le=65535,
|
||||||
)
|
)
|
||||||
tls_certfile: Optional[str] = Field(
|
tls_certfile: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Path to TLS certificate file for HTTPS",
|
description="Path to TLS certificate file for HTTPS",
|
||||||
)
|
)
|
||||||
tls_keyfile: Optional[str] = Field(
|
tls_keyfile: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Path to TLS key file for HTTPS",
|
description="Path to TLS key file for HTTPS",
|
||||||
)
|
)
|
||||||
auth: Optional[AuthenticationConfig] = Field(
|
auth: AuthenticationConfig | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Authentication configuration for the server",
|
description="Authentication configuration for the server",
|
||||||
)
|
)
|
||||||
|
@ -285,23 +265,23 @@ Reference to the distribution this package refers to. For unregistered (adhoc) p
|
||||||
this could be just a hash
|
this could be just a hash
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
container_image: Optional[str] = Field(
|
container_image: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Reference to the container image if this package refers to a container",
|
description="Reference to the container image if this package refers to a container",
|
||||||
)
|
)
|
||||||
apis: List[str] = Field(
|
apis: list[str] = Field(
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
description="""
|
description="""
|
||||||
The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""",
|
The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""",
|
||||||
)
|
)
|
||||||
|
|
||||||
providers: Dict[str, List[Provider]] = Field(
|
providers: dict[str, list[Provider]] = Field(
|
||||||
description="""
|
description="""
|
||||||
One or more providers to use for each API. The same provider_type (e.g., meta-reference)
|
One or more providers to use for each API. The same provider_type (e.g., meta-reference)
|
||||||
can be instantiated multiple times (with different configs) if necessary.
|
can be instantiated multiple times (with different configs) if necessary.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
metadata_store: Optional[KVStoreConfig] = Field(
|
metadata_store: KVStoreConfig | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="""
|
description="""
|
||||||
Configuration for the persistence store used by the distribution registry. If not specified,
|
Configuration for the persistence store used by the distribution registry. If not specified,
|
||||||
|
@ -309,22 +289,22 @@ a default SQLite store will be used.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
# registry of "resources" in the distribution
|
# registry of "resources" in the distribution
|
||||||
models: List[ModelInput] = Field(default_factory=list)
|
models: list[ModelInput] = Field(default_factory=list)
|
||||||
shields: List[ShieldInput] = Field(default_factory=list)
|
shields: list[ShieldInput] = Field(default_factory=list)
|
||||||
vector_dbs: List[VectorDBInput] = Field(default_factory=list)
|
vector_dbs: list[VectorDBInput] = Field(default_factory=list)
|
||||||
datasets: List[DatasetInput] = Field(default_factory=list)
|
datasets: list[DatasetInput] = Field(default_factory=list)
|
||||||
scoring_fns: List[ScoringFnInput] = Field(default_factory=list)
|
scoring_fns: list[ScoringFnInput] = Field(default_factory=list)
|
||||||
benchmarks: List[BenchmarkInput] = Field(default_factory=list)
|
benchmarks: list[BenchmarkInput] = Field(default_factory=list)
|
||||||
tool_groups: List[ToolGroupInput] = Field(default_factory=list)
|
tool_groups: list[ToolGroupInput] = Field(default_factory=list)
|
||||||
|
|
||||||
logging: Optional[LoggingConfig] = Field(default=None, description="Configuration for Llama Stack Logging")
|
logging: LoggingConfig | None = Field(default=None, description="Configuration for Llama Stack Logging")
|
||||||
|
|
||||||
server: ServerConfig = Field(
|
server: ServerConfig = Field(
|
||||||
default_factory=ServerConfig,
|
default_factory=ServerConfig,
|
||||||
description="Configuration for the HTTP(S) server",
|
description="Configuration for the HTTP(S) server",
|
||||||
)
|
)
|
||||||
|
|
||||||
external_providers_dir: Optional[str] = Field(
|
external_providers_dir: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.",
|
description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.",
|
||||||
)
|
)
|
||||||
|
@ -338,11 +318,11 @@ class BuildConfig(BaseModel):
|
||||||
default="conda",
|
default="conda",
|
||||||
description="Type of package to build (conda | container | venv)",
|
description="Type of package to build (conda | container | venv)",
|
||||||
)
|
)
|
||||||
image_name: Optional[str] = Field(
|
image_name: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Name of the distribution to build",
|
description="Name of the distribution to build",
|
||||||
)
|
)
|
||||||
external_providers_dir: Optional[str] = Field(
|
external_providers_dir: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Path to directory containing external provider implementations. The providers packages will be resolved from this directory. "
|
description="Path to directory containing external provider implementations. The providers packages will be resolved from this directory. "
|
||||||
"pip_packages MUST contain the provider package name.",
|
"pip_packages MUST contain the provider package name.",
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
import glob
|
import glob
|
||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, List
|
from typing import Any
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
@ -24,7 +24,7 @@ from llama_stack.providers.datatypes import (
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
def stack_apis() -> List[Api]:
|
def stack_apis() -> list[Api]:
|
||||||
return list(Api)
|
return list(Api)
|
||||||
|
|
||||||
|
|
||||||
|
@ -33,7 +33,7 @@ class AutoRoutedApiInfo(BaseModel):
|
||||||
router_api: Api
|
router_api: Api
|
||||||
|
|
||||||
|
|
||||||
def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
|
def builtin_automatically_routed_apis() -> list[AutoRoutedApiInfo]:
|
||||||
return [
|
return [
|
||||||
AutoRoutedApiInfo(
|
AutoRoutedApiInfo(
|
||||||
routing_table_api=Api.models,
|
routing_table_api=Api.models,
|
||||||
|
@ -66,12 +66,12 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def providable_apis() -> List[Api]:
|
def providable_apis() -> list[Api]:
|
||||||
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
|
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
|
||||||
return [api for api in Api if api not in routing_table_apis and api != Api.inspect and api != Api.providers]
|
return [api for api in Api if api not in routing_table_apis and api != Api.inspect and api != Api.providers]
|
||||||
|
|
||||||
|
|
||||||
def _load_remote_provider_spec(spec_data: Dict[str, Any], api: Api) -> ProviderSpec:
|
def _load_remote_provider_spec(spec_data: dict[str, Any], api: Api) -> ProviderSpec:
|
||||||
adapter = AdapterSpec(**spec_data["adapter"])
|
adapter = AdapterSpec(**spec_data["adapter"])
|
||||||
spec = remote_provider_spec(
|
spec = remote_provider_spec(
|
||||||
api=api,
|
api=api,
|
||||||
|
@ -81,7 +81,7 @@ def _load_remote_provider_spec(spec_data: Dict[str, Any], api: Api) -> ProviderS
|
||||||
return spec
|
return spec
|
||||||
|
|
||||||
|
|
||||||
def _load_inline_provider_spec(spec_data: Dict[str, Any], api: Api, provider_name: str) -> ProviderSpec:
|
def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_name: str) -> ProviderSpec:
|
||||||
spec = InlineProviderSpec(
|
spec = InlineProviderSpec(
|
||||||
api=api,
|
api=api,
|
||||||
provider_type=f"inline::{provider_name}",
|
provider_type=f"inline::{provider_name}",
|
||||||
|
@ -98,7 +98,7 @@ def _load_inline_provider_spec(spec_data: Dict[str, Any], api: Api, provider_nam
|
||||||
|
|
||||||
def get_provider_registry(
|
def get_provider_registry(
|
||||||
config=None,
|
config=None,
|
||||||
) -> Dict[Api, Dict[str, ProviderSpec]]:
|
) -> dict[Api, dict[str, ProviderSpec]]:
|
||||||
"""Get the provider registry, optionally including external providers.
|
"""Get the provider registry, optionally including external providers.
|
||||||
|
|
||||||
This function loads both built-in providers and external providers from YAML files.
|
This function loads both built-in providers and external providers from YAML files.
|
||||||
|
@ -133,7 +133,7 @@ def get_provider_registry(
|
||||||
ValueError: If any provider spec is invalid
|
ValueError: If any provider spec is invalid
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ret: Dict[Api, Dict[str, ProviderSpec]] = {}
|
ret: dict[Api, dict[str, ProviderSpec]] = {}
|
||||||
for api in providable_apis():
|
for api in providable_apis():
|
||||||
name = api.name.lower()
|
name = api.name.lower()
|
||||||
logger.debug(f"Importing module {name}")
|
logger.debug(f"Importing module {name}")
|
||||||
|
|
|
@ -12,7 +12,7 @@ import os
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional, TypeVar, Union, get_args, get_origin
|
from typing import Any, TypeVar, Union, get_args, get_origin
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import yaml
|
import yaml
|
||||||
|
@ -119,8 +119,8 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
self,
|
self,
|
||||||
config_path_or_template_name: str,
|
config_path_or_template_name: str,
|
||||||
skip_logger_removal: bool = False,
|
skip_logger_removal: bool = False,
|
||||||
custom_provider_registry: Optional[ProviderRegistry] = None,
|
custom_provider_registry: ProviderRegistry | None = None,
|
||||||
provider_data: Optional[dict[str, Any]] = None,
|
provider_data: dict[str, Any] | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.async_client = AsyncLlamaStackAsLibraryClient(
|
self.async_client = AsyncLlamaStackAsLibraryClient(
|
||||||
|
@ -181,8 +181,8 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config_path_or_template_name: str,
|
config_path_or_template_name: str,
|
||||||
custom_provider_registry: Optional[ProviderRegistry] = None,
|
custom_provider_registry: ProviderRegistry | None = None,
|
||||||
provider_data: Optional[dict[str, Any]] = None,
|
provider_data: dict[str, Any] | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# when using the library client, we should not log to console since many
|
# when using the library client, we should not log to console since many
|
||||||
|
@ -371,7 +371,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
)
|
)
|
||||||
return await response.parse()
|
return await response.parse()
|
||||||
|
|
||||||
def _convert_body(self, path: str, method: str, body: Optional[dict] = None) -> dict:
|
def _convert_body(self, path: str, method: str, body: dict | None = None) -> dict:
|
||||||
if not body:
|
if not body:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -73,14 +73,14 @@ class ProviderImpl(Providers):
|
||||||
|
|
||||||
raise ValueError(f"Provider {provider_id} not found")
|
raise ValueError(f"Provider {provider_id} not found")
|
||||||
|
|
||||||
async def get_providers_health(self) -> Dict[str, Dict[str, HealthResponse]]:
|
async def get_providers_health(self) -> dict[str, dict[str, HealthResponse]]:
|
||||||
"""Get health status for all providers.
|
"""Get health status for all providers.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, Dict[str, HealthResponse]]: A dictionary mapping API names to provider health statuses.
|
Dict[str, Dict[str, HealthResponse]]: A dictionary mapping API names to provider health statuses.
|
||||||
Each API maps to a dictionary of provider IDs to their health responses.
|
Each API maps to a dictionary of provider IDs to their health responses.
|
||||||
"""
|
"""
|
||||||
providers_health: Dict[str, Dict[str, HealthResponse]] = {}
|
providers_health: dict[str, dict[str, HealthResponse]] = {}
|
||||||
timeout = 1.0
|
timeout = 1.0
|
||||||
|
|
||||||
async def check_provider_health(impl: Any) -> tuple[str, HealthResponse] | None:
|
async def check_provider_health(impl: Any) -> tuple[str, HealthResponse] | None:
|
||||||
|
|
|
@ -7,7 +7,8 @@
|
||||||
import contextvars
|
import contextvars
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, ContextManager, Dict, List, Optional
|
from contextlib import AbstractContextManager
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from .utils.dynamic import instantiate_class_type
|
from .utils.dynamic import instantiate_class_type
|
||||||
|
|
||||||
|
@ -17,11 +18,11 @@ log = logging.getLogger(__name__)
|
||||||
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
|
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
|
||||||
|
|
||||||
|
|
||||||
class RequestProviderDataContext(ContextManager):
|
class RequestProviderDataContext(AbstractContextManager):
|
||||||
"""Context manager for request provider data"""
|
"""Context manager for request provider data"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, provider_data: Optional[Dict[str, Any]] = None, auth_attributes: Optional[Dict[str, List[str]]] = None
|
self, provider_data: dict[str, Any] | None = None, auth_attributes: dict[str, list[str]] | None = None
|
||||||
):
|
):
|
||||||
self.provider_data = provider_data or {}
|
self.provider_data = provider_data or {}
|
||||||
if auth_attributes:
|
if auth_attributes:
|
||||||
|
@ -63,7 +64,7 @@ class NeedsRequestProviderData:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def parse_request_provider_data(headers: Dict[str, str]) -> Optional[Dict[str, Any]]:
|
def parse_request_provider_data(headers: dict[str, str]) -> dict[str, Any] | None:
|
||||||
"""Parse provider data from request headers"""
|
"""Parse provider data from request headers"""
|
||||||
keys = [
|
keys = [
|
||||||
"X-LlamaStack-Provider-Data",
|
"X-LlamaStack-Provider-Data",
|
||||||
|
@ -86,14 +87,14 @@ def parse_request_provider_data(headers: Dict[str, str]) -> Optional[Dict[str, A
|
||||||
|
|
||||||
|
|
||||||
def request_provider_data_context(
|
def request_provider_data_context(
|
||||||
headers: Dict[str, str], auth_attributes: Optional[Dict[str, List[str]]] = None
|
headers: dict[str, str], auth_attributes: dict[str, list[str]] | None = None
|
||||||
) -> ContextManager:
|
) -> AbstractContextManager:
|
||||||
"""Context manager that sets request provider data from headers and auth attributes for the duration of the context"""
|
"""Context manager that sets request provider data from headers and auth attributes for the duration of the context"""
|
||||||
provider_data = parse_request_provider_data(headers)
|
provider_data = parse_request_provider_data(headers)
|
||||||
return RequestProviderDataContext(provider_data, auth_attributes)
|
return RequestProviderDataContext(provider_data, auth_attributes)
|
||||||
|
|
||||||
|
|
||||||
def get_auth_attributes() -> Optional[Dict[str, List[str]]]:
|
def get_auth_attributes() -> dict[str, list[str]] | None:
|
||||||
"""Helper to retrieve auth attributes from the provider data context"""
|
"""Helper to retrieve auth attributes from the provider data context"""
|
||||||
provider_data = PROVIDER_DATA_VAR.get()
|
provider_data = PROVIDER_DATA_VAR.get()
|
||||||
if not provider_data:
|
if not provider_data:
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Any, Dict, List, Set, Tuple
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.agents import Agents
|
from llama_stack.apis.agents import Agents
|
||||||
from llama_stack.apis.benchmarks import Benchmarks
|
from llama_stack.apis.benchmarks import Benchmarks
|
||||||
|
@ -58,7 +58,7 @@ class InvalidProviderError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def api_protocol_map() -> Dict[Api, Any]:
|
def api_protocol_map() -> dict[Api, Any]:
|
||||||
return {
|
return {
|
||||||
Api.providers: ProvidersAPI,
|
Api.providers: ProvidersAPI,
|
||||||
Api.agents: Agents,
|
Api.agents: Agents,
|
||||||
|
@ -83,7 +83,7 @@ def api_protocol_map() -> Dict[Api, Any]:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def additional_protocols_map() -> Dict[Api, Any]:
|
def additional_protocols_map() -> dict[Api, Any]:
|
||||||
return {
|
return {
|
||||||
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
|
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
|
||||||
Api.tool_groups: (ToolsProtocolPrivate, ToolGroups, Api.tool_groups),
|
Api.tool_groups: (ToolsProtocolPrivate, ToolGroups, Api.tool_groups),
|
||||||
|
@ -104,14 +104,14 @@ class ProviderWithSpec(Provider):
|
||||||
spec: ProviderSpec
|
spec: ProviderSpec
|
||||||
|
|
||||||
|
|
||||||
ProviderRegistry = Dict[Api, Dict[str, ProviderSpec]]
|
ProviderRegistry = dict[Api, dict[str, ProviderSpec]]
|
||||||
|
|
||||||
|
|
||||||
async def resolve_impls(
|
async def resolve_impls(
|
||||||
run_config: StackRunConfig,
|
run_config: StackRunConfig,
|
||||||
provider_registry: ProviderRegistry,
|
provider_registry: ProviderRegistry,
|
||||||
dist_registry: DistributionRegistry,
|
dist_registry: DistributionRegistry,
|
||||||
) -> Dict[Api, Any]:
|
) -> dict[Api, Any]:
|
||||||
"""
|
"""
|
||||||
Resolves provider implementations by:
|
Resolves provider implementations by:
|
||||||
1. Validating and organizing providers.
|
1. Validating and organizing providers.
|
||||||
|
@ -136,7 +136,7 @@ async def resolve_impls(
|
||||||
return await instantiate_providers(sorted_providers, router_apis, dist_registry)
|
return await instantiate_providers(sorted_providers, router_apis, dist_registry)
|
||||||
|
|
||||||
|
|
||||||
def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str, Dict[str, ProviderWithSpec]]:
|
def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]:
|
||||||
"""Generates specifications for automatically routed APIs."""
|
"""Generates specifications for automatically routed APIs."""
|
||||||
specs = {}
|
specs = {}
|
||||||
for info in builtin_automatically_routed_apis():
|
for info in builtin_automatically_routed_apis():
|
||||||
|
@ -178,10 +178,10 @@ def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str,
|
||||||
|
|
||||||
|
|
||||||
def validate_and_prepare_providers(
|
def validate_and_prepare_providers(
|
||||||
run_config: StackRunConfig, provider_registry: ProviderRegistry, routing_table_apis: Set[Api], router_apis: Set[Api]
|
run_config: StackRunConfig, provider_registry: ProviderRegistry, routing_table_apis: set[Api], router_apis: set[Api]
|
||||||
) -> Dict[str, Dict[str, ProviderWithSpec]]:
|
) -> dict[str, dict[str, ProviderWithSpec]]:
|
||||||
"""Validates providers, handles deprecations, and organizes them into a spec dictionary."""
|
"""Validates providers, handles deprecations, and organizes them into a spec dictionary."""
|
||||||
providers_with_specs: Dict[str, Dict[str, ProviderWithSpec]] = {}
|
providers_with_specs: dict[str, dict[str, ProviderWithSpec]] = {}
|
||||||
|
|
||||||
for api_str, providers in run_config.providers.items():
|
for api_str, providers in run_config.providers.items():
|
||||||
api = Api(api_str)
|
api = Api(api_str)
|
||||||
|
@ -222,10 +222,10 @@ def validate_provider(provider: Provider, api: Api, provider_registry: ProviderR
|
||||||
|
|
||||||
|
|
||||||
def sort_providers_by_deps(
|
def sort_providers_by_deps(
|
||||||
providers_with_specs: Dict[str, Dict[str, ProviderWithSpec]], run_config: StackRunConfig
|
providers_with_specs: dict[str, dict[str, ProviderWithSpec]], run_config: StackRunConfig
|
||||||
) -> List[Tuple[str, ProviderWithSpec]]:
|
) -> list[tuple[str, ProviderWithSpec]]:
|
||||||
"""Sorts providers based on their dependencies."""
|
"""Sorts providers based on their dependencies."""
|
||||||
sorted_providers: List[Tuple[str, ProviderWithSpec]] = topological_sort(
|
sorted_providers: list[tuple[str, ProviderWithSpec]] = topological_sort(
|
||||||
{k: list(v.values()) for k, v in providers_with_specs.items()}
|
{k: list(v.values()) for k, v in providers_with_specs.items()}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -236,11 +236,11 @@ def sort_providers_by_deps(
|
||||||
|
|
||||||
|
|
||||||
async def instantiate_providers(
|
async def instantiate_providers(
|
||||||
sorted_providers: List[Tuple[str, ProviderWithSpec]], router_apis: Set[Api], dist_registry: DistributionRegistry
|
sorted_providers: list[tuple[str, ProviderWithSpec]], router_apis: set[Api], dist_registry: DistributionRegistry
|
||||||
) -> Dict:
|
) -> dict:
|
||||||
"""Instantiates providers asynchronously while managing dependencies."""
|
"""Instantiates providers asynchronously while managing dependencies."""
|
||||||
impls: Dict[Api, Any] = {}
|
impls: dict[Api, Any] = {}
|
||||||
inner_impls_by_provider_id: Dict[str, Dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
|
inner_impls_by_provider_id: dict[str, dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
|
||||||
for api_str, provider in sorted_providers:
|
for api_str, provider in sorted_providers:
|
||||||
deps = {a: impls[a] for a in provider.spec.api_dependencies}
|
deps = {a: impls[a] for a in provider.spec.api_dependencies}
|
||||||
for a in provider.spec.optional_api_dependencies:
|
for a in provider.spec.optional_api_dependencies:
|
||||||
|
@ -263,9 +263,9 @@ async def instantiate_providers(
|
||||||
|
|
||||||
|
|
||||||
def topological_sort(
|
def topological_sort(
|
||||||
providers_with_specs: Dict[str, List[ProviderWithSpec]],
|
providers_with_specs: dict[str, list[ProviderWithSpec]],
|
||||||
) -> List[Tuple[str, ProviderWithSpec]]:
|
) -> list[tuple[str, ProviderWithSpec]]:
|
||||||
def dfs(kv, visited: Set[str], stack: List[str]):
|
def dfs(kv, visited: set[str], stack: list[str]):
|
||||||
api_str, providers = kv
|
api_str, providers = kv
|
||||||
visited.add(api_str)
|
visited.add(api_str)
|
||||||
|
|
||||||
|
@ -280,8 +280,8 @@ def topological_sort(
|
||||||
|
|
||||||
stack.append(api_str)
|
stack.append(api_str)
|
||||||
|
|
||||||
visited: Set[str] = set()
|
visited: set[str] = set()
|
||||||
stack: List[str] = []
|
stack: list[str] = []
|
||||||
|
|
||||||
for api_str, providers in providers_with_specs.items():
|
for api_str, providers in providers_with_specs.items():
|
||||||
if api_str not in visited:
|
if api_str not in visited:
|
||||||
|
@ -298,8 +298,8 @@ def topological_sort(
|
||||||
# returns a class implementing the protocol corresponding to the Api
|
# returns a class implementing the protocol corresponding to the Api
|
||||||
async def instantiate_provider(
|
async def instantiate_provider(
|
||||||
provider: ProviderWithSpec,
|
provider: ProviderWithSpec,
|
||||||
deps: Dict[Api, Any],
|
deps: dict[Api, Any],
|
||||||
inner_impls: Dict[str, Any],
|
inner_impls: dict[str, Any],
|
||||||
dist_registry: DistributionRegistry,
|
dist_registry: DistributionRegistry,
|
||||||
):
|
):
|
||||||
protocols = api_protocol_map()
|
protocols = api_protocol_map()
|
||||||
|
@ -391,8 +391,8 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
|
||||||
|
|
||||||
async def resolve_remote_stack_impls(
|
async def resolve_remote_stack_impls(
|
||||||
config: RemoteProviderConfig,
|
config: RemoteProviderConfig,
|
||||||
apis: List[str],
|
apis: list[str],
|
||||||
) -> Dict[Api, Any]:
|
) -> dict[Api, Any]:
|
||||||
protocols = api_protocol_map()
|
protocols = api_protocol_map()
|
||||||
additional_protocols = additional_protocols_map()
|
additional_protocols = additional_protocols_map()
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import RoutedProtocol
|
from llama_stack.distribution.datatypes import RoutedProtocol
|
||||||
from llama_stack.distribution.store import DistributionRegistry
|
from llama_stack.distribution.store import DistributionRegistry
|
||||||
|
@ -23,7 +23,7 @@ from .routing_tables import (
|
||||||
|
|
||||||
async def get_routing_table_impl(
|
async def get_routing_table_impl(
|
||||||
api: Api,
|
api: Api,
|
||||||
impls_by_provider_id: Dict[str, RoutedProtocol],
|
impls_by_provider_id: dict[str, RoutedProtocol],
|
||||||
_deps,
|
_deps,
|
||||||
dist_registry: DistributionRegistry,
|
dist_registry: DistributionRegistry,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
|
@ -45,7 +45,7 @@ async def get_routing_table_impl(
|
||||||
return impl
|
return impl
|
||||||
|
|
||||||
|
|
||||||
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict[str, Any]) -> Any:
|
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: dict[str, Any]) -> Any:
|
||||||
from .routers import (
|
from .routers import (
|
||||||
DatasetIORouter,
|
DatasetIORouter,
|
||||||
EvalRouter,
|
EvalRouter,
|
||||||
|
|
|
@ -6,12 +6,12 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
from collections.abc import AsyncGenerator, AsyncIterator
|
||||||
|
from typing import Annotated, Any
|
||||||
|
|
||||||
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
|
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
|
||||||
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
|
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
|
||||||
from pydantic import Field, TypeAdapter
|
from pydantic import Field, TypeAdapter
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
URL,
|
URL,
|
||||||
|
@ -100,9 +100,9 @@ class VectorIORouter(VectorIO):
|
||||||
self,
|
self,
|
||||||
vector_db_id: str,
|
vector_db_id: str,
|
||||||
embedding_model: str,
|
embedding_model: str,
|
||||||
embedding_dimension: Optional[int] = 384,
|
embedding_dimension: int | None = 384,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: str | None = None,
|
||||||
provider_vector_db_id: Optional[str] = None,
|
provider_vector_db_id: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
|
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
|
||||||
await self.routing_table.register_vector_db(
|
await self.routing_table.register_vector_db(
|
||||||
|
@ -116,8 +116,8 @@ class VectorIORouter(VectorIO):
|
||||||
async def insert_chunks(
|
async def insert_chunks(
|
||||||
self,
|
self,
|
||||||
vector_db_id: str,
|
vector_db_id: str,
|
||||||
chunks: List[Chunk],
|
chunks: list[Chunk],
|
||||||
ttl_seconds: Optional[int] = None,
|
ttl_seconds: int | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
||||||
|
@ -128,7 +128,7 @@ class VectorIORouter(VectorIO):
|
||||||
self,
|
self,
|
||||||
vector_db_id: str,
|
vector_db_id: str,
|
||||||
query: InterleavedContent,
|
query: InterleavedContent,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: dict[str, Any] | None = None,
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
|
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
|
||||||
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
|
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
|
||||||
|
@ -140,7 +140,7 @@ class InferenceRouter(Inference):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
telemetry: Optional[Telemetry] = None,
|
telemetry: Telemetry | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug("Initializing InferenceRouter")
|
logger.debug("Initializing InferenceRouter")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
@ -160,10 +160,10 @@ class InferenceRouter(Inference):
|
||||||
async def register_model(
|
async def register_model(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
provider_model_id: Optional[str] = None,
|
provider_model_id: str | None = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: str | None = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
model_type: Optional[ModelType] = None,
|
model_type: ModelType | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
|
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
|
||||||
|
@ -176,7 +176,7 @@ class InferenceRouter(Inference):
|
||||||
completion_tokens: int,
|
completion_tokens: int,
|
||||||
total_tokens: int,
|
total_tokens: int,
|
||||||
model: Model,
|
model: Model,
|
||||||
) -> List[MetricEvent]:
|
) -> list[MetricEvent]:
|
||||||
"""Constructs a list of MetricEvent objects containing token usage metrics.
|
"""Constructs a list of MetricEvent objects containing token usage metrics.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -221,7 +221,7 @@ class InferenceRouter(Inference):
|
||||||
completion_tokens: int,
|
completion_tokens: int,
|
||||||
total_tokens: int,
|
total_tokens: int,
|
||||||
model: Model,
|
model: Model,
|
||||||
) -> List[MetricInResponse]:
|
) -> list[MetricInResponse]:
|
||||||
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
|
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
|
||||||
if self.telemetry:
|
if self.telemetry:
|
||||||
for metric in metrics:
|
for metric in metrics:
|
||||||
|
@ -230,9 +230,9 @@ class InferenceRouter(Inference):
|
||||||
|
|
||||||
async def _count_tokens(
|
async def _count_tokens(
|
||||||
self,
|
self,
|
||||||
messages: List[Message] | InterleavedContent,
|
messages: list[Message] | InterleavedContent,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: ToolPromptFormat | None = None,
|
||||||
) -> Optional[int]:
|
) -> int | None:
|
||||||
if isinstance(messages, list):
|
if isinstance(messages, list):
|
||||||
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
|
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
|
||||||
else:
|
else:
|
||||||
|
@ -242,16 +242,16 @@ class InferenceRouter(Inference):
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
messages: List[Message],
|
messages: list[Message],
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: SamplingParams | None = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: ResponseFormat | None = None,
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: list[ToolDefinition] | None = None,
|
||||||
tool_choice: Optional[ToolChoice] = None,
|
tool_choice: ToolChoice | None = None,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: ToolPromptFormat | None = None,
|
||||||
stream: Optional[bool] = False,
|
stream: bool | None = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: ToolConfig | None = None,
|
||||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
||||||
)
|
)
|
||||||
|
@ -351,12 +351,12 @@ class InferenceRouter(Inference):
|
||||||
async def batch_chat_completion(
|
async def batch_chat_completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
messages_batch: List[List[Message]],
|
messages_batch: list[list[Message]],
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: list[ToolDefinition] | None = None,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: ToolConfig | None = None,
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: SamplingParams | None = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: ResponseFormat | None = None,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
) -> BatchChatCompletionResponse:
|
) -> BatchChatCompletionResponse:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
|
f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
|
||||||
|
@ -376,10 +376,10 @@ class InferenceRouter(Inference):
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
content: InterleavedContent,
|
content: InterleavedContent,
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: SamplingParams | None = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: ResponseFormat | None = None,
|
||||||
stream: Optional[bool] = False,
|
stream: bool | None = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
sampling_params = SamplingParams()
|
sampling_params = SamplingParams()
|
||||||
|
@ -439,10 +439,10 @@ class InferenceRouter(Inference):
|
||||||
async def batch_completion(
|
async def batch_completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
content_batch: List[InterleavedContent],
|
content_batch: list[InterleavedContent],
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: SamplingParams | None = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: ResponseFormat | None = None,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
) -> BatchCompletionResponse:
|
) -> BatchCompletionResponse:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
|
f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
|
||||||
|
@ -453,10 +453,10 @@ class InferenceRouter(Inference):
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[str] | List[InterleavedContentItem],
|
contents: list[str] | list[InterleavedContentItem],
|
||||||
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||||
output_dimension: Optional[int] = None,
|
output_dimension: int | None = None,
|
||||||
task_type: Optional[EmbeddingTaskType] = None,
|
task_type: EmbeddingTaskType | None = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
logger.debug(f"InferenceRouter.embeddings: {model_id}")
|
logger.debug(f"InferenceRouter.embeddings: {model_id}")
|
||||||
model = await self.routing_table.get_model(model_id)
|
model = await self.routing_table.get_model(model_id)
|
||||||
|
@ -475,24 +475,24 @@ class InferenceRouter(Inference):
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
prompt: Union[str, List[str], List[int], List[List[int]]],
|
prompt: str | list[str] | list[int] | list[list[int]],
|
||||||
best_of: Optional[int] = None,
|
best_of: int | None = None,
|
||||||
echo: Optional[bool] = None,
|
echo: bool | None = None,
|
||||||
frequency_penalty: Optional[float] = None,
|
frequency_penalty: float | None = None,
|
||||||
logit_bias: Optional[Dict[str, float]] = None,
|
logit_bias: dict[str, float] | None = None,
|
||||||
logprobs: Optional[bool] = None,
|
logprobs: bool | None = None,
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: int | None = None,
|
||||||
n: Optional[int] = None,
|
n: int | None = None,
|
||||||
presence_penalty: Optional[float] = None,
|
presence_penalty: float | None = None,
|
||||||
seed: Optional[int] = None,
|
seed: int | None = None,
|
||||||
stop: Optional[Union[str, List[str]]] = None,
|
stop: str | list[str] | None = None,
|
||||||
stream: Optional[bool] = None,
|
stream: bool | None = None,
|
||||||
stream_options: Optional[Dict[str, Any]] = None,
|
stream_options: dict[str, Any] | None = None,
|
||||||
temperature: Optional[float] = None,
|
temperature: float | None = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: float | None = None,
|
||||||
user: Optional[str] = None,
|
user: str | None = None,
|
||||||
guided_choice: Optional[List[str]] = None,
|
guided_choice: list[str] | None = None,
|
||||||
prompt_logprobs: Optional[int] = None,
|
prompt_logprobs: int | None = None,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
|
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
|
||||||
|
@ -531,29 +531,29 @@ class InferenceRouter(Inference):
|
||||||
async def openai_chat_completion(
|
async def openai_chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: Annotated[List[OpenAIMessageParam], Field(..., min_length=1)],
|
messages: Annotated[list[OpenAIMessageParam], Field(..., min_length=1)],
|
||||||
frequency_penalty: Optional[float] = None,
|
frequency_penalty: float | None = None,
|
||||||
function_call: Optional[Union[str, Dict[str, Any]]] = None,
|
function_call: str | dict[str, Any] | None = None,
|
||||||
functions: Optional[List[Dict[str, Any]]] = None,
|
functions: list[dict[str, Any]] | None = None,
|
||||||
logit_bias: Optional[Dict[str, float]] = None,
|
logit_bias: dict[str, float] | None = None,
|
||||||
logprobs: Optional[bool] = None,
|
logprobs: bool | None = None,
|
||||||
max_completion_tokens: Optional[int] = None,
|
max_completion_tokens: int | None = None,
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: int | None = None,
|
||||||
n: Optional[int] = None,
|
n: int | None = None,
|
||||||
parallel_tool_calls: Optional[bool] = None,
|
parallel_tool_calls: bool | None = None,
|
||||||
presence_penalty: Optional[float] = None,
|
presence_penalty: float | None = None,
|
||||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
response_format: OpenAIResponseFormatParam | None = None,
|
||||||
seed: Optional[int] = None,
|
seed: int | None = None,
|
||||||
stop: Optional[Union[str, List[str]]] = None,
|
stop: str | list[str] | None = None,
|
||||||
stream: Optional[bool] = None,
|
stream: bool | None = None,
|
||||||
stream_options: Optional[Dict[str, Any]] = None,
|
stream_options: dict[str, Any] | None = None,
|
||||||
temperature: Optional[float] = None,
|
temperature: float | None = None,
|
||||||
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
tools: Optional[List[Dict[str, Any]]] = None,
|
tools: list[dict[str, Any]] | None = None,
|
||||||
top_logprobs: Optional[int] = None,
|
top_logprobs: int | None = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: float | None = None,
|
||||||
user: Optional[str] = None,
|
user: str | None = None,
|
||||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}",
|
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}",
|
||||||
)
|
)
|
||||||
|
@ -602,7 +602,7 @@ class InferenceRouter(Inference):
|
||||||
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
||||||
return await provider.openai_chat_completion(**params)
|
return await provider.openai_chat_completion(**params)
|
||||||
|
|
||||||
async def health(self) -> Dict[str, HealthResponse]:
|
async def health(self) -> dict[str, HealthResponse]:
|
||||||
health_statuses = {}
|
health_statuses = {}
|
||||||
timeout = 0.5
|
timeout = 0.5
|
||||||
for provider_id, impl in self.routing_table.impls_by_provider_id.items():
|
for provider_id, impl in self.routing_table.impls_by_provider_id.items():
|
||||||
|
@ -645,9 +645,9 @@ class SafetyRouter(Safety):
|
||||||
async def register_shield(
|
async def register_shield(
|
||||||
self,
|
self,
|
||||||
shield_id: str,
|
shield_id: str,
|
||||||
provider_shield_id: Optional[str] = None,
|
provider_shield_id: str | None = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: str | None = None,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: dict[str, Any] | None = None,
|
||||||
) -> Shield:
|
) -> Shield:
|
||||||
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
||||||
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
||||||
|
@ -655,8 +655,8 @@ class SafetyRouter(Safety):
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self,
|
self,
|
||||||
shield_id: str,
|
shield_id: str,
|
||||||
messages: List[Message],
|
messages: list[Message],
|
||||||
params: Dict[str, Any] = None,
|
params: dict[str, Any] = None,
|
||||||
) -> RunShieldResponse:
|
) -> RunShieldResponse:
|
||||||
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
||||||
return await self.routing_table.get_provider_impl(shield_id).run_shield(
|
return await self.routing_table.get_provider_impl(shield_id).run_shield(
|
||||||
|
@ -686,8 +686,8 @@ class DatasetIORouter(DatasetIO):
|
||||||
self,
|
self,
|
||||||
purpose: DatasetPurpose,
|
purpose: DatasetPurpose,
|
||||||
source: DataSource,
|
source: DataSource,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
dataset_id: Optional[str] = None,
|
dataset_id: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
|
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
|
||||||
|
@ -702,8 +702,8 @@ class DatasetIORouter(DatasetIO):
|
||||||
async def iterrows(
|
async def iterrows(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
start_index: Optional[int] = None,
|
start_index: int | None = None,
|
||||||
limit: Optional[int] = None,
|
limit: int | None = None,
|
||||||
) -> PaginatedResponse:
|
) -> PaginatedResponse:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
|
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
|
||||||
|
@ -714,7 +714,7 @@ class DatasetIORouter(DatasetIO):
|
||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
|
||||||
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
|
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
|
||||||
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
|
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
|
@ -741,7 +741,7 @@ class ScoringRouter(Scoring):
|
||||||
async def score_batch(
|
async def score_batch(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
||||||
save_results_dataset: bool = False,
|
save_results_dataset: bool = False,
|
||||||
) -> ScoreBatchResponse:
|
) -> ScoreBatchResponse:
|
||||||
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
|
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
|
||||||
|
@ -762,8 +762,8 @@ class ScoringRouter(Scoring):
|
||||||
|
|
||||||
async def score(
|
async def score(
|
||||||
self,
|
self,
|
||||||
input_rows: List[Dict[str, Any]],
|
input_rows: list[dict[str, Any]],
|
||||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
||||||
) -> ScoreResponse:
|
) -> ScoreResponse:
|
||||||
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
|
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
|
||||||
res = {}
|
res = {}
|
||||||
|
@ -808,8 +808,8 @@ class EvalRouter(Eval):
|
||||||
async def evaluate_rows(
|
async def evaluate_rows(
|
||||||
self,
|
self,
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
input_rows: List[Dict[str, Any]],
|
input_rows: list[dict[str, Any]],
|
||||||
scoring_functions: List[str],
|
scoring_functions: list[str],
|
||||||
benchmark_config: BenchmarkConfig,
|
benchmark_config: BenchmarkConfig,
|
||||||
) -> EvaluateResponse:
|
) -> EvaluateResponse:
|
||||||
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
|
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
|
||||||
|
@ -863,8 +863,8 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
async def query(
|
async def query(
|
||||||
self,
|
self,
|
||||||
content: InterleavedContent,
|
content: InterleavedContent,
|
||||||
vector_db_ids: List[str],
|
vector_db_ids: list[str],
|
||||||
query_config: Optional[RAGQueryConfig] = None,
|
query_config: RAGQueryConfig | None = None,
|
||||||
) -> RAGQueryResult:
|
) -> RAGQueryResult:
|
||||||
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
|
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
|
||||||
return await self.routing_table.get_provider_impl("knowledge_search").query(
|
return await self.routing_table.get_provider_impl("knowledge_search").query(
|
||||||
|
@ -873,7 +873,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
|
|
||||||
async def insert(
|
async def insert(
|
||||||
self,
|
self,
|
||||||
documents: List[RAGDocument],
|
documents: list[RAGDocument],
|
||||||
vector_db_id: str,
|
vector_db_id: str,
|
||||||
chunk_size_in_tokens: int = 512,
|
chunk_size_in_tokens: int = 512,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -904,7 +904,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
logger.debug("ToolRuntimeRouter.shutdown")
|
logger.debug("ToolRuntimeRouter.shutdown")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any:
|
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any:
|
||||||
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
|
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
|
||||||
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
|
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
|
||||||
tool_name=tool_name,
|
tool_name=tool_name,
|
||||||
|
@ -912,7 +912,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def list_runtime_tools(
|
async def list_runtime_tools(
|
||||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||||
) -> ListToolDefsResponse:
|
) -> ListToolDefsResponse:
|
||||||
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
||||||
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
|
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import TypeAdapter
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
|
@ -106,20 +106,20 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
||||||
raise ValueError(f"Unregister not supported for {api}")
|
raise ValueError(f"Unregister not supported for {api}")
|
||||||
|
|
||||||
|
|
||||||
Registry = Dict[str, List[RoutableObjectWithProvider]]
|
Registry = dict[str, list[RoutableObjectWithProvider]]
|
||||||
|
|
||||||
|
|
||||||
class CommonRoutingTableImpl(RoutingTable):
|
class CommonRoutingTableImpl(RoutingTable):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
impls_by_provider_id: Dict[str, RoutedProtocol],
|
impls_by_provider_id: dict[str, RoutedProtocol],
|
||||||
dist_registry: DistributionRegistry,
|
dist_registry: DistributionRegistry,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.impls_by_provider_id = impls_by_provider_id
|
self.impls_by_provider_id = impls_by_provider_id
|
||||||
self.dist_registry = dist_registry
|
self.dist_registry = dist_registry
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
async def add_objects(objs: List[RoutableObjectWithProvider], provider_id: str, cls) -> None:
|
async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None:
|
||||||
for obj in objs:
|
for obj in objs:
|
||||||
if cls is None:
|
if cls is None:
|
||||||
obj.provider_id = provider_id
|
obj.provider_id = provider_id
|
||||||
|
@ -154,7 +154,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
for p in self.impls_by_provider_id.values():
|
for p in self.impls_by_provider_id.values():
|
||||||
await p.shutdown()
|
await p.shutdown()
|
||||||
|
|
||||||
def get_provider_impl(self, routing_key: str, provider_id: Optional[str] = None) -> Any:
|
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
||||||
def apiname_object():
|
def apiname_object():
|
||||||
if isinstance(self, ModelsRoutingTable):
|
if isinstance(self, ModelsRoutingTable):
|
||||||
return ("Inference", "model")
|
return ("Inference", "model")
|
||||||
|
@ -192,7 +192,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
|
|
||||||
raise ValueError(f"Provider not found for `{routing_key}`")
|
raise ValueError(f"Provider not found for `{routing_key}`")
|
||||||
|
|
||||||
async def get_object_by_identifier(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
|
async def get_object_by_identifier(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
|
||||||
# Get from disk registry
|
# Get from disk registry
|
||||||
obj = await self.dist_registry.get(type, identifier)
|
obj = await self.dist_registry.get(type, identifier)
|
||||||
if not obj:
|
if not obj:
|
||||||
|
@ -236,7 +236,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
await self.dist_registry.register(obj)
|
await self.dist_registry.register(obj)
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]:
|
async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]:
|
||||||
objs = await self.dist_registry.get_all()
|
objs = await self.dist_registry.get_all()
|
||||||
filtered_objs = [obj for obj in objs if obj.type == type]
|
filtered_objs = [obj for obj in objs if obj.type == type]
|
||||||
|
|
||||||
|
@ -277,10 +277,10 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
async def register_model(
|
async def register_model(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
provider_model_id: Optional[str] = None,
|
provider_model_id: str | None = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: str | None = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
model_type: Optional[ModelType] = None,
|
model_type: ModelType | None = None,
|
||||||
) -> Model:
|
) -> Model:
|
||||||
if provider_model_id is None:
|
if provider_model_id is None:
|
||||||
provider_model_id = model_id
|
provider_model_id = model_id
|
||||||
|
@ -328,9 +328,9 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
async def register_shield(
|
async def register_shield(
|
||||||
self,
|
self,
|
||||||
shield_id: str,
|
shield_id: str,
|
||||||
provider_shield_id: Optional[str] = None,
|
provider_shield_id: str | None = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: str | None = None,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: dict[str, Any] | None = None,
|
||||||
) -> Shield:
|
) -> Shield:
|
||||||
if provider_shield_id is None:
|
if provider_shield_id is None:
|
||||||
provider_shield_id = shield_id
|
provider_shield_id = shield_id
|
||||||
|
@ -368,9 +368,9 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||||
self,
|
self,
|
||||||
vector_db_id: str,
|
vector_db_id: str,
|
||||||
embedding_model: str,
|
embedding_model: str,
|
||||||
embedding_dimension: Optional[int] = 384,
|
embedding_dimension: int | None = 384,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: str | None = None,
|
||||||
provider_vector_db_id: Optional[str] = None,
|
provider_vector_db_id: str | None = None,
|
||||||
) -> VectorDB:
|
) -> VectorDB:
|
||||||
if provider_vector_db_id is None:
|
if provider_vector_db_id is None:
|
||||||
provider_vector_db_id = vector_db_id
|
provider_vector_db_id = vector_db_id
|
||||||
|
@ -423,8 +423,8 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||||
self,
|
self,
|
||||||
purpose: DatasetPurpose,
|
purpose: DatasetPurpose,
|
||||||
source: DataSource,
|
source: DataSource,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
dataset_id: Optional[str] = None,
|
dataset_id: str | None = None,
|
||||||
) -> Dataset:
|
) -> Dataset:
|
||||||
if isinstance(source, dict):
|
if isinstance(source, dict):
|
||||||
if source["type"] == "uri":
|
if source["type"] == "uri":
|
||||||
|
@ -489,9 +489,9 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||||
scoring_fn_id: str,
|
scoring_fn_id: str,
|
||||||
description: str,
|
description: str,
|
||||||
return_type: ParamType,
|
return_type: ParamType,
|
||||||
provider_scoring_fn_id: Optional[str] = None,
|
provider_scoring_fn_id: str | None = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: str | None = None,
|
||||||
params: Optional[ScoringFnParams] = None,
|
params: ScoringFnParams | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if provider_scoring_fn_id is None:
|
if provider_scoring_fn_id is None:
|
||||||
provider_scoring_fn_id = scoring_fn_id
|
provider_scoring_fn_id = scoring_fn_id
|
||||||
|
@ -528,10 +528,10 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
||||||
self,
|
self,
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
scoring_functions: List[str],
|
scoring_functions: list[str],
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
provider_benchmark_id: Optional[str] = None,
|
provider_benchmark_id: str | None = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
metadata = {}
|
metadata = {}
|
||||||
|
@ -556,7 +556,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
||||||
|
|
||||||
|
|
||||||
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
async def list_tools(self, toolgroup_id: Optional[str] = None) -> ListToolsResponse:
|
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
||||||
tools = await self.get_all_with_type("tool")
|
tools = await self.get_all_with_type("tool")
|
||||||
if toolgroup_id:
|
if toolgroup_id:
|
||||||
tools = [tool for tool in tools if tool.toolgroup_id == toolgroup_id]
|
tools = [tool for tool in tools if tool.toolgroup_id == toolgroup_id]
|
||||||
|
@ -578,8 +578,8 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
self,
|
self,
|
||||||
toolgroup_id: str,
|
toolgroup_id: str,
|
||||||
provider_id: str,
|
provider_id: str,
|
||||||
mcp_endpoint: Optional[URL] = None,
|
mcp_endpoint: URL | None = None,
|
||||||
args: Optional[Dict[str, Any]] = None,
|
args: dict[str, Any] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
tools = []
|
tools = []
|
||||||
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint)
|
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint)
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
import json
|
import json
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, List, Optional
|
|
||||||
from urllib.parse import parse_qs
|
from urllib.parse import parse_qs
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -22,7 +21,7 @@ logger = get_logger(name=__name__, category="auth")
|
||||||
class AuthResponse(BaseModel):
|
class AuthResponse(BaseModel):
|
||||||
"""The format of the authentication response from the auth endpoint."""
|
"""The format of the authentication response from the auth endpoint."""
|
||||||
|
|
||||||
access_attributes: Optional[AccessAttributes] = Field(
|
access_attributes: AccessAttributes | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="""
|
description="""
|
||||||
Structured user attributes for attribute-based access control.
|
Structured user attributes for attribute-based access control.
|
||||||
|
@ -44,7 +43,7 @@ class AuthResponse(BaseModel):
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
message: Optional[str] = Field(
|
message: str | None = Field(
|
||||||
default=None, description="Optional message providing additional context about the authentication result."
|
default=None, description="Optional message providing additional context about the authentication result."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -52,9 +51,9 @@ class AuthResponse(BaseModel):
|
||||||
class AuthRequestContext(BaseModel):
|
class AuthRequestContext(BaseModel):
|
||||||
path: str = Field(description="The path of the request being authenticated")
|
path: str = Field(description="The path of the request being authenticated")
|
||||||
|
|
||||||
headers: Dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)")
|
headers: dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)")
|
||||||
|
|
||||||
params: Dict[str, List[str]] = Field(
|
params: dict[str, list[str]] = Field(
|
||||||
description="Query parameters from the original request, parsed as dictionary of lists"
|
description="Query parameters from the original request, parsed as dictionary of lists"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -76,14 +75,14 @@ class AuthProviderConfig(BaseModel):
|
||||||
"""Base configuration for authentication providers."""
|
"""Base configuration for authentication providers."""
|
||||||
|
|
||||||
provider_type: AuthProviderType = Field(..., description="Type of authentication provider")
|
provider_type: AuthProviderType = Field(..., description="Type of authentication provider")
|
||||||
config: Dict[str, str] = Field(..., description="Provider-specific configuration")
|
config: dict[str, str] = Field(..., description="Provider-specific configuration")
|
||||||
|
|
||||||
|
|
||||||
class AuthProvider(ABC):
|
class AuthProvider(ABC):
|
||||||
"""Abstract base class for authentication providers."""
|
"""Abstract base class for authentication providers."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def validate_token(self, token: str, scope: Optional[Dict] = None) -> Optional[AccessAttributes]:
|
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None:
|
||||||
"""Validate a token and return access attributes."""
|
"""Validate a token and return access attributes."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -96,7 +95,7 @@ class AuthProvider(ABC):
|
||||||
class KubernetesAuthProvider(AuthProvider):
|
class KubernetesAuthProvider(AuthProvider):
|
||||||
"""Kubernetes authentication provider that validates tokens against the Kubernetes API server."""
|
"""Kubernetes authentication provider that validates tokens against the Kubernetes API server."""
|
||||||
|
|
||||||
def __init__(self, config: Dict[str, str]):
|
def __init__(self, config: dict[str, str]):
|
||||||
self.api_server_url = config["api_server_url"]
|
self.api_server_url = config["api_server_url"]
|
||||||
self.ca_cert_path = config.get("ca_cert_path")
|
self.ca_cert_path = config.get("ca_cert_path")
|
||||||
self._client = None
|
self._client = None
|
||||||
|
@ -120,7 +119,7 @@ class KubernetesAuthProvider(AuthProvider):
|
||||||
self._client = ApiClient(configuration)
|
self._client = ApiClient(configuration)
|
||||||
return self._client
|
return self._client
|
||||||
|
|
||||||
async def validate_token(self, token: str, scope: Optional[Dict] = None) -> Optional[AccessAttributes]:
|
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None:
|
||||||
"""Validate a Kubernetes token and return access attributes."""
|
"""Validate a Kubernetes token and return access attributes."""
|
||||||
try:
|
try:
|
||||||
client = await self._get_client()
|
client = await self._get_client()
|
||||||
|
@ -166,11 +165,11 @@ class KubernetesAuthProvider(AuthProvider):
|
||||||
class CustomAuthProvider(AuthProvider):
|
class CustomAuthProvider(AuthProvider):
|
||||||
"""Custom authentication provider that uses an external endpoint."""
|
"""Custom authentication provider that uses an external endpoint."""
|
||||||
|
|
||||||
def __init__(self, config: Dict[str, str]):
|
def __init__(self, config: dict[str, str]):
|
||||||
self.endpoint = config["endpoint"]
|
self.endpoint = config["endpoint"]
|
||||||
self._client = None
|
self._client = None
|
||||||
|
|
||||||
async def validate_token(self, token: str, scope: Optional[Dict] = None) -> Optional[AccessAttributes]:
|
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None:
|
||||||
"""Validate a token using the custom authentication endpoint."""
|
"""Validate a token using the custom authentication endpoint."""
|
||||||
if not self.endpoint:
|
if not self.endpoint:
|
||||||
raise ValueError("Authentication endpoint not configured")
|
raise ValueError("Authentication endpoint not configured")
|
||||||
|
|
|
@ -6,7 +6,6 @@
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import re
|
import re
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -29,7 +28,7 @@ def toolgroup_protocol_map():
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
def get_all_api_endpoints() -> dict[Api, list[ApiEndpoint]]:
|
||||||
apis = {}
|
apis = {}
|
||||||
|
|
||||||
protocols = api_protocol_map()
|
protocols = api_protocol_map()
|
||||||
|
|
|
@ -15,7 +15,7 @@ import warnings
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from importlib.metadata import version as parse_version
|
from importlib.metadata import version as parse_version
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, List, Optional, Union
|
from typing import Annotated, Any
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from fastapi import Body, FastAPI, HTTPException, Request
|
from fastapi import Body, FastAPI, HTTPException, Request
|
||||||
|
@ -24,7 +24,6 @@ from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
from openai import BadRequestError
|
from openai import BadRequestError
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import LoggingConfig, StackRunConfig
|
from llama_stack.distribution.datatypes import LoggingConfig, StackRunConfig
|
||||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||||
|
@ -91,7 +90,7 @@ async def global_exception_handler(request: Request, exc: Exception):
|
||||||
return JSONResponse(status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}})
|
return JSONResponse(status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}})
|
||||||
|
|
||||||
|
|
||||||
def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidationError]:
|
def translate_exception(exc: Exception) -> HTTPException | RequestValidationError:
|
||||||
if isinstance(exc, ValidationError):
|
if isinstance(exc, ValidationError):
|
||||||
exc = RequestValidationError(exc.errors())
|
exc = RequestValidationError(exc.errors())
|
||||||
|
|
||||||
|
@ -315,7 +314,7 @@ class ClientVersionMiddleware:
|
||||||
return await self.app(scope, receive, send)
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
|
|
||||||
def main(args: Optional[argparse.Namespace] = None):
|
def main(args: argparse.Namespace | None = None):
|
||||||
"""Start the LlamaStack server."""
|
"""Start the LlamaStack server."""
|
||||||
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -385,7 +384,7 @@ def main(args: Optional[argparse.Namespace] = None):
|
||||||
raise ValueError("Either --yaml-config or --template must be provided")
|
raise ValueError("Either --yaml-config or --template must be provided")
|
||||||
|
|
||||||
logger_config = None
|
logger_config = None
|
||||||
with open(config_file, "r") as fp:
|
with open(config_file) as fp:
|
||||||
config_contents = yaml.safe_load(fp)
|
config_contents = yaml.safe_load(fp)
|
||||||
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
||||||
logger_config = LoggingConfig(**cfg)
|
logger_config = LoggingConfig(**cfg)
|
||||||
|
@ -517,7 +516,7 @@ def main(args: Optional[argparse.Namespace] = None):
|
||||||
uvicorn.run(**uvicorn_config)
|
uvicorn.run(**uvicorn_config)
|
||||||
|
|
||||||
|
|
||||||
def extract_path_params(route: str) -> List[str]:
|
def extract_path_params(route: str) -> list[str]:
|
||||||
segments = route.split("/")
|
segments = route.split("/")
|
||||||
params = [seg[1:-1] for seg in segments if seg.startswith("{") and seg.endswith("}")]
|
params = [seg[1:-1] for seg in segments if seg.startswith("{") and seg.endswith("}")]
|
||||||
# to handle path params like {param:path}
|
# to handle path params like {param:path}
|
||||||
|
|
|
@ -8,7 +8,7 @@ import importlib.resources
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
@ -90,7 +90,7 @@ RESOURCES = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
|
async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
||||||
for rsrc, api, register_method, list_method in RESOURCES:
|
for rsrc, api, register_method, list_method in RESOURCES:
|
||||||
objects = getattr(run_config, rsrc)
|
objects = getattr(run_config, rsrc)
|
||||||
if api not in impls:
|
if api not in impls:
|
||||||
|
@ -197,7 +197,7 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]:
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
|
||||||
def add_internal_implementations(impls: Dict[Api, Any], run_config: StackRunConfig) -> None:
|
def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConfig) -> None:
|
||||||
"""Add internal implementations (inspect and providers) to the implementations dictionary.
|
"""Add internal implementations (inspect and providers) to the implementations dictionary.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -220,8 +220,8 @@ def add_internal_implementations(impls: Dict[Api, Any], run_config: StackRunConf
|
||||||
# Produces a stack of providers for the given run config. Not all APIs may be
|
# Produces a stack of providers for the given run config. Not all APIs may be
|
||||||
# asked for in the run config.
|
# asked for in the run config.
|
||||||
async def construct_stack(
|
async def construct_stack(
|
||||||
run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None
|
run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None
|
||||||
) -> Dict[Api, Any]:
|
) -> dict[Api, Any]:
|
||||||
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
|
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
|
||||||
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry)
|
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry)
|
||||||
|
|
||||||
|
@ -244,7 +244,7 @@ def get_stack_run_config_from_template(template: str) -> StackRunConfig:
|
||||||
|
|
||||||
|
|
||||||
def run_config_from_adhoc_config_spec(
|
def run_config_from_adhoc_config_spec(
|
||||||
adhoc_config_spec: str, provider_registry: Optional[ProviderRegistry] = None
|
adhoc_config_spec: str, provider_registry: ProviderRegistry | None = None
|
||||||
) -> StackRunConfig:
|
) -> StackRunConfig:
|
||||||
"""
|
"""
|
||||||
Create an adhoc distribution from a list of API providers.
|
Create an adhoc distribution from a list of API providers.
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import Dict, List, Optional, Protocol, Tuple
|
from typing import Protocol
|
||||||
|
|
||||||
import pydantic
|
import pydantic
|
||||||
|
|
||||||
|
@ -20,13 +20,13 @@ logger = get_logger(__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
class DistributionRegistry(Protocol):
|
class DistributionRegistry(Protocol):
|
||||||
async def get_all(self) -> List[RoutableObjectWithProvider]: ...
|
async def get_all(self) -> list[RoutableObjectWithProvider]: ...
|
||||||
|
|
||||||
async def initialize(self) -> None: ...
|
async def initialize(self) -> None: ...
|
||||||
|
|
||||||
async def get(self, identifier: str) -> Optional[RoutableObjectWithProvider]: ...
|
async def get(self, identifier: str) -> RoutableObjectWithProvider | None: ...
|
||||||
|
|
||||||
def get_cached(self, identifier: str) -> Optional[RoutableObjectWithProvider]: ...
|
def get_cached(self, identifier: str) -> RoutableObjectWithProvider | None: ...
|
||||||
|
|
||||||
async def update(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider: ...
|
async def update(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider: ...
|
||||||
|
|
||||||
|
@ -40,13 +40,13 @@ KEY_VERSION = "v8"
|
||||||
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
||||||
|
|
||||||
|
|
||||||
def _get_registry_key_range() -> Tuple[str, str]:
|
def _get_registry_key_range() -> tuple[str, str]:
|
||||||
"""Returns the start and end keys for the registry range query."""
|
"""Returns the start and end keys for the registry range query."""
|
||||||
start_key = f"{REGISTER_PREFIX}:{KEY_VERSION}"
|
start_key = f"{REGISTER_PREFIX}:{KEY_VERSION}"
|
||||||
return start_key, f"{start_key}\xff"
|
return start_key, f"{start_key}\xff"
|
||||||
|
|
||||||
|
|
||||||
def _parse_registry_values(values: List[str]) -> List[RoutableObjectWithProvider]:
|
def _parse_registry_values(values: list[str]) -> list[RoutableObjectWithProvider]:
|
||||||
"""Utility function to parse registry values into RoutableObjectWithProvider objects."""
|
"""Utility function to parse registry values into RoutableObjectWithProvider objects."""
|
||||||
all_objects = []
|
all_objects = []
|
||||||
for value in values:
|
for value in values:
|
||||||
|
@ -67,16 +67,16 @@ class DiskDistributionRegistry(DistributionRegistry):
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_cached(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
|
def get_cached(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
|
||||||
# Disk registry does not have a cache
|
# Disk registry does not have a cache
|
||||||
raise NotImplementedError("Disk registry does not have a cache")
|
raise NotImplementedError("Disk registry does not have a cache")
|
||||||
|
|
||||||
async def get_all(self) -> List[RoutableObjectWithProvider]:
|
async def get_all(self) -> list[RoutableObjectWithProvider]:
|
||||||
start_key, end_key = _get_registry_key_range()
|
start_key, end_key = _get_registry_key_range()
|
||||||
values = await self.kvstore.range(start_key, end_key)
|
values = await self.kvstore.range(start_key, end_key)
|
||||||
return _parse_registry_values(values)
|
return _parse_registry_values(values)
|
||||||
|
|
||||||
async def get(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
|
async def get(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
|
||||||
json_str = await self.kvstore.get(KEY_FORMAT.format(type=type, identifier=identifier))
|
json_str = await self.kvstore.get(KEY_FORMAT.format(type=type, identifier=identifier))
|
||||||
if not json_str:
|
if not json_str:
|
||||||
return None
|
return None
|
||||||
|
@ -113,7 +113,7 @@ class DiskDistributionRegistry(DistributionRegistry):
|
||||||
class CachedDiskDistributionRegistry(DiskDistributionRegistry):
|
class CachedDiskDistributionRegistry(DiskDistributionRegistry):
|
||||||
def __init__(self, kvstore: KVStore):
|
def __init__(self, kvstore: KVStore):
|
||||||
super().__init__(kvstore)
|
super().__init__(kvstore)
|
||||||
self.cache: Dict[Tuple[str, str], RoutableObjectWithProvider] = {}
|
self.cache: dict[tuple[str, str], RoutableObjectWithProvider] = {}
|
||||||
self._initialized = False
|
self._initialized = False
|
||||||
self._initialize_lock = asyncio.Lock()
|
self._initialize_lock = asyncio.Lock()
|
||||||
self._cache_lock = asyncio.Lock()
|
self._cache_lock = asyncio.Lock()
|
||||||
|
@ -147,15 +147,15 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
await self._ensure_initialized()
|
await self._ensure_initialized()
|
||||||
|
|
||||||
def get_cached(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
|
def get_cached(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
|
||||||
return self.cache.get((type, identifier), None)
|
return self.cache.get((type, identifier), None)
|
||||||
|
|
||||||
async def get_all(self) -> List[RoutableObjectWithProvider]:
|
async def get_all(self) -> list[RoutableObjectWithProvider]:
|
||||||
await self._ensure_initialized()
|
await self._ensure_initialized()
|
||||||
async with self._locked_cache() as cache:
|
async with self._locked_cache() as cache:
|
||||||
return list(cache.values())
|
return list(cache.values())
|
||||||
|
|
||||||
async def get(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
|
async def get(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
|
||||||
await self._ensure_initialized()
|
await self._ensure_initialized()
|
||||||
cache_key = (type, identifier)
|
cache_key = (type, identifier)
|
||||||
|
|
||||||
|
@ -189,7 +189,7 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
|
||||||
|
|
||||||
|
|
||||||
async def create_dist_registry(
|
async def create_dist_registry(
|
||||||
metadata_store: Optional[KVStoreConfig],
|
metadata_store: KVStoreConfig | None,
|
||||||
image_name: str,
|
image_name: str,
|
||||||
) -> tuple[CachedDiskDistributionRegistry, KVStore]:
|
) -> tuple[CachedDiskDistributionRegistry, KVStore]:
|
||||||
# instantiate kvstore for storing and retrieving distribution metadata
|
# instantiate kvstore for storing and retrieving distribution metadata
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from llama_stack_client import LlamaStackClient
|
from llama_stack_client import LlamaStackClient
|
||||||
|
|
||||||
|
@ -23,7 +22,7 @@ class LlamaStackApi:
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def run_scoring(self, row, scoring_function_ids: list[str], scoring_params: Optional[dict]):
|
def run_scoring(self, row, scoring_function_ids: list[str], scoring_params: dict | None):
|
||||||
"""Run scoring on a single row"""
|
"""Run scoring on a single row"""
|
||||||
if not scoring_params:
|
if not scoring_params:
|
||||||
scoring_params = {fn_id: None for fn_id in scoring_function_ids}
|
scoring_params = {fn_id: None for fn_id in scoring_function_ids}
|
||||||
|
|
|
@ -4,10 +4,10 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
|
def redact_sensitive_fields(data: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""Redact sensitive information from config before printing."""
|
"""Redact sensitive information from config before printing."""
|
||||||
sensitive_patterns = ["api_key", "api_token", "password", "secret"]
|
sensitive_patterns = ["api_key", "api_token", "password", "secret"]
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@ def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
return [_redact_value(i) for i in v]
|
return [_redact_value(i) for i in v]
|
||||||
return v
|
return v
|
||||||
|
|
||||||
def _redact_dict(d: Dict[str, Any]) -> Dict[str, Any]:
|
def _redact_dict(d: dict[str, Any]) -> dict[str, Any]:
|
||||||
result = {}
|
result = {}
|
||||||
for k, v in d.items():
|
for k, v in d.items():
|
||||||
if any(pattern in k.lower() for pattern in sensitive_patterns):
|
if any(pattern in k.lower() for pattern in sensitive_patterns):
|
||||||
|
|
|
@ -4,14 +4,15 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from typing import AsyncGenerator, List, TypeVar
|
from typing import TypeVar
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
def preserve_contexts_async_generator(
|
def preserve_contexts_async_generator(
|
||||||
gen: AsyncGenerator[T, None], context_vars: List[ContextVar]
|
gen: AsyncGenerator[T, None], context_vars: list[ContextVar]
|
||||||
) -> AsyncGenerator[T, None]:
|
) -> AsyncGenerator[T, None]:
|
||||||
"""
|
"""
|
||||||
Wraps an async generator to preserve context variables across iterations.
|
Wraps an async generator to preserve context variables across iterations.
|
||||||
|
|
|
@ -8,12 +8,11 @@ import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, List, Literal, Optional, Type, Union, get_args, get_origin
|
from typing import Annotated, Any, Literal, Union, get_args, get_origin
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from pydantic.fields import FieldInfo
|
from pydantic.fields import FieldInfo
|
||||||
from pydantic_core import PydanticUndefinedType
|
from pydantic_core import PydanticUndefinedType
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -21,7 +20,7 @@ log = logging.getLogger(__name__)
|
||||||
def is_list_of_primitives(field_type):
|
def is_list_of_primitives(field_type):
|
||||||
"""Check if a field type is a List of primitive types."""
|
"""Check if a field type is a List of primitive types."""
|
||||||
origin = get_origin(field_type)
|
origin = get_origin(field_type)
|
||||||
if origin is List or origin is list:
|
if origin is list or origin is list:
|
||||||
args = get_args(field_type)
|
args = get_args(field_type)
|
||||||
if len(args) == 1 and args[0] in (int, float, str, bool):
|
if len(args) == 1 and args[0] in (int, float, str, bool):
|
||||||
return True
|
return True
|
||||||
|
@ -53,7 +52,7 @@ def get_non_none_type(field_type):
|
||||||
return next(arg for arg in get_args(field_type) if arg is not type(None))
|
return next(arg for arg in get_args(field_type) if arg is not type(None))
|
||||||
|
|
||||||
|
|
||||||
def manually_validate_field(model: Type[BaseModel], field_name: str, value: Any):
|
def manually_validate_field(model: type[BaseModel], field_name: str, value: Any):
|
||||||
validators = model.__pydantic_decorators__.field_validators
|
validators = model.__pydantic_decorators__.field_validators
|
||||||
for _name, validator in validators.items():
|
for _name, validator in validators.items():
|
||||||
if field_name in validator.info.fields:
|
if field_name in validator.info.fields:
|
||||||
|
@ -126,7 +125,7 @@ def prompt_for_discriminated_union(
|
||||||
#
|
#
|
||||||
# doesn't support List[nested_class] yet or Dicts of any kind. needs a bunch of
|
# doesn't support List[nested_class] yet or Dicts of any kind. needs a bunch of
|
||||||
# unit tests for coverage.
|
# unit tests for coverage.
|
||||||
def prompt_for_config(config_type: type[BaseModel], existing_config: Optional[BaseModel] = None) -> BaseModel:
|
def prompt_for_config(config_type: type[BaseModel], existing_config: BaseModel | None = None) -> BaseModel:
|
||||||
"""
|
"""
|
||||||
Recursively prompt the user for configuration values based on a Pydantic BaseModel.
|
Recursively prompt the user for configuration values based on a Pydantic BaseModel.
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from logging.config import dictConfig
|
from logging.config import dictConfig
|
||||||
from typing import Dict, Optional
|
|
||||||
|
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.errors import MarkupError
|
from rich.errors import MarkupError
|
||||||
|
@ -33,7 +32,7 @@ CATEGORIES = [
|
||||||
]
|
]
|
||||||
|
|
||||||
# Initialize category levels with default level
|
# Initialize category levels with default level
|
||||||
_category_levels: Dict[str, int] = {category: DEFAULT_LOG_LEVEL for category in CATEGORIES}
|
_category_levels: dict[str, int] = {category: DEFAULT_LOG_LEVEL for category in CATEGORIES}
|
||||||
|
|
||||||
|
|
||||||
def config_to_category_levels(category: str, level: str):
|
def config_to_category_levels(category: str, level: str):
|
||||||
|
@ -49,7 +48,7 @@ def config_to_category_levels(category: str, level: str):
|
||||||
Dict[str, int]: A dictionary mapping categories to their log levels.
|
Dict[str, int]: A dictionary mapping categories to their log levels.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
category_levels: Dict[str, int] = {}
|
category_levels: dict[str, int] = {}
|
||||||
level_value = logging._nameToLevel.get(str(level).upper())
|
level_value = logging._nameToLevel.get(str(level).upper())
|
||||||
if level_value is None:
|
if level_value is None:
|
||||||
logging.warning(f"Unknown log level '{level}' for category '{category}'. Falling back to default 'INFO'.")
|
logging.warning(f"Unknown log level '{level}' for category '{category}'. Falling back to default 'INFO'.")
|
||||||
|
@ -69,7 +68,7 @@ def config_to_category_levels(category: str, level: str):
|
||||||
return category_levels
|
return category_levels
|
||||||
|
|
||||||
|
|
||||||
def parse_yaml_config(yaml_config: LoggingConfig) -> Dict[str, int]:
|
def parse_yaml_config(yaml_config: LoggingConfig) -> dict[str, int]:
|
||||||
"""
|
"""
|
||||||
Helper function to parse a yaml logging configuration found in the run.yaml
|
Helper function to parse a yaml logging configuration found in the run.yaml
|
||||||
|
|
||||||
|
@ -86,7 +85,7 @@ def parse_yaml_config(yaml_config: LoggingConfig) -> Dict[str, int]:
|
||||||
return category_levels
|
return category_levels
|
||||||
|
|
||||||
|
|
||||||
def parse_environment_config(env_config: str) -> Dict[str, int]:
|
def parse_environment_config(env_config: str) -> dict[str, int]:
|
||||||
"""
|
"""
|
||||||
Parse the LLAMA_STACK_LOGGING environment variable and return a dictionary of category log levels.
|
Parse the LLAMA_STACK_LOGGING environment variable and return a dictionary of category log levels.
|
||||||
|
|
||||||
|
@ -131,7 +130,7 @@ class CustomRichHandler(RichHandler):
|
||||||
self.markup = original_markup
|
self.markup = original_markup
|
||||||
|
|
||||||
|
|
||||||
def setup_logging(category_levels: Dict[str, int], log_file: str | None) -> None:
|
def setup_logging(category_levels: dict[str, int], log_file: str | None) -> None:
|
||||||
"""
|
"""
|
||||||
Configure logging based on the provided category log levels and an optional log file.
|
Configure logging based on the provided category log levels and an optional log file.
|
||||||
|
|
||||||
|
@ -211,7 +210,7 @@ def setup_logging(category_levels: Dict[str, int], log_file: str | None) -> None
|
||||||
|
|
||||||
|
|
||||||
def get_logger(
|
def get_logger(
|
||||||
name: str, category: str = "uncategorized", config: Optional[LoggingConfig] | None = None
|
name: str, category: str = "uncategorized", config: LoggingConfig | None | None = None
|
||||||
) -> logging.LoggerAdapter:
|
) -> logging.LoggerAdapter:
|
||||||
"""
|
"""
|
||||||
Returns a logger with the specified name and category.
|
Returns a logger with the specified name and category.
|
||||||
|
|
|
@ -7,14 +7,14 @@
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank, get_model_parallel_world_size
|
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank, get_model_parallel_world_size
|
||||||
|
|
||||||
|
|
||||||
def map_mp_rank(old_mp_size: int, new_mp_size: int, new_mp_rank: int) -> List[int]:
|
def map_mp_rank(old_mp_size: int, new_mp_size: int, new_mp_rank: int) -> list[int]:
|
||||||
"""Map a new MP rank to a list of old MP ranks given a change in MP size."""
|
"""Map a new MP rank to a list of old MP ranks given a change in MP size."""
|
||||||
if new_mp_size % old_mp_size == 0:
|
if new_mp_size % old_mp_size == 0:
|
||||||
# Read old MP shard and split it into smaller ones
|
# Read old MP shard and split it into smaller ones
|
||||||
|
@ -31,12 +31,12 @@ def map_mp_rank(old_mp_size: int, new_mp_size: int, new_mp_rank: int) -> List[in
|
||||||
|
|
||||||
|
|
||||||
def maybe_reshard_state_dict(
|
def maybe_reshard_state_dict(
|
||||||
ckpt_paths: List[Path],
|
ckpt_paths: list[Path],
|
||||||
n_kv_heads: int,
|
n_kv_heads: int,
|
||||||
moe_num_experts: Optional[int] = None,
|
moe_num_experts: int | None = None,
|
||||||
map_location: Union[str, torch.device] = "cpu",
|
map_location: str | torch.device = "cpu",
|
||||||
mmap: bool = True,
|
mmap: bool = True,
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> dict[str, torch.Tensor]:
|
||||||
if str(map_location) == "cpu":
|
if str(map_location) == "cpu":
|
||||||
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||||
else:
|
else:
|
||||||
|
@ -97,18 +97,18 @@ _MOE_WEIGHT_COLUMN_KEY = {"feed_forward.experts.moe_w_out_eF_D"}
|
||||||
|
|
||||||
|
|
||||||
def reshard_mp(
|
def reshard_mp(
|
||||||
state_dicts: List[Dict[str, torch.Tensor]],
|
state_dicts: list[dict[str, torch.Tensor]],
|
||||||
size: int,
|
size: int,
|
||||||
rank: int,
|
rank: int,
|
||||||
repeat_qk_qv: int = 1,
|
repeat_qk_qv: int = 1,
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> dict[str, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Reshard a list of state dicts into a single state dict given a change in MP size.
|
Reshard a list of state dicts into a single state dict given a change in MP size.
|
||||||
If the list has more than one state dict, we concatenate the values of the same
|
If the list has more than one state dict, we concatenate the values of the same
|
||||||
key across all state dicts. Otherwise, we just slice it for the current MP rank.
|
key across all state dicts. Otherwise, we just slice it for the current MP rank.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def concat_or_chunk(tensors: List[torch.Tensor], dim: int) -> torch.Tensor:
|
def concat_or_chunk(tensors: list[torch.Tensor], dim: int) -> torch.Tensor:
|
||||||
if len(tensors) > 1:
|
if len(tensors) > 1:
|
||||||
return torch.cat(tensors, dim=dim)
|
return torch.cat(tensors, dim=dim)
|
||||||
return tensors[0].chunk(size, dim=dim)[rank].clone()
|
return tensors[0].chunk(size, dim=dim)[rank].clone()
|
||||||
|
@ -144,7 +144,7 @@ def reshard_mp(
|
||||||
column_regex = re.compile("|".join(column_keys))
|
column_regex = re.compile("|".join(column_keys))
|
||||||
row_regex = re.compile("|".join(row_keys))
|
row_regex = re.compile("|".join(row_keys))
|
||||||
|
|
||||||
output: Dict[str, torch.Tensor] = {}
|
output: dict[str, torch.Tensor] = {}
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
# Note: only processes keys in the first state dict.
|
# Note: only processes keys in the first state dict.
|
||||||
# Assumes keys are the same across all state dicts.
|
# Assumes keys are the same across all state dicts.
|
||||||
|
@ -154,7 +154,7 @@ def reshard_mp(
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def convert_moe_weights(state_dict: Dict[str, Any], num_experts: int) -> Dict[str, Any]:
|
def convert_moe_weights(state_dict: dict[str, Any], num_experts: int) -> dict[str, Any]:
|
||||||
routed_keys = _MOE_WEIGHT_ROW_KEY | _MOE_WEIGHT_COLUMN_KEY
|
routed_keys = _MOE_WEIGHT_ROW_KEY | _MOE_WEIGHT_COLUMN_KEY
|
||||||
routed_regex = re.compile("|".join(routed_keys))
|
routed_regex = re.compile("|".join(routed_keys))
|
||||||
keys = list(state_dict.keys())
|
keys = list(state_dict.keys())
|
||||||
|
|
|
@ -7,10 +7,9 @@
|
||||||
import base64
|
import base64
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Any, Dict, List, Literal, Optional, Union
|
from typing import Annotated, Any, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
|
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
# The goal is that these set of types are relevant for all Llama models.
|
# The goal is that these set of types are relevant for all Llama models.
|
||||||
# That isn't the current state yet -- e.g., BuiltinTool is somewhat specific to
|
# That isn't the current state yet -- e.g., BuiltinTool is somewhat specific to
|
||||||
|
@ -31,21 +30,21 @@ class BuiltinTool(Enum):
|
||||||
code_interpreter = "code_interpreter"
|
code_interpreter = "code_interpreter"
|
||||||
|
|
||||||
|
|
||||||
Primitive = Union[str, int, float, bool, None]
|
Primitive = str | int | float | bool | None
|
||||||
RecursiveType = Union[Primitive, List[Primitive], Dict[str, Primitive]]
|
RecursiveType = Primitive | list[Primitive] | dict[str, Primitive]
|
||||||
|
|
||||||
|
|
||||||
class ToolCall(BaseModel):
|
class ToolCall(BaseModel):
|
||||||
call_id: str
|
call_id: str
|
||||||
tool_name: Union[BuiltinTool, str]
|
tool_name: BuiltinTool | str
|
||||||
# Plan is to deprecate the Dict in favor of a JSON string
|
# Plan is to deprecate the Dict in favor of a JSON string
|
||||||
# that is parsed on the client side instead of trying to manage
|
# that is parsed on the client side instead of trying to manage
|
||||||
# the recursive type here.
|
# the recursive type here.
|
||||||
# Making this a union so that client side can start prepping for this change.
|
# Making this a union so that client side can start prepping for this change.
|
||||||
# Eventually, we will remove both the Dict and arguments_json field,
|
# Eventually, we will remove both the Dict and arguments_json field,
|
||||||
# and arguments will just be a str
|
# and arguments will just be a str
|
||||||
arguments: Union[str, Dict[str, RecursiveType]]
|
arguments: str | dict[str, RecursiveType]
|
||||||
arguments_json: Optional[str] = None
|
arguments_json: str | None = None
|
||||||
|
|
||||||
@field_validator("tool_name", mode="before")
|
@field_validator("tool_name", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -91,15 +90,15 @@ class StopReason(Enum):
|
||||||
|
|
||||||
class ToolParamDefinition(BaseModel):
|
class ToolParamDefinition(BaseModel):
|
||||||
param_type: str
|
param_type: str
|
||||||
description: Optional[str] = None
|
description: str | None = None
|
||||||
required: Optional[bool] = True
|
required: bool | None = True
|
||||||
default: Optional[Any] = None
|
default: Any | None = None
|
||||||
|
|
||||||
|
|
||||||
class ToolDefinition(BaseModel):
|
class ToolDefinition(BaseModel):
|
||||||
tool_name: Union[BuiltinTool, str]
|
tool_name: BuiltinTool | str
|
||||||
description: Optional[str] = None
|
description: str | None = None
|
||||||
parameters: Optional[Dict[str, ToolParamDefinition]] = None
|
parameters: dict[str, ToolParamDefinition] | None = None
|
||||||
|
|
||||||
@field_validator("tool_name", mode="before")
|
@field_validator("tool_name", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -119,7 +118,7 @@ class RawMediaItem(BaseModel):
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
@field_serializer("data")
|
@field_serializer("data")
|
||||||
def serialize_data(self, data: Optional[bytes], _info):
|
def serialize_data(self, data: bytes | None, _info):
|
||||||
if data is None:
|
if data is None:
|
||||||
return None
|
return None
|
||||||
return base64.b64encode(data).decode("utf-8")
|
return base64.b64encode(data).decode("utf-8")
|
||||||
|
@ -137,9 +136,9 @@ class RawTextItem(BaseModel):
|
||||||
text: str
|
text: str
|
||||||
|
|
||||||
|
|
||||||
RawContentItem = Annotated[Union[RawTextItem, RawMediaItem], Field(discriminator="type")]
|
RawContentItem = Annotated[RawTextItem | RawMediaItem, Field(discriminator="type")]
|
||||||
|
|
||||||
RawContent = str | RawContentItem | List[RawContentItem]
|
RawContent = str | RawContentItem | list[RawContentItem]
|
||||||
|
|
||||||
|
|
||||||
class RawMessage(BaseModel):
|
class RawMessage(BaseModel):
|
||||||
|
@ -147,17 +146,17 @@ class RawMessage(BaseModel):
|
||||||
content: RawContent
|
content: RawContent
|
||||||
|
|
||||||
# This is for RAG but likely should be absorbed into content
|
# This is for RAG but likely should be absorbed into content
|
||||||
context: Optional[RawContent] = None
|
context: RawContent | None = None
|
||||||
|
|
||||||
# These are for the output message coming from the assistant
|
# These are for the output message coming from the assistant
|
||||||
stop_reason: Optional[StopReason] = None
|
stop_reason: StopReason | None = None
|
||||||
tool_calls: List[ToolCall] = Field(default_factory=list)
|
tool_calls: list[ToolCall] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class GenerationResult(BaseModel):
|
class GenerationResult(BaseModel):
|
||||||
token: int
|
token: int
|
||||||
text: str
|
text: str
|
||||||
logprobs: Optional[List[float]] = None
|
logprobs: list[float] | None = None
|
||||||
|
|
||||||
source: Literal["input"] | Literal["output"]
|
source: Literal["input"] | Literal["output"]
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,6 @@
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
|
|
||||||
class QuantizationScheme(Enum):
|
class QuantizationScheme(Enum):
|
||||||
|
@ -15,8 +14,8 @@ class QuantizationScheme(Enum):
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class QuantizationArgs:
|
class QuantizationArgs:
|
||||||
scheme: Optional[QuantizationScheme] = None
|
scheme: QuantizationScheme | None = None
|
||||||
group_size: Optional[int] = None
|
group_size: int | None = None
|
||||||
spinquant: bool = False
|
spinquant: bool = False
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
|
@ -39,10 +38,10 @@ class ModelArgs:
|
||||||
dim: int = 4096
|
dim: int = 4096
|
||||||
n_layers: int = 32
|
n_layers: int = 32
|
||||||
n_heads: int = 32
|
n_heads: int = 32
|
||||||
n_kv_heads: Optional[int] = None
|
n_kv_heads: int | None = None
|
||||||
vocab_size: int = -1
|
vocab_size: int = -1
|
||||||
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
|
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
|
||||||
ffn_dim_multiplier: Optional[float] = None
|
ffn_dim_multiplier: float | None = None
|
||||||
norm_eps: float = 1e-5
|
norm_eps: float = 1e-5
|
||||||
rope_theta: float = 500000
|
rope_theta: float = 500000
|
||||||
use_scaled_rope: bool = False
|
use_scaled_rope: bool = False
|
||||||
|
@ -55,8 +54,8 @@ class ModelArgs:
|
||||||
vision_max_num_chunks: int = 4
|
vision_max_num_chunks: int = 4
|
||||||
vision_num_cross_attention_layers: int = -1
|
vision_num_cross_attention_layers: int = -1
|
||||||
|
|
||||||
quantization_args: Optional[QuantizationArgs] = None
|
quantization_args: QuantizationArgs | None = None
|
||||||
lora_args: Optional[LoRAArgs] = None
|
lora_args: LoRAArgs | None = None
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
for k, v in kwargs.items():
|
for k, v in kwargs.items():
|
||||||
|
|
|
@ -8,7 +8,6 @@ import io
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
from PIL import Image as PIL_Image
|
from PIL import Image as PIL_Image
|
||||||
|
|
||||||
|
@ -29,14 +28,14 @@ from .tool_utils import ToolUtils
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class VisionInput:
|
class VisionInput:
|
||||||
mask: List[List[int]]
|
mask: list[list[int]]
|
||||||
images: List[PIL_Image.Image]
|
images: list[PIL_Image.Image]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LLMInput:
|
class LLMInput:
|
||||||
tokens: List[int]
|
tokens: list[int]
|
||||||
vision: Optional[VisionInput] = None
|
vision: VisionInput | None = None
|
||||||
|
|
||||||
|
|
||||||
def role_str(role: Role) -> str:
|
def role_str(role: Role) -> str:
|
||||||
|
@ -50,7 +49,7 @@ def role_str(role: Role) -> str:
|
||||||
|
|
||||||
|
|
||||||
class ChatFormat:
|
class ChatFormat:
|
||||||
possible_headers: Dict[Role, str]
|
possible_headers: dict[Role, str]
|
||||||
|
|
||||||
def __init__(self, tokenizer: Tokenizer):
|
def __init__(self, tokenizer: Tokenizer):
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
@ -58,7 +57,7 @@ class ChatFormat:
|
||||||
self.possible_headers = {role: f"<|start_header_id|>{role_str(role)}<|end_header_id|>\n\n" for role in Role}
|
self.possible_headers = {role: f"<|start_header_id|>{role_str(role)}<|end_header_id|>\n\n" for role in Role}
|
||||||
self.vision_token = self.tokenizer.special_tokens["<|image|>"]
|
self.vision_token = self.tokenizer.special_tokens["<|image|>"]
|
||||||
|
|
||||||
def _encode_header(self, role: str) -> List[int]:
|
def _encode_header(self, role: str) -> list[int]:
|
||||||
tokens = []
|
tokens = []
|
||||||
tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"])
|
tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"])
|
||||||
tokens.extend(self.tokenizer.encode("ipython" if role == "tool" else role, bos=False, eos=False))
|
tokens.extend(self.tokenizer.encode("ipython" if role == "tool" else role, bos=False, eos=False))
|
||||||
|
@ -70,7 +69,7 @@ class ChatFormat:
|
||||||
tokens, images = self._encode_content(content, bos=True)
|
tokens, images = self._encode_content(content, bos=True)
|
||||||
return self._model_input_from_tokens_images(tokens, images)
|
return self._model_input_from_tokens_images(tokens, images)
|
||||||
|
|
||||||
def _encode_content(self, content: RawContent, bos: bool = False) -> Tuple[List[int], List[PIL_Image.Image]]:
|
def _encode_content(self, content: RawContent, bos: bool = False) -> tuple[list[int], list[PIL_Image.Image]]:
|
||||||
tokens = []
|
tokens = []
|
||||||
images = []
|
images = []
|
||||||
|
|
||||||
|
@ -107,7 +106,7 @@ class ChatFormat:
|
||||||
|
|
||||||
def encode_message(
|
def encode_message(
|
||||||
self, message: RawMessage, tool_prompt_format: ToolPromptFormat
|
self, message: RawMessage, tool_prompt_format: ToolPromptFormat
|
||||||
) -> Tuple[List[int], List[PIL_Image.Image]]:
|
) -> tuple[list[int], list[PIL_Image.Image]]:
|
||||||
tokens = self._encode_header(message.role)
|
tokens = self._encode_header(message.role)
|
||||||
images = []
|
images = []
|
||||||
|
|
||||||
|
@ -145,8 +144,8 @@ class ChatFormat:
|
||||||
|
|
||||||
def encode_dialog_prompt(
|
def encode_dialog_prompt(
|
||||||
self,
|
self,
|
||||||
messages: List[RawMessage],
|
messages: list[RawMessage],
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: ToolPromptFormat | None = None,
|
||||||
) -> LLMInput:
|
) -> LLMInput:
|
||||||
tool_prompt_format = tool_prompt_format or ToolPromptFormat.json
|
tool_prompt_format = tool_prompt_format or ToolPromptFormat.json
|
||||||
tokens = []
|
tokens = []
|
||||||
|
@ -163,7 +162,7 @@ class ChatFormat:
|
||||||
return self._model_input_from_tokens_images(tokens, images)
|
return self._model_input_from_tokens_images(tokens, images)
|
||||||
|
|
||||||
# TODO(this should be generic, not only for assistant messages)
|
# TODO(this should be generic, not only for assistant messages)
|
||||||
def decode_assistant_message(self, tokens: List[int], stop_reason: StopReason) -> RawMessage:
|
def decode_assistant_message(self, tokens: list[int], stop_reason: StopReason) -> RawMessage:
|
||||||
content = self.tokenizer.decode(tokens)
|
content = self.tokenizer.decode(tokens)
|
||||||
|
|
||||||
return self.decode_assistant_message_from_content(content, stop_reason)
|
return self.decode_assistant_message_from_content(content, stop_reason)
|
||||||
|
@ -234,7 +233,7 @@ class ChatFormat:
|
||||||
tool_calls=tool_calls,
|
tool_calls=tool_calls,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _model_input_from_tokens_images(self, tokens: List[int], images: List[PIL_Image.Image]) -> LLMInput:
|
def _model_input_from_tokens_images(self, tokens: list[int], images: list[PIL_Image.Image]) -> LLMInput:
|
||||||
vision_input = None
|
vision_input = None
|
||||||
if len(images) > 0:
|
if len(images) > 0:
|
||||||
vision_input = VisionInput(
|
vision_input = VisionInput(
|
||||||
|
@ -249,9 +248,9 @@ class ChatFormat:
|
||||||
|
|
||||||
|
|
||||||
def create_vision_mask(
|
def create_vision_mask(
|
||||||
tokens: List[int],
|
tokens: list[int],
|
||||||
vision_token: int,
|
vision_token: int,
|
||||||
) -> List[List[int]]:
|
) -> list[list[int]]:
|
||||||
vision_token_locations = [i for i, token in enumerate(tokens) if token == vision_token]
|
vision_token_locations = [i for i, token in enumerate(tokens) if token == vision_token]
|
||||||
if len(vision_token_locations) == 0:
|
if len(vision_token_locations) == 0:
|
||||||
return []
|
return []
|
||||||
|
|
|
@ -15,8 +15,8 @@ import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
from collections.abc import Callable, Generator
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Generator, List, Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
@ -41,8 +41,8 @@ class Llama3:
|
||||||
ckpt_dir: str,
|
ckpt_dir: str,
|
||||||
max_seq_len: int,
|
max_seq_len: int,
|
||||||
max_batch_size: int,
|
max_batch_size: int,
|
||||||
world_size: Optional[int] = None,
|
world_size: int | None = None,
|
||||||
quantization_mode: Optional[QuantizationMode] = None,
|
quantization_mode: QuantizationMode | None = None,
|
||||||
seed: int = 1,
|
seed: int = 1,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
):
|
):
|
||||||
|
@ -82,7 +82,7 @@ class Llama3:
|
||||||
ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
|
ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||||
assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}"
|
assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||||
print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})")
|
print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})")
|
||||||
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
with open(Path(ckpt_dir) / "params.json") as f:
|
||||||
params = json.loads(f.read())
|
params = json.loads(f.read())
|
||||||
|
|
||||||
model_args: ModelArgs = ModelArgs(
|
model_args: ModelArgs = ModelArgs(
|
||||||
|
@ -154,15 +154,15 @@ class Llama3:
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
llm_inputs: List[LLMInput],
|
llm_inputs: list[LLMInput],
|
||||||
temperature: float = 0.6,
|
temperature: float = 0.6,
|
||||||
top_p: float = 0.9,
|
top_p: float = 0.9,
|
||||||
max_gen_len: Optional[int] = None,
|
max_gen_len: int | None = None,
|
||||||
logprobs: bool = False,
|
logprobs: bool = False,
|
||||||
echo: bool = False,
|
echo: bool = False,
|
||||||
print_model_input: bool = False,
|
print_model_input: bool = False,
|
||||||
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
logits_processor: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
|
||||||
) -> Generator[List[GenerationResult], None, None]:
|
) -> Generator[list[GenerationResult], None, None]:
|
||||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
||||||
max_gen_len = self.args.max_seq_len - 1
|
max_gen_len = self.args.max_seq_len - 1
|
||||||
params = self.model.params
|
params = self.model.params
|
||||||
|
@ -302,13 +302,13 @@ class Llama3:
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
self,
|
self,
|
||||||
contents: List[RawContent],
|
contents: list[RawContent],
|
||||||
temperature: float = 0.6,
|
temperature: float = 0.6,
|
||||||
top_p: float = 0.9,
|
top_p: float = 0.9,
|
||||||
max_gen_len: Optional[int] = None,
|
max_gen_len: int | None = None,
|
||||||
logprobs: bool = False,
|
logprobs: bool = False,
|
||||||
echo: bool = False,
|
echo: bool = False,
|
||||||
) -> Generator[List[GenerationResult], None, None]:
|
) -> Generator[list[GenerationResult], None, None]:
|
||||||
model_inputs = [self.formatter.encode_content(c) for c in contents]
|
model_inputs = [self.formatter.encode_content(c) for c in contents]
|
||||||
for result in self.generate(
|
for result in self.generate(
|
||||||
model_inputs=model_inputs,
|
model_inputs=model_inputs,
|
||||||
|
@ -324,14 +324,14 @@ class Llama3:
|
||||||
|
|
||||||
def chat_completion(
|
def chat_completion(
|
||||||
self,
|
self,
|
||||||
messages_batch: List[List[RawMessage]],
|
messages_batch: list[list[RawMessage]],
|
||||||
temperature: float = 0.6,
|
temperature: float = 0.6,
|
||||||
top_p: float = 0.9,
|
top_p: float = 0.9,
|
||||||
max_gen_len: Optional[int] = None,
|
max_gen_len: int | None = None,
|
||||||
logprobs: bool = False,
|
logprobs: bool = False,
|
||||||
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
||||||
echo: bool = False,
|
echo: bool = False,
|
||||||
) -> Generator[List[GenerationResult], None, None]:
|
) -> Generator[list[GenerationResult], None, None]:
|
||||||
model_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
|
model_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
|
||||||
for result in self.generate(
|
for result in self.generate(
|
||||||
model_inputs=model_inputs,
|
model_inputs=model_inputs,
|
||||||
|
|
|
@ -12,7 +12,6 @@
|
||||||
# the top-level of this source tree.
|
# the top-level of this source tree.
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
|
@ -131,7 +130,7 @@ class LLama31Interface:
|
||||||
self.formatter = ChatFormat(self.tokenizer)
|
self.formatter = ChatFormat(self.tokenizer)
|
||||||
self.tool_prompt_format = tool_prompt_format
|
self.tool_prompt_format = tool_prompt_format
|
||||||
|
|
||||||
def get_tokens(self, messages: List[RawMessage]) -> List[int]:
|
def get_tokens(self, messages: list[RawMessage]) -> list[int]:
|
||||||
model_input = self.formatter.encode_dialog_prompt(
|
model_input = self.formatter.encode_dialog_prompt(
|
||||||
messages,
|
messages,
|
||||||
self.tool_prompt_format,
|
self.tool_prompt_format,
|
||||||
|
@ -149,10 +148,10 @@ class LLama31Interface:
|
||||||
|
|
||||||
def system_messages(
|
def system_messages(
|
||||||
self,
|
self,
|
||||||
builtin_tools: List[BuiltinTool],
|
builtin_tools: list[BuiltinTool],
|
||||||
custom_tools: List[ToolDefinition],
|
custom_tools: list[ToolDefinition],
|
||||||
instruction: Optional[str] = None,
|
instruction: str | None = None,
|
||||||
) -> List[RawMessage]:
|
) -> list[RawMessage]:
|
||||||
messages = []
|
messages = []
|
||||||
|
|
||||||
default_gen = SystemDefaultGenerator()
|
default_gen = SystemDefaultGenerator()
|
||||||
|
@ -194,8 +193,8 @@ class LLama31Interface:
|
||||||
self,
|
self,
|
||||||
content: str,
|
content: str,
|
||||||
stop_reason: StopReason,
|
stop_reason: StopReason,
|
||||||
tool_call: Optional[ToolCall] = None,
|
tool_call: ToolCall | None = None,
|
||||||
) -> List[RawMessage]:
|
) -> list[RawMessage]:
|
||||||
tool_calls = []
|
tool_calls = []
|
||||||
if tool_call:
|
if tool_call:
|
||||||
tool_calls.append(tool_call)
|
tool_calls.append(tool_call)
|
||||||
|
@ -208,7 +207,7 @@ class LLama31Interface:
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
def user_message(self, content: str) -> List[RawMessage]:
|
def user_message(self, content: str) -> list[RawMessage]:
|
||||||
return [RawMessage(role="user", content=content)]
|
return [RawMessage(role="user", content=content)]
|
||||||
|
|
||||||
def display_message_as_tokens(self, message: RawMessage) -> None:
|
def display_message_as_tokens(self, message: RawMessage) -> None:
|
||||||
|
@ -228,7 +227,7 @@ class LLama31Interface:
|
||||||
print("\n", end="")
|
print("\n", end="")
|
||||||
|
|
||||||
|
|
||||||
def list_jinja_templates() -> List[Template]:
|
def list_jinja_templates() -> list[Template]:
|
||||||
return TEMPLATES
|
return TEMPLATES
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
import fairscale.nn.model_parallel.initialize as fs_init
|
import fairscale.nn.model_parallel.initialize as fs_init
|
||||||
import torch
|
import torch
|
||||||
|
@ -80,7 +79,7 @@ def apply_rotary_emb(
|
||||||
xq: torch.Tensor,
|
xq: torch.Tensor,
|
||||||
xk: torch.Tensor,
|
xk: torch.Tensor,
|
||||||
freqs_cis: torch.Tensor,
|
freqs_cis: torch.Tensor,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||||
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
||||||
|
@ -162,7 +161,7 @@ class Attention(nn.Module):
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
start_pos: int,
|
start_pos: int,
|
||||||
freqs_cis: torch.Tensor,
|
freqs_cis: torch.Tensor,
|
||||||
mask: Optional[torch.Tensor],
|
mask: torch.Tensor | None,
|
||||||
):
|
):
|
||||||
bsz, seqlen, _ = x.shape
|
bsz, seqlen, _ = x.shape
|
||||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||||
|
@ -204,7 +203,7 @@ class FeedForward(nn.Module):
|
||||||
dim: int,
|
dim: int,
|
||||||
hidden_dim: int,
|
hidden_dim: int,
|
||||||
multiple_of: int,
|
multiple_of: int,
|
||||||
ffn_dim_multiplier: Optional[float],
|
ffn_dim_multiplier: float | None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
hidden_dim = int(2 * hidden_dim / 3)
|
hidden_dim = int(2 * hidden_dim / 3)
|
||||||
|
@ -243,7 +242,7 @@ class TransformerBlock(nn.Module):
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
start_pos: int,
|
start_pos: int,
|
||||||
freqs_cis: torch.Tensor,
|
freqs_cis: torch.Tensor,
|
||||||
mask: Optional[torch.Tensor],
|
mask: torch.Tensor | None,
|
||||||
):
|
):
|
||||||
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
|
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
|
||||||
out = h + self.feed_forward(self.ffn_norm(h))
|
out = h + self.feed_forward(self.ffn_norm(h))
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
import math
|
import math
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import Any, Optional, Set, Tuple
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torchvision.transforms as tv
|
import torchvision.transforms as tv
|
||||||
|
@ -26,7 +26,7 @@ IMAGE_RES = 224
|
||||||
logger = getLogger()
|
logger = getLogger()
|
||||||
|
|
||||||
|
|
||||||
class VariableSizeImageTransform(object):
|
class VariableSizeImageTransform:
|
||||||
"""
|
"""
|
||||||
This class accepts images of any size and dynamically resize, pads and chunks it
|
This class accepts images of any size and dynamically resize, pads and chunks it
|
||||||
based on the image aspect ratio and the number of image chunks we allow.
|
based on the image aspect ratio and the number of image chunks we allow.
|
||||||
|
@ -75,7 +75,7 @@ class VariableSizeImageTransform(object):
|
||||||
self.resample = tv.InterpolationMode.BILINEAR
|
self.resample = tv.InterpolationMode.BILINEAR
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_factors(n: int) -> Set[int]:
|
def get_factors(n: int) -> set[int]:
|
||||||
"""
|
"""
|
||||||
Calculate all factors of a given number, i.e. a dividor that leaves
|
Calculate all factors of a given number, i.e. a dividor that leaves
|
||||||
no remainder. For example, if n=12, it will return {1, 2, 3, 4, 6, 12}.
|
no remainder. For example, if n=12, it will return {1, 2, 3, 4, 6, 12}.
|
||||||
|
@ -145,9 +145,9 @@ class VariableSizeImageTransform(object):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_max_res_without_distortion(
|
def get_max_res_without_distortion(
|
||||||
image_size: Tuple[int, int],
|
image_size: tuple[int, int],
|
||||||
target_size: Tuple[int, int],
|
target_size: tuple[int, int],
|
||||||
) -> Tuple[int, int]:
|
) -> tuple[int, int]:
|
||||||
"""
|
"""
|
||||||
Determines the maximum resolution to which an image can be resized to without distorting its
|
Determines the maximum resolution to which an image can be resized to without distorting its
|
||||||
aspect ratio, based on the target resolution.
|
aspect ratio, based on the target resolution.
|
||||||
|
@ -198,8 +198,8 @@ class VariableSizeImageTransform(object):
|
||||||
def resize_without_distortion(
|
def resize_without_distortion(
|
||||||
self,
|
self,
|
||||||
image: torch.Tensor,
|
image: torch.Tensor,
|
||||||
target_size: Tuple[int, int],
|
target_size: tuple[int, int],
|
||||||
max_upscaling_size: Optional[int],
|
max_upscaling_size: int | None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Used to resize an image to target_resolution, without distortion.
|
Used to resize an image to target_resolution, without distortion.
|
||||||
|
@ -261,10 +261,10 @@ class VariableSizeImageTransform(object):
|
||||||
|
|
||||||
def get_best_fit(
|
def get_best_fit(
|
||||||
self,
|
self,
|
||||||
image_size: Tuple[int, int],
|
image_size: tuple[int, int],
|
||||||
possible_resolutions: torch.Tensor,
|
possible_resolutions: torch.Tensor,
|
||||||
resize_to_max_canvas: bool = False,
|
resize_to_max_canvas: bool = False,
|
||||||
) -> Tuple[int, int]:
|
) -> tuple[int, int]:
|
||||||
"""
|
"""
|
||||||
Determines the best canvas possible from a list of possible resolutions to, without distortion,
|
Determines the best canvas possible from a list of possible resolutions to, without distortion,
|
||||||
resize an image to.
|
resize an image to.
|
||||||
|
@ -364,7 +364,7 @@ class VariableSizeImageTransform(object):
|
||||||
max_num_chunks: int,
|
max_num_chunks: int,
|
||||||
normalize_img: bool = True,
|
normalize_img: bool = True,
|
||||||
resize_to_max_canvas: bool = False,
|
resize_to_max_canvas: bool = False,
|
||||||
) -> Tuple[Any, Any]:
|
) -> tuple[Any, Any]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
image (PIL.Image): Image to be resized.
|
image (PIL.Image): Image to be resized.
|
||||||
|
|
|
@ -6,8 +6,9 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
from collections.abc import Callable
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import Any
|
||||||
|
|
||||||
import fairscale.nn.model_parallel.initialize as fs_init
|
import fairscale.nn.model_parallel.initialize as fs_init
|
||||||
import torch
|
import torch
|
||||||
|
@ -104,9 +105,9 @@ class ColumnParallelConv2dPatch(torch.nn.Module):
|
||||||
self,
|
self,
|
||||||
in_channels: int,
|
in_channels: int,
|
||||||
out_channels: int,
|
out_channels: int,
|
||||||
kernel_size: Union[int, Tuple[int, int]],
|
kernel_size: int | tuple[int, int],
|
||||||
stride: Union[int, Tuple[int, int]],
|
stride: int | tuple[int, int],
|
||||||
bias: Optional[bool] = False,
|
bias: bool | None = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if isinstance(kernel_size, int):
|
if isinstance(kernel_size, int):
|
||||||
|
@ -390,13 +391,13 @@ class VisionEncoder(nn.Module):
|
||||||
|
|
||||||
def load_hook(
|
def load_hook(
|
||||||
self,
|
self,
|
||||||
state_dict: Dict[str, Any],
|
state_dict: dict[str, Any],
|
||||||
prefix: str,
|
prefix: str,
|
||||||
local_metadata: Dict[str, Any],
|
local_metadata: dict[str, Any],
|
||||||
strict: bool = True,
|
strict: bool = True,
|
||||||
missing_keys: List[str] = None,
|
missing_keys: list[str] = None,
|
||||||
unexpected_keys: List[str] = None,
|
unexpected_keys: list[str] = None,
|
||||||
error_msgs: List[str] = None,
|
error_msgs: list[str] = None,
|
||||||
return_state_dict: bool = False,
|
return_state_dict: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
orig_pos_embed = state_dict.get(prefix + "positional_embedding")
|
orig_pos_embed = state_dict.get(prefix + "positional_embedding")
|
||||||
|
@ -641,7 +642,7 @@ class FeedForward(nn.Module):
|
||||||
dim: int,
|
dim: int,
|
||||||
hidden_dim: int,
|
hidden_dim: int,
|
||||||
multiple_of: int,
|
multiple_of: int,
|
||||||
ffn_dim_multiplier: Optional[float],
|
ffn_dim_multiplier: float | None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the FeedForward module.
|
Initialize the FeedForward module.
|
||||||
|
@ -983,7 +984,7 @@ class CrossAttentionTransformerBlock(torch.nn.Module):
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
xattn_mask: torch.Tensor,
|
xattn_mask: torch.Tensor,
|
||||||
full_text_row_masked_out_mask: Tuple[torch.Tensor, torch.Tensor],
|
full_text_row_masked_out_mask: tuple[torch.Tensor, torch.Tensor],
|
||||||
xattn_cache: torch.Tensor,
|
xattn_cache: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
_attn_out = self.attention(
|
_attn_out = self.attention(
|
||||||
|
@ -1144,7 +1145,7 @@ class CrossAttentionTransformerText(torch.nn.Module):
|
||||||
def _init_fusion_schedule(
|
def _init_fusion_schedule(
|
||||||
self,
|
self,
|
||||||
num_layers: int,
|
num_layers: int,
|
||||||
) -> List[int]:
|
) -> list[int]:
|
||||||
llama_layers = list(range(self.n_llama_layers))
|
llama_layers = list(range(self.n_llama_layers))
|
||||||
|
|
||||||
# uniformly spread the layers
|
# uniformly spread the layers
|
||||||
|
@ -1231,7 +1232,7 @@ class CrossAttentionTransformerText(torch.nn.Module):
|
||||||
text_dtype,
|
text_dtype,
|
||||||
vision_tokens,
|
vision_tokens,
|
||||||
cross_attention_masks,
|
cross_attention_masks,
|
||||||
) -> Tuple[Tensor, Tensor]:
|
) -> tuple[Tensor, Tensor]:
|
||||||
assert vision_tokens is not None, "Vision tokens must be provided"
|
assert vision_tokens is not None, "Vision tokens must be provided"
|
||||||
vision_seqlen = vision_tokens.shape[3]
|
vision_seqlen = vision_tokens.shape[3]
|
||||||
assert vision_tokens.shape[1] == cross_attention_masks.shape[2], (
|
assert vision_tokens.shape[1] == cross_attention_masks.shape[2], (
|
||||||
|
@ -1280,11 +1281,11 @@ class CrossAttentionTransformer(torch.nn.Module):
|
||||||
|
|
||||||
def compute_vision_tokens_masks(
|
def compute_vision_tokens_masks(
|
||||||
self,
|
self,
|
||||||
batch_images: List[List[PIL_Image.Image]],
|
batch_images: list[list[PIL_Image.Image]],
|
||||||
batch_masks: List[List[List[int]]],
|
batch_masks: list[list[list[int]]],
|
||||||
total_len: int,
|
total_len: int,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
skip_vision_encoder = False
|
skip_vision_encoder = False
|
||||||
|
|
||||||
assert len(batch_images) == len(batch_masks), "Images and masks must have the same length"
|
assert len(batch_images) == len(batch_masks), "Images and masks must have the same length"
|
||||||
|
@ -1371,11 +1372,11 @@ class CrossAttentionTransformer(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
def _stack_images(
|
def _stack_images(
|
||||||
images: List[List[PIL_Image.Image]],
|
images: list[list[PIL_Image.Image]],
|
||||||
max_num_chunks: int,
|
max_num_chunks: int,
|
||||||
image_res: int,
|
image_res: int,
|
||||||
max_num_images: int,
|
max_num_images: int,
|
||||||
) -> Tuple[torch.Tensor, List[int]]:
|
) -> tuple[torch.Tensor, list[int]]:
|
||||||
"""
|
"""
|
||||||
Takes a list of list of images and stacks them into a tensor.
|
Takes a list of list of images and stacks them into a tensor.
|
||||||
This function is needed since images can be of completely
|
This function is needed since images can be of completely
|
||||||
|
@ -1400,8 +1401,8 @@ def _stack_images(
|
||||||
|
|
||||||
|
|
||||||
def _pad_masks(
|
def _pad_masks(
|
||||||
all_masks: List[List[List[int]]],
|
all_masks: list[list[list[int]]],
|
||||||
all_num_chunks: List[List[int]],
|
all_num_chunks: list[list[int]],
|
||||||
total_len: int,
|
total_len: int,
|
||||||
max_num_chunks: int,
|
max_num_chunks: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# the top-level of this source tree.
|
# the top-level of this source tree.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List
|
from typing import Any
|
||||||
|
|
||||||
from jinja2 import Template
|
from jinja2 import Template
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ from jinja2 import Template
|
||||||
@dataclass
|
@dataclass
|
||||||
class PromptTemplate:
|
class PromptTemplate:
|
||||||
template: str
|
template: str
|
||||||
data: Dict[str, Any]
|
data: dict[str, Any]
|
||||||
|
|
||||||
def render(self):
|
def render(self):
|
||||||
template = Template(self.template)
|
template = Template(self.template)
|
||||||
|
@ -35,5 +35,5 @@ class PromptTemplateGeneratorBase:
|
||||||
def gen(self, *args, **kwargs) -> PromptTemplate:
|
def gen(self, *args, **kwargs) -> PromptTemplate:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def data_examples(self) -> List[Any]:
|
def data_examples(self) -> list[Any]:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
|
|
||||||
import textwrap
|
import textwrap
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, List, Optional
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
|
@ -39,12 +39,12 @@ class SystemDefaultGenerator(PromptTemplateGeneratorBase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def data_examples(self) -> List[Any]:
|
def data_examples(self) -> list[Any]:
|
||||||
return [None]
|
return [None]
|
||||||
|
|
||||||
|
|
||||||
class BuiltinToolGenerator(PromptTemplateGeneratorBase):
|
class BuiltinToolGenerator(PromptTemplateGeneratorBase):
|
||||||
def _tool_breakdown(self, tools: List[ToolDefinition]):
|
def _tool_breakdown(self, tools: list[ToolDefinition]):
|
||||||
builtin_tools, custom_tools = [], []
|
builtin_tools, custom_tools = [], []
|
||||||
for dfn in tools:
|
for dfn in tools:
|
||||||
if isinstance(dfn.tool_name, BuiltinTool):
|
if isinstance(dfn.tool_name, BuiltinTool):
|
||||||
|
@ -54,7 +54,7 @@ class BuiltinToolGenerator(PromptTemplateGeneratorBase):
|
||||||
|
|
||||||
return builtin_tools, custom_tools
|
return builtin_tools, custom_tools
|
||||||
|
|
||||||
def gen(self, tools: List[ToolDefinition]) -> PromptTemplate:
|
def gen(self, tools: list[ToolDefinition]) -> PromptTemplate:
|
||||||
builtin_tools, custom_tools = self._tool_breakdown(tools)
|
builtin_tools, custom_tools = self._tool_breakdown(tools)
|
||||||
template_str = textwrap.dedent(
|
template_str = textwrap.dedent(
|
||||||
"""
|
"""
|
||||||
|
@ -75,7 +75,7 @@ class BuiltinToolGenerator(PromptTemplateGeneratorBase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def data_examples(self) -> List[List[ToolDefinition]]:
|
def data_examples(self) -> list[list[ToolDefinition]]:
|
||||||
return [
|
return [
|
||||||
# builtin tools
|
# builtin tools
|
||||||
[
|
[
|
||||||
|
@ -91,7 +91,7 @@ class BuiltinToolGenerator(PromptTemplateGeneratorBase):
|
||||||
|
|
||||||
|
|
||||||
class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
|
class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
|
||||||
def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
|
def gen(self, custom_tools: list[ToolDefinition]) -> PromptTemplate:
|
||||||
template_str = textwrap.dedent(
|
template_str = textwrap.dedent(
|
||||||
"""
|
"""
|
||||||
Answer the user's question by making use of the following functions if needed.
|
Answer the user's question by making use of the following functions if needed.
|
||||||
|
@ -137,7 +137,7 @@ class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
|
||||||
{"custom_tools": [t.model_dump() for t in custom_tools]},
|
{"custom_tools": [t.model_dump() for t in custom_tools]},
|
||||||
)
|
)
|
||||||
|
|
||||||
def data_examples(self) -> List[List[ToolDefinition]]:
|
def data_examples(self) -> list[list[ToolDefinition]]:
|
||||||
return [
|
return [
|
||||||
[
|
[
|
||||||
ToolDefinition(
|
ToolDefinition(
|
||||||
|
@ -161,7 +161,7 @@ class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
|
||||||
|
|
||||||
|
|
||||||
class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase):
|
class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase):
|
||||||
def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
|
def gen(self, custom_tools: list[ToolDefinition]) -> PromptTemplate:
|
||||||
template_str = textwrap.dedent(
|
template_str = textwrap.dedent(
|
||||||
"""
|
"""
|
||||||
You have access to the following functions:
|
You have access to the following functions:
|
||||||
|
@ -199,7 +199,7 @@ class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase):
|
||||||
{"custom_tools": [t.model_dump() for t in custom_tools]},
|
{"custom_tools": [t.model_dump() for t in custom_tools]},
|
||||||
)
|
)
|
||||||
|
|
||||||
def data_examples(self) -> List[List[ToolDefinition]]:
|
def data_examples(self) -> list[list[ToolDefinition]]:
|
||||||
return [
|
return [
|
||||||
[
|
[
|
||||||
ToolDefinition(
|
ToolDefinition(
|
||||||
|
@ -238,14 +238,14 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
||||||
""".strip("\n")
|
""".strip("\n")
|
||||||
)
|
)
|
||||||
|
|
||||||
def gen(self, custom_tools: List[ToolDefinition], system_prompt: Optional[str] = None) -> PromptTemplate:
|
def gen(self, custom_tools: list[ToolDefinition], system_prompt: str | None = None) -> PromptTemplate:
|
||||||
system_prompt = system_prompt or self.DEFAULT_PROMPT
|
system_prompt = system_prompt or self.DEFAULT_PROMPT
|
||||||
return PromptTemplate(
|
return PromptTemplate(
|
||||||
system_prompt,
|
system_prompt,
|
||||||
{"function_description": self._gen_function_description(custom_tools)},
|
{"function_description": self._gen_function_description(custom_tools)},
|
||||||
)
|
)
|
||||||
|
|
||||||
def _gen_function_description(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
|
def _gen_function_description(self, custom_tools: list[ToolDefinition]) -> PromptTemplate:
|
||||||
template_str = textwrap.dedent(
|
template_str = textwrap.dedent(
|
||||||
"""
|
"""
|
||||||
Here is a list of functions in JSON format that you can invoke.
|
Here is a list of functions in JSON format that you can invoke.
|
||||||
|
@ -291,7 +291,7 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
||||||
{"tools": [t.model_dump() for t in custom_tools]},
|
{"tools": [t.model_dump() for t in custom_tools]},
|
||||||
).render()
|
).render()
|
||||||
|
|
||||||
def data_examples(self) -> List[List[ToolDefinition]]:
|
def data_examples(self) -> list[list[ToolDefinition]]:
|
||||||
return [
|
return [
|
||||||
[
|
[
|
||||||
ToolDefinition(
|
ToolDefinition(
|
||||||
|
|
|
@ -12,7 +12,6 @@
|
||||||
# the top-level of this source tree.
|
# the top-level of this source tree.
|
||||||
|
|
||||||
import textwrap
|
import textwrap
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from .base import PromptTemplate, PromptTemplateGeneratorBase
|
from .base import PromptTemplate, PromptTemplateGeneratorBase
|
||||||
|
|
||||||
|
@ -21,8 +20,8 @@ class ToolResponseGenerator(PromptTemplateGeneratorBase):
|
||||||
def gen(
|
def gen(
|
||||||
self,
|
self,
|
||||||
status: str,
|
status: str,
|
||||||
stdout: Optional[str] = None,
|
stdout: str | None = None,
|
||||||
stderr: Optional[str] = None,
|
stderr: str | None = None,
|
||||||
):
|
):
|
||||||
assert status in [
|
assert status in [
|
||||||
"success",
|
"success",
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
# type: ignore
|
# type: ignore
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, List, Optional, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
|
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
|
||||||
|
@ -37,9 +37,9 @@ def swiglu_wrapper(
|
||||||
def convert_to_quantized_model(
|
def convert_to_quantized_model(
|
||||||
model: Transformer | CrossAttentionTransformer,
|
model: Transformer | CrossAttentionTransformer,
|
||||||
checkpoint_dir: str,
|
checkpoint_dir: str,
|
||||||
quantization_mode: Optional[str] = None,
|
quantization_mode: str | None = None,
|
||||||
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
fp8_activation_scale_ub: float | None = 1200.0,
|
||||||
device: Optional[torch.device] = None,
|
device: torch.device | None = None,
|
||||||
) -> Transformer | CrossAttentionTransformer:
|
) -> Transformer | CrossAttentionTransformer:
|
||||||
if quantization_mode == QuantizationMode.fp8_mixed:
|
if quantization_mode == QuantizationMode.fp8_mixed:
|
||||||
return convert_to_fp8_quantized_model(model, checkpoint_dir, fp8_activation_scale_ub, device)
|
return convert_to_fp8_quantized_model(model, checkpoint_dir, fp8_activation_scale_ub, device)
|
||||||
|
@ -52,8 +52,8 @@ def convert_to_quantized_model(
|
||||||
def convert_to_fp8_quantized_model(
|
def convert_to_fp8_quantized_model(
|
||||||
model: Transformer,
|
model: Transformer,
|
||||||
checkpoint_dir: str,
|
checkpoint_dir: str,
|
||||||
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
fp8_activation_scale_ub: float | None = 1200.0,
|
||||||
device: Optional[torch.device] = None,
|
device: torch.device | None = None,
|
||||||
) -> Transformer:
|
) -> Transformer:
|
||||||
# Move weights to GPU with quantization
|
# Move weights to GPU with quantization
|
||||||
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
|
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
|
||||||
|
@ -122,8 +122,8 @@ class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
|
||||||
precision: torch.dtype = torch.float32,
|
precision: torch.dtype = torch.float32,
|
||||||
scales_precision: torch.dtype = torch.float32,
|
scales_precision: torch.dtype = torch.float32,
|
||||||
# LoRA parameters
|
# LoRA parameters
|
||||||
lora_rank: Optional[int] = None,
|
lora_rank: int | None = None,
|
||||||
lora_scale: Optional[float] = None,
|
lora_scale: float | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
in_features,
|
in_features,
|
||||||
|
@ -134,8 +134,8 @@ class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
|
||||||
precision=precision,
|
precision=precision,
|
||||||
scales_precision=scales_precision,
|
scales_precision=scales_precision,
|
||||||
)
|
)
|
||||||
self.lora_scale: Optional[float] = None
|
self.lora_scale: float | None = None
|
||||||
self.adaptor: Optional[nn.Sequential] = None
|
self.adaptor: nn.Sequential | None = None
|
||||||
if lora_rank is not None:
|
if lora_rank is not None:
|
||||||
assert lora_scale is not None, "Please specify lora scale for LoRA."
|
assert lora_scale is not None, "Please specify lora scale for LoRA."
|
||||||
# Low-rank adaptation. See paper for more details: https://arxiv.org/abs/2106.09685
|
# Low-rank adaptation. See paper for more details: https://arxiv.org/abs/2106.09685
|
||||||
|
@ -147,13 +147,13 @@ class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
|
||||||
|
|
||||||
def load_hook(
|
def load_hook(
|
||||||
self,
|
self,
|
||||||
state_dict: Dict[str, Any],
|
state_dict: dict[str, Any],
|
||||||
prefix: str,
|
prefix: str,
|
||||||
local_metadata: Dict[str, Any],
|
local_metadata: dict[str, Any],
|
||||||
strict: bool,
|
strict: bool,
|
||||||
missing_keys: List[str],
|
missing_keys: list[str],
|
||||||
unexpected_keys: List[str],
|
unexpected_keys: list[str],
|
||||||
error_msgs: List[str],
|
error_msgs: list[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""A hook to load the quantized weights from the state dict."""
|
"""A hook to load the quantized weights from the state dict."""
|
||||||
if prefix + "zeros" not in state_dict:
|
if prefix + "zeros" not in state_dict:
|
||||||
|
@ -191,13 +191,13 @@ class Int8WeightEmbedding(torch.nn.Embedding):
|
||||||
|
|
||||||
def load_hook(
|
def load_hook(
|
||||||
self,
|
self,
|
||||||
state_dict: Dict[str, Any],
|
state_dict: dict[str, Any],
|
||||||
prefix: str,
|
prefix: str,
|
||||||
local_metadata: Dict[str, Any],
|
local_metadata: dict[str, Any],
|
||||||
strict: bool,
|
strict: bool,
|
||||||
missing_keys: List[str],
|
missing_keys: list[str],
|
||||||
unexpected_keys: List[str],
|
unexpected_keys: list[str],
|
||||||
error_msgs: List[str],
|
error_msgs: list[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""A hook to load the quantized embedding weight and scales from the state dict."""
|
"""A hook to load the quantized embedding weight and scales from the state dict."""
|
||||||
weights = state_dict.pop(prefix + "weight")
|
weights = state_dict.pop(prefix + "weight")
|
||||||
|
@ -221,13 +221,13 @@ class Int8WeightLinear(torch.nn.Linear):
|
||||||
|
|
||||||
def load_hook(
|
def load_hook(
|
||||||
self,
|
self,
|
||||||
state_dict: Dict[str, Any],
|
state_dict: dict[str, Any],
|
||||||
prefix: str,
|
prefix: str,
|
||||||
local_metadata: Dict[str, Any],
|
local_metadata: dict[str, Any],
|
||||||
strict: bool,
|
strict: bool,
|
||||||
missing_keys: List[str],
|
missing_keys: list[str],
|
||||||
unexpected_keys: List[str],
|
unexpected_keys: list[str],
|
||||||
error_msgs: List[str],
|
error_msgs: list[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""A hook to load the quantized linear weight and scales from the state dict."""
|
"""A hook to load the quantized linear weight and scales from the state dict."""
|
||||||
weights = state_dict.pop(prefix + "weight")
|
weights = state_dict.pop(prefix + "weight")
|
||||||
|
@ -238,8 +238,8 @@ class Int8WeightLinear(torch.nn.Linear):
|
||||||
def _prepare_model_int4_weight_int8_dynamic_activation(
|
def _prepare_model_int4_weight_int8_dynamic_activation(
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
group_size: int,
|
group_size: int,
|
||||||
lora_rank: Optional[int],
|
lora_rank: int | None,
|
||||||
lora_scale: Optional[float],
|
lora_scale: float | None,
|
||||||
):
|
):
|
||||||
"""Prepare the model for int4 weight and int8 dynamic activation quantization.
|
"""Prepare the model for int4 weight and int8 dynamic activation quantization.
|
||||||
|
|
||||||
|
@ -265,7 +265,7 @@ def _prepare_model_int4_weight_int8_dynamic_activation(
|
||||||
)
|
)
|
||||||
del module
|
del module
|
||||||
setattr(model, module_name, quantized_module)
|
setattr(model, module_name, quantized_module)
|
||||||
elif isinstance(module, (ColumnParallelLinear, RowParallelLinear, nn.Linear)):
|
elif isinstance(module, ColumnParallelLinear | RowParallelLinear | nn.Linear):
|
||||||
quantized_module = Int8DynActInt4WeightLinearLoRA(
|
quantized_module = Int8DynActInt4WeightLinearLoRA(
|
||||||
in_features=module.in_features,
|
in_features=module.in_features,
|
||||||
out_features=module.out_features,
|
out_features=module.out_features,
|
||||||
|
@ -286,7 +286,7 @@ def _prepare_model_int4_weight_int8_dynamic_activation(
|
||||||
def convert_to_int4_quantized_model(
|
def convert_to_int4_quantized_model(
|
||||||
model: Transformer | CrossAttentionTransformer,
|
model: Transformer | CrossAttentionTransformer,
|
||||||
checkpoint_dir: str,
|
checkpoint_dir: str,
|
||||||
device: Optional[torch.device] = None,
|
device: torch.device | None = None,
|
||||||
) -> Transformer | CrossAttentionTransformer:
|
) -> Transformer | CrossAttentionTransformer:
|
||||||
"""Convert the model to int4 quantized model."""
|
"""Convert the model to int4 quantized model."""
|
||||||
model_args = model.params
|
model_args = model.params
|
||||||
|
|
|
@ -5,18 +5,11 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from collections.abc import Collection, Iterator, Sequence, Set
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
from typing import (
|
||||||
AbstractSet,
|
|
||||||
Collection,
|
|
||||||
Dict,
|
|
||||||
Iterator,
|
|
||||||
List,
|
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
Union,
|
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -44,7 +37,7 @@ class Tokenizer:
|
||||||
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
|
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
special_tokens: Dict[str, int]
|
special_tokens: dict[str, int]
|
||||||
|
|
||||||
num_reserved_special_tokens = 256
|
num_reserved_special_tokens = 256
|
||||||
|
|
||||||
|
@ -116,9 +109,9 @@ class Tokenizer:
|
||||||
*,
|
*,
|
||||||
bos: bool,
|
bos: bool,
|
||||||
eos: bool,
|
eos: bool,
|
||||||
allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None,
|
allowed_special: Literal["all"] | Set[str] | None = None,
|
||||||
disallowed_special: Union[Literal["all"], Collection[str]] = (),
|
disallowed_special: Literal["all"] | Collection[str] = (),
|
||||||
) -> List[int]:
|
) -> list[int]:
|
||||||
"""
|
"""
|
||||||
Encodes a string into a list of token IDs.
|
Encodes a string into a list of token IDs.
|
||||||
|
|
||||||
|
@ -151,7 +144,7 @@ class Tokenizer:
|
||||||
s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
|
s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
t: List[int] = []
|
t: list[int] = []
|
||||||
for substr in substrs:
|
for substr in substrs:
|
||||||
t.extend(
|
t.extend(
|
||||||
self.model.encode(
|
self.model.encode(
|
||||||
|
@ -177,7 +170,7 @@ class Tokenizer:
|
||||||
str: The decoded string.
|
str: The decoded string.
|
||||||
"""
|
"""
|
||||||
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
|
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
|
||||||
return self.model.decode(cast(List[int], t))
|
return self.model.decode(cast(list[int], t))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _split_whitespaces_or_nonwhitespaces(s: str, max_consecutive_slice_len: int) -> Iterator[str]:
|
def _split_whitespaces_or_nonwhitespaces(s: str, max_consecutive_slice_len: int) -> Iterator[str]:
|
||||||
|
|
|
@ -6,7 +6,6 @@
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
@ -172,7 +171,7 @@ class ToolUtils:
|
||||||
return match is not None
|
return match is not None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def maybe_extract_builtin_tool_call(message_body: str) -> Optional[Tuple[str, str]]:
|
def maybe_extract_builtin_tool_call(message_body: str) -> tuple[str, str] | None:
|
||||||
# Find the first match in the text
|
# Find the first match in the text
|
||||||
match = re.search(BUILTIN_TOOL_PATTERN, message_body)
|
match = re.search(BUILTIN_TOOL_PATTERN, message_body)
|
||||||
|
|
||||||
|
@ -185,7 +184,7 @@ class ToolUtils:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def maybe_extract_custom_tool_call(message_body: str) -> Optional[Tuple[str, str]]:
|
def maybe_extract_custom_tool_call(message_body: str) -> tuple[str, str] | None:
|
||||||
# NOTE: Custom function too calls are still experimental
|
# NOTE: Custom function too calls are still experimental
|
||||||
# Sometimes, response is of the form
|
# Sometimes, response is of the form
|
||||||
# {"type": "function", "name": "function_name", "parameters": {...}
|
# {"type": "function", "name": "function_name", "parameters": {...}
|
||||||
|
@ -252,7 +251,7 @@ class ToolUtils:
|
||||||
def format_value(value: RecursiveType) -> str:
|
def format_value(value: RecursiveType) -> str:
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
return f'"{value}"'
|
return f'"{value}"'
|
||||||
elif isinstance(value, (int, float, bool)) or value is None:
|
elif isinstance(value, int | float | bool) or value is None:
|
||||||
return str(value)
|
return str(value)
|
||||||
elif isinstance(value, list):
|
elif isinstance(value, list):
|
||||||
return f"[{', '.join(format_value(v) for v in value)}]"
|
return f"[{', '.join(format_value(v) for v in value)}]"
|
||||||
|
|
|
@ -12,7 +12,6 @@
|
||||||
# the top-level of this source tree.
|
# the top-level of this source tree.
|
||||||
|
|
||||||
import textwrap
|
import textwrap
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from llama_stack.models.llama.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
|
@ -73,7 +72,7 @@ def wolfram_alpha_response():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def usecases() -> List[UseCase | str]:
|
def usecases() -> list[UseCase | str]:
|
||||||
return [
|
return [
|
||||||
textwrap.dedent(
|
textwrap.dedent(
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -12,7 +12,6 @@
|
||||||
# the top-level of this source tree.
|
# the top-level of this source tree.
|
||||||
|
|
||||||
import textwrap
|
import textwrap
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from llama_stack.models.llama.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
|
@ -74,7 +73,7 @@ def wolfram_alpha_response():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def usecases() -> List[UseCase | str]:
|
def usecases() -> list[UseCase | str]:
|
||||||
return [
|
return [
|
||||||
textwrap.dedent(
|
textwrap.dedent(
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel, model_validator
|
||||||
|
|
||||||
|
@ -15,8 +14,8 @@ class QuantizationScheme(Enum):
|
||||||
|
|
||||||
|
|
||||||
class QuantizationArgs(BaseModel):
|
class QuantizationArgs(BaseModel):
|
||||||
scheme: Optional[QuantizationScheme] = None
|
scheme: QuantizationScheme | None = None
|
||||||
group_size: Optional[int] = None
|
group_size: int | None = None
|
||||||
spinquant: bool = False
|
spinquant: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@ -58,32 +57,32 @@ class ModelArgs(BaseModel):
|
||||||
dim: int = -1
|
dim: int = -1
|
||||||
n_layers: int = -1
|
n_layers: int = -1
|
||||||
n_heads: int = -1
|
n_heads: int = -1
|
||||||
n_kv_heads: Optional[int] = None
|
n_kv_heads: int | None = None
|
||||||
head_dim: Optional[int] = None
|
head_dim: int | None = None
|
||||||
|
|
||||||
vocab_size: int = -1
|
vocab_size: int = -1
|
||||||
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
|
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
|
||||||
ffn_dim_multiplier: Optional[float] = None
|
ffn_dim_multiplier: float | None = None
|
||||||
ffn_exp: Optional[float] = None
|
ffn_exp: float | None = None
|
||||||
norm_eps: float = 1e-5
|
norm_eps: float = 1e-5
|
||||||
|
|
||||||
attention_chunk_size: Optional[int] = None
|
attention_chunk_size: int | None = None
|
||||||
rope_theta: float = 500000
|
rope_theta: float = 500000
|
||||||
use_scaled_rope: bool = False
|
use_scaled_rope: bool = False
|
||||||
rope_scaling_factor: Optional[float] = None
|
rope_scaling_factor: float | None = None
|
||||||
rope_high_freq_factor: Optional[float] = None
|
rope_high_freq_factor: float | None = None
|
||||||
|
|
||||||
nope_layer_interval: Optional[int] = None # No position encoding in every n layers
|
nope_layer_interval: int | None = None # No position encoding in every n layers
|
||||||
use_qk_norm: bool = False
|
use_qk_norm: bool = False
|
||||||
# Set to True to enable inference-time temperature tuning (useful for very long context)
|
# Set to True to enable inference-time temperature tuning (useful for very long context)
|
||||||
attn_temperature_tuning: bool = False
|
attn_temperature_tuning: bool = False
|
||||||
floor_scale: float = 8192.0
|
floor_scale: float = 8192.0
|
||||||
attn_scale: float = 0.1
|
attn_scale: float = 0.1
|
||||||
|
|
||||||
vision_args: Optional[VisionArgs] = None
|
vision_args: VisionArgs | None = None
|
||||||
moe_args: Optional[MoEArgs] = None
|
moe_args: MoEArgs | None = None
|
||||||
quantization_args: Optional[QuantizationArgs] = None
|
quantization_args: QuantizationArgs | None = None
|
||||||
lora_args: Optional[LoRAArgs] = None
|
lora_args: LoRAArgs | None = None
|
||||||
|
|
||||||
max_batch_size: int = 32
|
max_batch_size: int = 32
|
||||||
max_seq_len: int = 2048
|
max_seq_len: int = 2048
|
||||||
|
|
|
@ -8,7 +8,6 @@ import io
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image as PIL_Image
|
from PIL import Image as PIL_Image
|
||||||
|
@ -46,10 +45,10 @@ def role_str(role: Role) -> str:
|
||||||
class TransformedImage:
|
class TransformedImage:
|
||||||
image_tiles: torch.Tensor
|
image_tiles: torch.Tensor
|
||||||
# is the aspect ratio needed anywhere?
|
# is the aspect ratio needed anywhere?
|
||||||
aspect_ratio: Tuple[int, int]
|
aspect_ratio: tuple[int, int]
|
||||||
|
|
||||||
|
|
||||||
def convert_image_to_rgb(image: PIL_Image.Image, bg: Tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image:
|
def convert_image_to_rgb(image: PIL_Image.Image, bg: tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image:
|
||||||
if image.mode == "RGBA":
|
if image.mode == "RGBA":
|
||||||
image.load() # for png.split()
|
image.load() # for png.split()
|
||||||
new_img = PIL_Image.new("RGB", image.size, bg)
|
new_img = PIL_Image.new("RGB", image.size, bg)
|
||||||
|
@ -59,12 +58,12 @@ def convert_image_to_rgb(image: PIL_Image.Image, bg: Tuple[int, int, int] = (255
|
||||||
|
|
||||||
|
|
||||||
class ChatFormat:
|
class ChatFormat:
|
||||||
possible_headers: Dict[Role, str]
|
possible_headers: dict[Role, str]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
vision_args: Optional[VisionArgs] = None,
|
vision_args: VisionArgs | None = None,
|
||||||
max_num_chunks: int = 16,
|
max_num_chunks: int = 16,
|
||||||
):
|
):
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
@ -81,7 +80,7 @@ class ChatFormat:
|
||||||
vision_args.image_size.width, vision_args.image_size.height
|
vision_args.image_size.width, vision_args.image_size.height
|
||||||
)
|
)
|
||||||
|
|
||||||
def _encode_header(self, role: str) -> List[int]:
|
def _encode_header(self, role: str) -> list[int]:
|
||||||
tokens = []
|
tokens = []
|
||||||
tokens.append(self.tokenizer.special_tokens["<|header_start|>"])
|
tokens.append(self.tokenizer.special_tokens["<|header_start|>"])
|
||||||
|
|
||||||
|
@ -98,7 +97,7 @@ class ChatFormat:
|
||||||
def _encode_image(
|
def _encode_image(
|
||||||
self,
|
self,
|
||||||
transformed_image: TransformedImage,
|
transformed_image: TransformedImage,
|
||||||
) -> List[int]:
|
) -> list[int]:
|
||||||
assert self.vision_args is not None, "The model is not vision-enabled"
|
assert self.vision_args is not None, "The model is not vision-enabled"
|
||||||
|
|
||||||
image_tensor = transformed_image.image_tiles
|
image_tensor = transformed_image.image_tiles
|
||||||
|
@ -140,7 +139,7 @@ class ChatFormat:
|
||||||
|
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
def _encode_content(self, content: RawContent, bos: bool = False) -> Tuple[List[int], List[TransformedImage]]:
|
def _encode_content(self, content: RawContent, bos: bool = False) -> tuple[list[int], list[TransformedImage]]:
|
||||||
tokens = []
|
tokens = []
|
||||||
tranformed_images = []
|
tranformed_images = []
|
||||||
|
|
||||||
|
@ -189,7 +188,7 @@ class ChatFormat:
|
||||||
|
|
||||||
def encode_message(
|
def encode_message(
|
||||||
self, message: RawMessage, tool_prompt_format: ToolPromptFormat
|
self, message: RawMessage, tool_prompt_format: ToolPromptFormat
|
||||||
) -> Tuple[List[int], List[TransformedImage]]:
|
) -> tuple[list[int], list[TransformedImage]]:
|
||||||
tokens = self._encode_header(message.role)
|
tokens = self._encode_header(message.role)
|
||||||
images = []
|
images = []
|
||||||
|
|
||||||
|
@ -223,7 +222,7 @@ class ChatFormat:
|
||||||
|
|
||||||
def encode_dialog_prompt(
|
def encode_dialog_prompt(
|
||||||
self,
|
self,
|
||||||
messages: List[RawMessage],
|
messages: list[RawMessage],
|
||||||
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
||||||
) -> LLMInput:
|
) -> LLMInput:
|
||||||
tokens = []
|
tokens = []
|
||||||
|
@ -240,7 +239,7 @@ class ChatFormat:
|
||||||
return self._model_input_from_tokens_images(tokens, images)
|
return self._model_input_from_tokens_images(tokens, images)
|
||||||
|
|
||||||
# TODO(this should be generic, not only for assistant messages)
|
# TODO(this should be generic, not only for assistant messages)
|
||||||
def decode_assistant_message(self, tokens: List[int], stop_reason: StopReason) -> RawMessage:
|
def decode_assistant_message(self, tokens: list[int], stop_reason: StopReason) -> RawMessage:
|
||||||
content = self.tokenizer.decode(tokens)
|
content = self.tokenizer.decode(tokens)
|
||||||
|
|
||||||
return self.decode_assistant_message_from_content(content, stop_reason)
|
return self.decode_assistant_message_from_content(content, stop_reason)
|
||||||
|
@ -312,7 +311,7 @@ class ChatFormat:
|
||||||
tool_calls=tool_calls,
|
tool_calls=tool_calls,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _model_input_from_tokens_images(self, tokens: List[int], images: List[TransformedImage]) -> LLMInput:
|
def _model_input_from_tokens_images(self, tokens: list[int], images: list[TransformedImage]) -> LLMInput:
|
||||||
return LLMInput(
|
return LLMInput(
|
||||||
tokens=tokens,
|
tokens=tokens,
|
||||||
images=[x.image_tiles for x in images] if len(images) > 0 else None,
|
images=[x.image_tiles for x in images] if len(images) > 0 else None,
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Union
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -30,7 +29,7 @@ class LLMInput:
|
||||||
tokens: torch.Tensor
|
tokens: torch.Tensor
|
||||||
|
|
||||||
# images are already pre-processed (resized, tiled, etc.)
|
# images are already pre-processed (resized, tiled, etc.)
|
||||||
images: Optional[List[torch.Tensor]] = None
|
images: list[torch.Tensor] | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -45,8 +44,8 @@ class TransformerInput:
|
||||||
# tokens_position defines the position of the tokens in each batch,
|
# tokens_position defines the position of the tokens in each batch,
|
||||||
# - when it is a tensor ([batch_size,]), it is the start position of the tokens in each batch
|
# - when it is a tensor ([batch_size,]), it is the start position of the tokens in each batch
|
||||||
# - when it is an int, the start position are the same for all batches
|
# - when it is an int, the start position are the same for all batches
|
||||||
tokens_position: Union[torch.Tensor, int]
|
tokens_position: torch.Tensor | int
|
||||||
image_embedding: Optional[MaskedEmbedding] = None
|
image_embedding: MaskedEmbedding | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
@ -11,7 +11,7 @@
|
||||||
# top-level folder for each specific model found within the models/ directory at
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
# the top-level of this source tree.
|
# the top-level of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, List
|
from typing import Any
|
||||||
|
|
||||||
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
|
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
|
||||||
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
||||||
|
@ -36,13 +36,13 @@ class FeedForward(nn.Module):
|
||||||
|
|
||||||
def load_hook(
|
def load_hook(
|
||||||
self,
|
self,
|
||||||
state_dict: Dict[str, Any],
|
state_dict: dict[str, Any],
|
||||||
prefix: str,
|
prefix: str,
|
||||||
local_metadata: Dict[str, Any],
|
local_metadata: dict[str, Any],
|
||||||
strict: bool,
|
strict: bool,
|
||||||
missing_keys: List[str],
|
missing_keys: list[str],
|
||||||
unexpected_keys: List[str],
|
unexpected_keys: list[str],
|
||||||
error_msgs: List[str],
|
error_msgs: list[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
if prefix + "mlp.fc1_weight" in state_dict:
|
if prefix + "mlp.fc1_weight" in state_dict:
|
||||||
w1, w3 = state_dict.pop(prefix + "mlp.fc1_weight").chunk(2, dim=0)
|
w1, w3 = state_dict.pop(prefix + "mlp.fc1_weight").chunk(2, dim=0)
|
||||||
|
|
|
@ -10,8 +10,8 @@ import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
from collections.abc import Callable, Generator
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Generator, List, Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
@ -38,8 +38,8 @@ class Llama4:
|
||||||
ckpt_dir: str,
|
ckpt_dir: str,
|
||||||
max_seq_len: int,
|
max_seq_len: int,
|
||||||
max_batch_size: int,
|
max_batch_size: int,
|
||||||
world_size: Optional[int] = None,
|
world_size: int | None = None,
|
||||||
quantization_mode: Optional[QuantizationMode] = None,
|
quantization_mode: QuantizationMode | None = None,
|
||||||
seed: int = 1,
|
seed: int = 1,
|
||||||
):
|
):
|
||||||
if not torch.distributed.is_initialized():
|
if not torch.distributed.is_initialized():
|
||||||
|
@ -63,7 +63,7 @@ class Llama4:
|
||||||
ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
|
ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||||
assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}"
|
assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||||
print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})")
|
print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})")
|
||||||
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
with open(Path(ckpt_dir) / "params.json") as f:
|
||||||
params = json.loads(f.read())
|
params = json.loads(f.read())
|
||||||
|
|
||||||
model_args: ModelArgs = ModelArgs(
|
model_args: ModelArgs = ModelArgs(
|
||||||
|
@ -117,15 +117,15 @@ class Llama4:
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
llm_inputs: List[LLMInput],
|
llm_inputs: list[LLMInput],
|
||||||
temperature: float = 0.6,
|
temperature: float = 0.6,
|
||||||
top_p: float = 0.9,
|
top_p: float = 0.9,
|
||||||
max_gen_len: Optional[int] = None,
|
max_gen_len: int | None = None,
|
||||||
logprobs: bool = False,
|
logprobs: bool = False,
|
||||||
echo: bool = False,
|
echo: bool = False,
|
||||||
print_model_input: bool = False,
|
print_model_input: bool = False,
|
||||||
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
logits_processor: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
|
||||||
) -> Generator[List[GenerationResult], None, None]:
|
) -> Generator[list[GenerationResult], None, None]:
|
||||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.args.max_seq_len:
|
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.args.max_seq_len:
|
||||||
max_gen_len = self.model.args.max_seq_len - 1
|
max_gen_len = self.model.args.max_seq_len - 1
|
||||||
|
|
||||||
|
@ -245,13 +245,13 @@ class Llama4:
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
self,
|
self,
|
||||||
contents: List[RawContent],
|
contents: list[RawContent],
|
||||||
temperature: float = 0.6,
|
temperature: float = 0.6,
|
||||||
top_p: float = 0.9,
|
top_p: float = 0.9,
|
||||||
max_gen_len: Optional[int] = None,
|
max_gen_len: int | None = None,
|
||||||
logprobs: bool = False,
|
logprobs: bool = False,
|
||||||
echo: bool = False,
|
echo: bool = False,
|
||||||
) -> Generator[List[GenerationResult], None, None]:
|
) -> Generator[list[GenerationResult], None, None]:
|
||||||
llm_inputs = [self.formatter.encode_content(c) for c in contents]
|
llm_inputs = [self.formatter.encode_content(c) for c in contents]
|
||||||
for result in self.generate(
|
for result in self.generate(
|
||||||
llm_inputs=llm_inputs,
|
llm_inputs=llm_inputs,
|
||||||
|
@ -267,13 +267,13 @@ class Llama4:
|
||||||
|
|
||||||
def chat_completion(
|
def chat_completion(
|
||||||
self,
|
self,
|
||||||
messages_batch: List[List[RawMessage]],
|
messages_batch: list[list[RawMessage]],
|
||||||
temperature: float = 0.6,
|
temperature: float = 0.6,
|
||||||
top_p: float = 0.9,
|
top_p: float = 0.9,
|
||||||
max_gen_len: Optional[int] = None,
|
max_gen_len: int | None = None,
|
||||||
logprobs: bool = False,
|
logprobs: bool = False,
|
||||||
echo: bool = False,
|
echo: bool = False,
|
||||||
) -> Generator[List[GenerationResult], None, None]:
|
) -> Generator[list[GenerationResult], None, None]:
|
||||||
llm_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
|
llm_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
|
||||||
for result in self.generate(
|
for result in self.generate(
|
||||||
llm_inputs=llm_inputs,
|
llm_inputs=llm_inputs,
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any
|
||||||
|
|
||||||
import fairscale.nn.model_parallel.initialize as fs_init
|
import fairscale.nn.model_parallel.initialize as fs_init
|
||||||
import torch
|
import torch
|
||||||
|
@ -89,7 +89,7 @@ def apply_rotary_emb(
|
||||||
xq: torch.Tensor,
|
xq: torch.Tensor,
|
||||||
xk: torch.Tensor,
|
xk: torch.Tensor,
|
||||||
freqs_cis: torch.Tensor,
|
freqs_cis: torch.Tensor,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||||
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
||||||
|
@ -174,13 +174,13 @@ class Attention(nn.Module):
|
||||||
|
|
||||||
def load_hook(
|
def load_hook(
|
||||||
self,
|
self,
|
||||||
state_dict: Dict[str, Any],
|
state_dict: dict[str, Any],
|
||||||
prefix: str,
|
prefix: str,
|
||||||
local_metadata: Dict[str, Any],
|
local_metadata: dict[str, Any],
|
||||||
strict: bool,
|
strict: bool,
|
||||||
missing_keys: List[str],
|
missing_keys: list[str],
|
||||||
unexpected_keys: List[str],
|
unexpected_keys: list[str],
|
||||||
error_msgs: List[str],
|
error_msgs: list[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
if prefix + "wqkv.weight" in state_dict:
|
if prefix + "wqkv.weight" in state_dict:
|
||||||
wqkv = state_dict.pop(prefix + "wqkv.weight")
|
wqkv = state_dict.pop(prefix + "wqkv.weight")
|
||||||
|
@ -200,7 +200,7 @@ class Attention(nn.Module):
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
start_pos: int,
|
start_pos: int,
|
||||||
freqs_cis: torch.Tensor,
|
freqs_cis: torch.Tensor,
|
||||||
mask: Optional[torch.Tensor] = None,
|
mask: torch.Tensor | None = None,
|
||||||
):
|
):
|
||||||
bsz, seqlen, _ = x.shape
|
bsz, seqlen, _ = x.shape
|
||||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||||
|
@ -288,13 +288,13 @@ class TransformerBlock(nn.Module):
|
||||||
|
|
||||||
def load_hook(
|
def load_hook(
|
||||||
self,
|
self,
|
||||||
state_dict: Dict[str, Any],
|
state_dict: dict[str, Any],
|
||||||
prefix: str,
|
prefix: str,
|
||||||
local_metadata: Dict[str, Any],
|
local_metadata: dict[str, Any],
|
||||||
strict: bool,
|
strict: bool,
|
||||||
missing_keys: List[str],
|
missing_keys: list[str],
|
||||||
unexpected_keys: List[str],
|
unexpected_keys: list[str],
|
||||||
error_msgs: List[str],
|
error_msgs: list[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
if prefix + "attention.wqkv.layer_norm_weight" in state_dict:
|
if prefix + "attention.wqkv.layer_norm_weight" in state_dict:
|
||||||
state_dict[prefix + "attention_norm.weight"] = state_dict.pop(prefix + "attention.wqkv.layer_norm_weight")
|
state_dict[prefix + "attention_norm.weight"] = state_dict.pop(prefix + "attention.wqkv.layer_norm_weight")
|
||||||
|
@ -318,8 +318,8 @@ class TransformerBlock(nn.Module):
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
start_pos: int,
|
start_pos: int,
|
||||||
freqs_cis: torch.Tensor,
|
freqs_cis: torch.Tensor,
|
||||||
global_attn_mask: Optional[torch.Tensor],
|
global_attn_mask: torch.Tensor | None,
|
||||||
local_attn_mask: Optional[torch.Tensor],
|
local_attn_mask: torch.Tensor | None,
|
||||||
):
|
):
|
||||||
# The iRoPE architecture uses global attention mask for NoPE layers or
|
# The iRoPE architecture uses global attention mask for NoPE layers or
|
||||||
# if chunked local attention is not used
|
# if chunked local attention is not used
|
||||||
|
@ -374,13 +374,13 @@ class Transformer(nn.Module):
|
||||||
|
|
||||||
def load_hook(
|
def load_hook(
|
||||||
self,
|
self,
|
||||||
state_dict: Dict[str, Any],
|
state_dict: dict[str, Any],
|
||||||
prefix: str,
|
prefix: str,
|
||||||
local_metadata: Dict[str, Any],
|
local_metadata: dict[str, Any],
|
||||||
strict: bool,
|
strict: bool,
|
||||||
missing_keys: List[str],
|
missing_keys: list[str],
|
||||||
unexpected_keys: List[str],
|
unexpected_keys: list[str],
|
||||||
error_msgs: List[str],
|
error_msgs: list[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
if prefix + "rope.freqs" in state_dict:
|
if prefix + "rope.freqs" in state_dict:
|
||||||
state_dict.pop(prefix + "rope.freqs")
|
state_dict.pop(prefix + "rope.freqs")
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
# ruff: noqa: N806
|
# ruff: noqa: N806
|
||||||
# pyre-strict
|
# pyre-strict
|
||||||
from typing import Any, Dict, List
|
from typing import Any
|
||||||
|
|
||||||
import fairscale.nn.model_parallel.initialize as fs_init
|
import fairscale.nn.model_parallel.initialize as fs_init
|
||||||
import torch
|
import torch
|
||||||
|
@ -63,13 +63,13 @@ class Experts(nn.Module):
|
||||||
|
|
||||||
def load_hook(
|
def load_hook(
|
||||||
self,
|
self,
|
||||||
state_dict: Dict[str, Any],
|
state_dict: dict[str, Any],
|
||||||
prefix: str,
|
prefix: str,
|
||||||
local_metadata: Dict[str, Any],
|
local_metadata: dict[str, Any],
|
||||||
strict: bool,
|
strict: bool,
|
||||||
missing_keys: List[str],
|
missing_keys: list[str],
|
||||||
unexpected_keys: List[str],
|
unexpected_keys: list[str],
|
||||||
error_msgs: List[str],
|
error_msgs: list[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
if prefix + "moe_w_in_eD_F" in state_dict:
|
if prefix + "moe_w_in_eD_F" in state_dict:
|
||||||
|
@ -158,13 +158,13 @@ class MoE(torch.nn.Module):
|
||||||
|
|
||||||
def load_hook(
|
def load_hook(
|
||||||
self,
|
self,
|
||||||
state_dict: Dict[str, Any],
|
state_dict: dict[str, Any],
|
||||||
prefix: str,
|
prefix: str,
|
||||||
local_metadata: Dict[str, Any],
|
local_metadata: dict[str, Any],
|
||||||
strict: bool,
|
strict: bool,
|
||||||
missing_keys: List[str],
|
missing_keys: list[str],
|
||||||
unexpected_keys: List[str],
|
unexpected_keys: list[str],
|
||||||
error_msgs: List[str],
|
error_msgs: list[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
if prefix + "w_in_shared_FD.weight" in state_dict:
|
if prefix + "w_in_shared_FD.weight" in state_dict:
|
||||||
state_dict[prefix + "shared_expert.w1.weight"] = state_dict.pop(prefix + "w_in_shared_FD.weight")
|
state_dict[prefix + "shared_expert.w1.weight"] = state_dict.pop(prefix + "w_in_shared_FD.weight")
|
||||||
|
@ -210,5 +210,5 @@ class MoE(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
def divide_exact(numerator: int, denominator: int) -> int:
|
def divide_exact(numerator: int, denominator: int) -> int:
|
||||||
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
|
assert numerator % denominator == 0, f"{numerator} is not divisible by {denominator}"
|
||||||
return numerator // denominator
|
return numerator // denominator
|
||||||
|
|
|
@ -13,7 +13,6 @@
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Optional, Set, Tuple
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torchvision.transforms as tv
|
import torchvision.transforms as tv
|
||||||
|
@ -52,7 +51,7 @@ class ResizeNormalizeImageTransform:
|
||||||
return self.tv_transform(image)
|
return self.tv_transform(image)
|
||||||
|
|
||||||
|
|
||||||
class VariableSizeImageTransform(object):
|
class VariableSizeImageTransform:
|
||||||
"""
|
"""
|
||||||
This class accepts images of any size and dynamically resize, pads and chunks it
|
This class accepts images of any size and dynamically resize, pads and chunks it
|
||||||
based on the image aspect ratio and the number of image chunks we allow.
|
based on the image aspect ratio and the number of image chunks we allow.
|
||||||
|
@ -100,7 +99,7 @@ class VariableSizeImageTransform(object):
|
||||||
self.resample = tv.InterpolationMode.BILINEAR
|
self.resample = tv.InterpolationMode.BILINEAR
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_factors(n: int) -> Set[int]:
|
def get_factors(n: int) -> set[int]:
|
||||||
"""
|
"""
|
||||||
Calculate all factors of a given number, i.e. a dividor that leaves
|
Calculate all factors of a given number, i.e. a dividor that leaves
|
||||||
no remainder. For example, if n=12, it will return {1, 2, 3, 4, 6, 12}.
|
no remainder. For example, if n=12, it will return {1, 2, 3, 4, 6, 12}.
|
||||||
|
@ -170,9 +169,9 @@ class VariableSizeImageTransform(object):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_max_res_without_distortion(
|
def get_max_res_without_distortion(
|
||||||
image_size: Tuple[int, int],
|
image_size: tuple[int, int],
|
||||||
target_size: Tuple[int, int],
|
target_size: tuple[int, int],
|
||||||
) -> Tuple[int, int]:
|
) -> tuple[int, int]:
|
||||||
"""
|
"""
|
||||||
Determines the maximum resolution to which an image can be resized to without distorting its
|
Determines the maximum resolution to which an image can be resized to without distorting its
|
||||||
aspect ratio, based on the target resolution.
|
aspect ratio, based on the target resolution.
|
||||||
|
@ -223,8 +222,8 @@ class VariableSizeImageTransform(object):
|
||||||
def resize_without_distortion(
|
def resize_without_distortion(
|
||||||
self,
|
self,
|
||||||
image: torch.Tensor,
|
image: torch.Tensor,
|
||||||
target_size: Tuple[int, int],
|
target_size: tuple[int, int],
|
||||||
max_upscaling_size: Optional[int],
|
max_upscaling_size: int | None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Used to resize an image to target_resolution, without distortion.
|
Used to resize an image to target_resolution, without distortion.
|
||||||
|
@ -289,10 +288,10 @@ class VariableSizeImageTransform(object):
|
||||||
|
|
||||||
def get_best_fit(
|
def get_best_fit(
|
||||||
self,
|
self,
|
||||||
image_size: Tuple[int, int],
|
image_size: tuple[int, int],
|
||||||
possible_resolutions: torch.Tensor,
|
possible_resolutions: torch.Tensor,
|
||||||
resize_to_max_canvas: bool = False,
|
resize_to_max_canvas: bool = False,
|
||||||
) -> Tuple[int, int]:
|
) -> tuple[int, int]:
|
||||||
"""
|
"""
|
||||||
Determines the best canvas possible from a list of possible resolutions to, without distortion,
|
Determines the best canvas possible from a list of possible resolutions to, without distortion,
|
||||||
resize an image to.
|
resize an image to.
|
||||||
|
@ -392,7 +391,7 @@ class VariableSizeImageTransform(object):
|
||||||
max_num_chunks: int,
|
max_num_chunks: int,
|
||||||
normalize_img: bool = True,
|
normalize_img: bool = True,
|
||||||
resize_to_max_canvas: bool = False,
|
resize_to_max_canvas: bool = False,
|
||||||
) -> Tuple[torch.Tensor, Tuple[int, int]]:
|
) -> tuple[torch.Tensor, tuple[int, int]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
image (PIL.Image): Image to be resized.
|
image (PIL.Image): Image to be resized.
|
||||||
|
|
|
@ -12,7 +12,6 @@
|
||||||
# the top-level of this source tree.
|
# the top-level of this source tree.
|
||||||
|
|
||||||
import textwrap
|
import textwrap
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import ToolDefinition, ToolParamDefinition
|
from llama_stack.apis.inference import ToolDefinition, ToolParamDefinition
|
||||||
from llama_stack.models.llama.llama3.prompt_templates.base import (
|
from llama_stack.models.llama.llama3.prompt_templates.base import (
|
||||||
|
@ -67,14 +66,14 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
||||||
""".strip("\n")
|
""".strip("\n")
|
||||||
)
|
)
|
||||||
|
|
||||||
def gen(self, custom_tools: List[ToolDefinition], system_prompt: Optional[str] = None) -> PromptTemplate:
|
def gen(self, custom_tools: list[ToolDefinition], system_prompt: str | None = None) -> PromptTemplate:
|
||||||
system_prompt = system_prompt or self.DEFAULT_PROMPT
|
system_prompt = system_prompt or self.DEFAULT_PROMPT
|
||||||
return PromptTemplate(
|
return PromptTemplate(
|
||||||
system_prompt,
|
system_prompt,
|
||||||
{"function_description": self._gen_function_description(custom_tools)},
|
{"function_description": self._gen_function_description(custom_tools)},
|
||||||
)
|
)
|
||||||
|
|
||||||
def _gen_function_description(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
|
def _gen_function_description(self, custom_tools: list[ToolDefinition]) -> PromptTemplate:
|
||||||
template_str = textwrap.dedent(
|
template_str = textwrap.dedent(
|
||||||
"""
|
"""
|
||||||
Here is a list of functions in JSON format that you can invoke.
|
Here is a list of functions in JSON format that you can invoke.
|
||||||
|
@ -120,7 +119,7 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
||||||
{"tools": [t.model_dump() for t in custom_tools]},
|
{"tools": [t.model_dump() for t in custom_tools]},
|
||||||
).render()
|
).render()
|
||||||
|
|
||||||
def data_examples(self) -> List[List[ToolDefinition]]:
|
def data_examples(self) -> list[list[ToolDefinition]]:
|
||||||
return [
|
return [
|
||||||
[
|
[
|
||||||
ToolDefinition(
|
ToolDefinition(
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
import textwrap
|
import textwrap
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
|
from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
|
||||||
PythonListCustomToolGenerator,
|
PythonListCustomToolGenerator,
|
||||||
|
@ -23,7 +22,7 @@ from ..prompt_format import (
|
||||||
THIS_DIR = Path(__file__).parent
|
THIS_DIR = Path(__file__).parent
|
||||||
|
|
||||||
|
|
||||||
def usecases(base_model: bool = False) -> List[UseCase | str]:
|
def usecases(base_model: bool = False) -> list[UseCase | str]:
|
||||||
with open(THIS_DIR.parent / "resources/small_dog.jpg", "rb") as f:
|
with open(THIS_DIR.parent / "resources/small_dog.jpg", "rb") as f:
|
||||||
img_small_dog = f.read()
|
img_small_dog = f.read()
|
||||||
with open(THIS_DIR.parent / "resources/dog.jpg", "rb") as f:
|
with open(THIS_DIR.parent / "resources/dog.jpg", "rb") as f:
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Callable, Optional
|
from collections.abc import Callable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
|
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
|
||||||
|
@ -45,8 +45,8 @@ def experts_batched_swiglu_wrapper(
|
||||||
def convert_to_quantized_model(
|
def convert_to_quantized_model(
|
||||||
model: Transformer,
|
model: Transformer,
|
||||||
checkpoint_dir: str,
|
checkpoint_dir: str,
|
||||||
quantization_mode: Optional[str] = None,
|
quantization_mode: str | None = None,
|
||||||
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
fp8_activation_scale_ub: float | None = 1200.0,
|
||||||
use_rich_progress: bool = True,
|
use_rich_progress: bool = True,
|
||||||
) -> Transformer:
|
) -> Transformer:
|
||||||
from ...quantize_impls import (
|
from ...quantize_impls import (
|
||||||
|
@ -213,7 +213,7 @@ def logging_callbacks(
|
||||||
)
|
)
|
||||||
task_id = progress.add_task("[blue]Converting layers...", total=total_blocks, status="Starting")
|
task_id = progress.add_task("[blue]Converting layers...", total=total_blocks, status="Starting")
|
||||||
|
|
||||||
def update_status(message: Optional[str], completed: Optional[int] = None) -> None:
|
def update_status(message: str | None, completed: int | None = None) -> None:
|
||||||
if use_rich_progress:
|
if use_rich_progress:
|
||||||
if message is not None:
|
if message is not None:
|
||||||
progress.update(task_id, status=message)
|
progress.update(task_id, status=message)
|
||||||
|
|
|
@ -5,18 +5,11 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from collections.abc import Collection, Iterator, Sequence, Set
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
from typing import (
|
||||||
AbstractSet,
|
|
||||||
Collection,
|
|
||||||
Dict,
|
|
||||||
Iterator,
|
|
||||||
List,
|
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
Union,
|
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -114,7 +107,7 @@ class Tokenizer:
|
||||||
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
|
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
special_tokens: Dict[str, int]
|
special_tokens: dict[str, int]
|
||||||
|
|
||||||
num_reserved_special_tokens = 2048
|
num_reserved_special_tokens = 2048
|
||||||
|
|
||||||
|
@ -182,9 +175,9 @@ class Tokenizer:
|
||||||
*,
|
*,
|
||||||
bos: bool,
|
bos: bool,
|
||||||
eos: bool,
|
eos: bool,
|
||||||
allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None,
|
allowed_special: Literal["all"] | Set[str] | None = None,
|
||||||
disallowed_special: Union[Literal["all"], Collection[str]] = (),
|
disallowed_special: Literal["all"] | Collection[str] = (),
|
||||||
) -> List[int]:
|
) -> list[int]:
|
||||||
"""
|
"""
|
||||||
Encodes a string into a list of token IDs.
|
Encodes a string into a list of token IDs.
|
||||||
|
|
||||||
|
@ -217,7 +210,7 @@ class Tokenizer:
|
||||||
s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
|
s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
t: List[int] = []
|
t: list[int] = []
|
||||||
for substr in substrs:
|
for substr in substrs:
|
||||||
t.extend(
|
t.extend(
|
||||||
self.model.encode(
|
self.model.encode(
|
||||||
|
@ -243,7 +236,7 @@ class Tokenizer:
|
||||||
str: The decoded string.
|
str: The decoded string.
|
||||||
"""
|
"""
|
||||||
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
|
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
|
||||||
return self.model.decode(cast(List[int], t))
|
return self.model.decode(cast(list[int], t))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _split_whitespaces_or_nonwhitespaces(s: str, max_consecutive_slice_len: int) -> Iterator[str]:
|
def _split_whitespaces_or_nonwhitespaces(s: str, max_consecutive_slice_len: int) -> Iterator[str]:
|
||||||
|
|
|
@ -5,7 +5,8 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Any, Callable, Dict, List
|
from collections.abc import Callable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
@ -136,13 +137,13 @@ class VisionEmbeddings(torch.nn.Module):
|
||||||
|
|
||||||
def load_hook(
|
def load_hook(
|
||||||
self,
|
self,
|
||||||
state_dict: Dict[str, Any],
|
state_dict: dict[str, Any],
|
||||||
prefix: str,
|
prefix: str,
|
||||||
local_metadata: Dict[str, Any],
|
local_metadata: dict[str, Any],
|
||||||
strict: bool = True,
|
strict: bool = True,
|
||||||
missing_keys: List[str] = None,
|
missing_keys: list[str] = None,
|
||||||
unexpected_keys: List[str] = None,
|
unexpected_keys: list[str] = None,
|
||||||
error_msgs: List[str] = None,
|
error_msgs: list[str] = None,
|
||||||
return_state_dict: bool = False,
|
return_state_dict: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
original_sd = self.state_dict()
|
original_sd = self.state_dict()
|
||||||
|
@ -163,7 +164,7 @@ class VisionEmbeddings(torch.nn.Module):
|
||||||
# each image is a tensor of shape [num_tiles, C, H, W]
|
# each image is a tensor of shape [num_tiles, C, H, W]
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
image_batch: List[List[torch.Tensor]],
|
image_batch: list[list[torch.Tensor]],
|
||||||
image_mask: torch.Tensor,
|
image_mask: torch.Tensor,
|
||||||
h_ref: torch.Tensor,
|
h_ref: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
|
@ -4,7 +4,8 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from collections.abc import Callable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import fairscale.nn.model_parallel.initialize as fs_init
|
import fairscale.nn.model_parallel.initialize as fs_init
|
||||||
import torch
|
import torch
|
||||||
|
@ -42,9 +43,9 @@ class ColumnParallelConv2dPatch(torch.nn.Module):
|
||||||
self,
|
self,
|
||||||
in_channels: int,
|
in_channels: int,
|
||||||
out_channels: int,
|
out_channels: int,
|
||||||
kernel_size: Union[int, Tuple[int, int]],
|
kernel_size: int | tuple[int, int],
|
||||||
stride: Union[int, Tuple[int, int]],
|
stride: int | tuple[int, int],
|
||||||
bias: Optional[bool] = False,
|
bias: bool | None = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if isinstance(kernel_size, int):
|
if isinstance(kernel_size, int):
|
||||||
|
@ -134,15 +135,15 @@ class _TransformerBlock(nn.Module):
|
||||||
def attention(
|
def attention(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
freq_cis: Optional[torch.Tensor] = None,
|
freq_cis: torch.Tensor | None = None,
|
||||||
):
|
):
|
||||||
return self.attn(x=x, start_pos=0, freqs_cis=freq_cis)
|
return self.attn(x=x, start_pos=0, freqs_cis=freq_cis)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
mask: Optional[torch.Tensor] = None,
|
mask: torch.Tensor | None = None,
|
||||||
freq_cis: Optional[torch.Tensor] = None,
|
freq_cis: torch.Tensor | None = None,
|
||||||
):
|
):
|
||||||
_gate_attn = 1 if not self.gated else self.gate_attn.tanh()
|
_gate_attn = 1 if not self.gated else self.gate_attn.tanh()
|
||||||
_gate_ffn = 1 if not self.gated else self.gate_ffn.tanh()
|
_gate_ffn = 1 if not self.gated else self.gate_ffn.tanh()
|
||||||
|
@ -210,8 +211,8 @@ class PackingIndex:
|
||||||
class VisionEncoder(nn.Module):
|
class VisionEncoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
image_size: Tuple[int, int],
|
image_size: tuple[int, int],
|
||||||
patch_size: Tuple[int, int],
|
patch_size: tuple[int, int],
|
||||||
dim: int,
|
dim: int,
|
||||||
layers: int,
|
layers: int,
|
||||||
heads: int,
|
heads: int,
|
||||||
|
@ -299,13 +300,13 @@ class VisionEncoder(nn.Module):
|
||||||
|
|
||||||
def load_hook(
|
def load_hook(
|
||||||
self,
|
self,
|
||||||
state_dict: Dict[str, Any],
|
state_dict: dict[str, Any],
|
||||||
prefix: str,
|
prefix: str,
|
||||||
local_metadata: Dict[str, Any],
|
local_metadata: dict[str, Any],
|
||||||
strict: bool = True,
|
strict: bool = True,
|
||||||
missing_keys: List[str] = None,
|
missing_keys: list[str] = None,
|
||||||
unexpected_keys: List[str] = None,
|
unexpected_keys: list[str] = None,
|
||||||
error_msgs: List[str] = None,
|
error_msgs: list[str] = None,
|
||||||
return_state_dict: bool = False,
|
return_state_dict: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
orig_pos_embed = state_dict.get(prefix + "positional_embedding")
|
orig_pos_embed = state_dict.get(prefix + "positional_embedding")
|
||||||
|
|
|
@ -14,7 +14,6 @@
|
||||||
import json
|
import json
|
||||||
import textwrap
|
import textwrap
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
@ -44,7 +43,7 @@ class TextCompletionContent(BaseModel):
|
||||||
class UseCase(BaseModel):
|
class UseCase(BaseModel):
|
||||||
title: str = ""
|
title: str = ""
|
||||||
description: str = ""
|
description: str = ""
|
||||||
dialogs: List[List[RawMessage] | TextCompletionContent | str] = Field(default_factory=list)
|
dialogs: list[list[RawMessage] | TextCompletionContent | str] = Field(default_factory=list)
|
||||||
notes: str = ""
|
notes: str = ""
|
||||||
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json
|
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json
|
||||||
max_gen_len: int = 512
|
max_gen_len: int = 512
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
# type: ignore
|
# type: ignore
|
||||||
import collections
|
import collections
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, Tuple, Type, Union
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -27,7 +26,7 @@ class Fp8ScaledWeights:
|
||||||
# TODO: Ugly trick so torch allows us to replace parameters
|
# TODO: Ugly trick so torch allows us to replace parameters
|
||||||
# with our custom Fp8Weights instance. Do this properly.
|
# with our custom Fp8Weights instance. Do this properly.
|
||||||
@property
|
@property
|
||||||
def __class__(self) -> Type[nn.parameter.Parameter]:
|
def __class__(self) -> type[nn.parameter.Parameter]:
|
||||||
return nn.Parameter
|
return nn.Parameter
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -51,7 +50,7 @@ class Int4ScaledWeights:
|
||||||
# TODO: Ugly trick so torch allows us to replace parameters
|
# TODO: Ugly trick so torch allows us to replace parameters
|
||||||
# with our custom Int4Weights instance. Do this properly.
|
# with our custom Int4Weights instance. Do this properly.
|
||||||
@property
|
@property
|
||||||
def __class__(self) -> Type[nn.parameter.Parameter]:
|
def __class__(self) -> type[nn.parameter.Parameter]:
|
||||||
return nn.Parameter
|
return nn.Parameter
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -74,7 +73,7 @@ class Int4Weights(
|
||||||
def int4_row_quantize(
|
def int4_row_quantize(
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
group_size: int = 128,
|
group_size: int = 128,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
n_bit = 4 # Number of target bits.
|
n_bit = 4 # Number of target bits.
|
||||||
to_quant = x.reshape(-1, group_size).to(torch.float)
|
to_quant = x.reshape(-1, group_size).to(torch.float)
|
||||||
|
|
||||||
|
@ -115,8 +114,8 @@ def pack_int4(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
||||||
def bmm_nt(
|
def bmm_nt(
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
w: Union[Fp8RowwiseWeights, Int4Weights],
|
w: Fp8RowwiseWeights | Int4Weights,
|
||||||
num_tokens: Optional[Tensor] = None,
|
num_tokens: Tensor | None = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
if isinstance(w, Fp8ScaledWeights):
|
if isinstance(w, Fp8ScaledWeights):
|
||||||
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x, num_tokens, w.activation_scale_ub)
|
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x, num_tokens, w.activation_scale_ub)
|
||||||
|
@ -129,10 +128,10 @@ def bmm_nt(
|
||||||
|
|
||||||
def ffn_swiglu(
|
def ffn_swiglu(
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
w1: Union[Fp8RowwiseWeights, Int4Weights],
|
w1: Fp8RowwiseWeights | Int4Weights,
|
||||||
w3: Union[Fp8RowwiseWeights, Int4Weights],
|
w3: Fp8RowwiseWeights | Int4Weights,
|
||||||
w2: Union[Fp8RowwiseWeights, Int4Weights],
|
w2: Fp8RowwiseWeights | Int4Weights,
|
||||||
num_tokens: Optional[Tensor] = None,
|
num_tokens: Tensor | None = None,
|
||||||
is_memory_bounded: bool = False,
|
is_memory_bounded: bool = False,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
if (isinstance(w1, Fp8ScaledWeights) and isinstance(w3, Fp8ScaledWeights) and isinstance(w2, Fp8ScaledWeights)) or (
|
if (isinstance(w1, Fp8ScaledWeights) and isinstance(w3, Fp8ScaledWeights) and isinstance(w2, Fp8ScaledWeights)) or (
|
||||||
|
@ -158,7 +157,7 @@ def ffn_swiglu(
|
||||||
def quantize_fp8(
|
def quantize_fp8(
|
||||||
w: Tensor,
|
w: Tensor,
|
||||||
fp8_activation_scale_ub: float,
|
fp8_activation_scale_ub: float,
|
||||||
output_device: Optional[torch.device] = None,
|
output_device: torch.device | None = None,
|
||||||
) -> Fp8RowwiseWeights:
|
) -> Fp8RowwiseWeights:
|
||||||
"""Quantize [n, k] weight tensor.
|
"""Quantize [n, k] weight tensor.
|
||||||
|
|
||||||
|
@ -184,7 +183,7 @@ def quantize_fp8(
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def quantize_int4(
|
def quantize_int4(
|
||||||
w: Tensor,
|
w: Tensor,
|
||||||
output_device: Optional[torch.device] = None,
|
output_device: torch.device | None = None,
|
||||||
) -> Int4Weights:
|
) -> Int4Weights:
|
||||||
"""Quantize [n, k/2] weight tensor.
|
"""Quantize [n, k/2] weight tensor.
|
||||||
|
|
||||||
|
@ -213,7 +212,7 @@ def load_fp8(
|
||||||
w: Tensor,
|
w: Tensor,
|
||||||
w_scale: Tensor,
|
w_scale: Tensor,
|
||||||
fp8_activation_scale_ub: float,
|
fp8_activation_scale_ub: float,
|
||||||
output_device: Optional[torch.device] = None,
|
output_device: torch.device | None = None,
|
||||||
) -> Fp8RowwiseWeights:
|
) -> Fp8RowwiseWeights:
|
||||||
"""Load FP8 [n, k] weight tensor.
|
"""Load FP8 [n, k] weight tensor.
|
||||||
|
|
||||||
|
@ -239,7 +238,7 @@ def load_int4(
|
||||||
w: Tensor,
|
w: Tensor,
|
||||||
scale: Tensor,
|
scale: Tensor,
|
||||||
zero_point: Tensor,
|
zero_point: Tensor,
|
||||||
output_device: Optional[torch.device] = None,
|
output_device: torch.device | None = None,
|
||||||
) -> Int4Weights:
|
) -> Int4Weights:
|
||||||
"""Load INT4 [n, k/2] weight tensor.
|
"""Load INT4 [n, k/2] weight tensor.
|
||||||
|
|
||||||
|
@ -256,9 +255,9 @@ def load_int4(
|
||||||
|
|
||||||
def fc_dynamic(
|
def fc_dynamic(
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
w: Union[Fp8RowwiseWeights, Int4Weights],
|
w: Fp8RowwiseWeights | Int4Weights,
|
||||||
activation_scale_ub: Optional[Tensor] = None,
|
activation_scale_ub: Tensor | None = None,
|
||||||
num_tokens: Optional[Tensor] = None,
|
num_tokens: Tensor | None = None,
|
||||||
is_memory_bounded: bool = False,
|
is_memory_bounded: bool = False,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""
|
"""
|
||||||
|
@ -275,11 +274,11 @@ def fc_dynamic(
|
||||||
|
|
||||||
def ffn_swiglu_dynamic(
|
def ffn_swiglu_dynamic(
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
w1: Union[Fp8RowwiseWeights, Int4Weights],
|
w1: Fp8RowwiseWeights | Int4Weights,
|
||||||
w3: Union[Fp8RowwiseWeights, Int4Weights],
|
w3: Fp8RowwiseWeights | Int4Weights,
|
||||||
w2: Union[Fp8RowwiseWeights, Int4Weights],
|
w2: Fp8RowwiseWeights | Int4Weights,
|
||||||
activation_scale_ub: Optional[Tensor] = None,
|
activation_scale_ub: Tensor | None = None,
|
||||||
num_tokens: Optional[Tensor] = None,
|
num_tokens: Tensor | None = None,
|
||||||
is_memory_bounded: bool = False,
|
is_memory_bounded: bool = False,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
assert x.dim() == 3 or x.dim() == 2
|
assert x.dim() == 3 or x.dim() == 2
|
||||||
|
|
|
@ -6,7 +6,6 @@
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from .sku_types import (
|
from .sku_types import (
|
||||||
CheckpointQuantizationFormat,
|
CheckpointQuantizationFormat,
|
||||||
|
@ -19,14 +18,14 @@ LLAMA2_VOCAB_SIZE = 32000
|
||||||
LLAMA3_VOCAB_SIZE = 128256
|
LLAMA3_VOCAB_SIZE = 128256
|
||||||
|
|
||||||
|
|
||||||
def resolve_model(descriptor: str) -> Optional[Model]:
|
def resolve_model(descriptor: str) -> Model | None:
|
||||||
for m in all_registered_models():
|
for m in all_registered_models():
|
||||||
if descriptor in (m.descriptor(), m.huggingface_repo):
|
if descriptor in (m.descriptor(), m.huggingface_repo):
|
||||||
return m
|
return m
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def all_registered_models() -> List[Model]:
|
def all_registered_models() -> list[Model]:
|
||||||
return (
|
return (
|
||||||
llama2_family()
|
llama2_family()
|
||||||
+ llama3_family()
|
+ llama3_family()
|
||||||
|
@ -38,48 +37,48 @@ def all_registered_models() -> List[Model]:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def llama2_family() -> List[Model]:
|
def llama2_family() -> list[Model]:
|
||||||
return [
|
return [
|
||||||
*llama2_base_models(),
|
*llama2_base_models(),
|
||||||
*llama2_instruct_models(),
|
*llama2_instruct_models(),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def llama3_family() -> List[Model]:
|
def llama3_family() -> list[Model]:
|
||||||
return [
|
return [
|
||||||
*llama3_base_models(),
|
*llama3_base_models(),
|
||||||
*llama3_instruct_models(),
|
*llama3_instruct_models(),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def llama3_1_family() -> List[Model]:
|
def llama3_1_family() -> list[Model]:
|
||||||
return [
|
return [
|
||||||
*llama3_1_base_models(),
|
*llama3_1_base_models(),
|
||||||
*llama3_1_instruct_models(),
|
*llama3_1_instruct_models(),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def llama3_2_family() -> List[Model]:
|
def llama3_2_family() -> list[Model]:
|
||||||
return [
|
return [
|
||||||
*llama3_2_base_models(),
|
*llama3_2_base_models(),
|
||||||
*llama3_2_instruct_models(),
|
*llama3_2_instruct_models(),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def llama3_3_family() -> List[Model]:
|
def llama3_3_family() -> list[Model]:
|
||||||
return [
|
return [
|
||||||
*llama3_3_instruct_models(),
|
*llama3_3_instruct_models(),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def llama4_family() -> List[Model]:
|
def llama4_family() -> list[Model]:
|
||||||
return [
|
return [
|
||||||
*llama4_base_models(),
|
*llama4_base_models(),
|
||||||
*llama4_instruct_models(),
|
*llama4_instruct_models(),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def llama4_base_models() -> List[Model]:
|
def llama4_base_models() -> list[Model]:
|
||||||
return [
|
return [
|
||||||
Model(
|
Model(
|
||||||
core_model_id=CoreModelId.llama4_scout_17b_16e,
|
core_model_id=CoreModelId.llama4_scout_17b_16e,
|
||||||
|
@ -98,7 +97,7 @@ def llama4_base_models() -> List[Model]:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def llama4_instruct_models() -> List[Model]:
|
def llama4_instruct_models() -> list[Model]:
|
||||||
return [
|
return [
|
||||||
Model(
|
Model(
|
||||||
core_model_id=CoreModelId.llama4_scout_17b_16e_instruct,
|
core_model_id=CoreModelId.llama4_scout_17b_16e_instruct,
|
||||||
|
@ -126,7 +125,7 @@ def llama4_instruct_models() -> List[Model]:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def llama2_base_models() -> List[Model]:
|
def llama2_base_models() -> list[Model]:
|
||||||
return [
|
return [
|
||||||
Model(
|
Model(
|
||||||
core_model_id=CoreModelId.llama2_7b,
|
core_model_id=CoreModelId.llama2_7b,
|
||||||
|
@ -185,7 +184,7 @@ def llama2_base_models() -> List[Model]:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def llama3_base_models() -> List[Model]:
|
def llama3_base_models() -> list[Model]:
|
||||||
return [
|
return [
|
||||||
Model(
|
Model(
|
||||||
core_model_id=CoreModelId.llama3_8b,
|
core_model_id=CoreModelId.llama3_8b,
|
||||||
|
@ -226,7 +225,7 @@ def llama3_base_models() -> List[Model]:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def llama3_1_base_models() -> List[Model]:
|
def llama3_1_base_models() -> list[Model]:
|
||||||
return [
|
return [
|
||||||
Model(
|
Model(
|
||||||
core_model_id=CoreModelId.llama3_1_8b,
|
core_model_id=CoreModelId.llama3_1_8b,
|
||||||
|
@ -324,7 +323,7 @@ def llama3_1_base_models() -> List[Model]:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def llama3_2_base_models() -> List[Model]:
|
def llama3_2_base_models() -> list[Model]:
|
||||||
return [
|
return [
|
||||||
Model(
|
Model(
|
||||||
core_model_id=CoreModelId.llama3_2_1b,
|
core_model_id=CoreModelId.llama3_2_1b,
|
||||||
|
@ -407,7 +406,7 @@ def llama3_2_base_models() -> List[Model]:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def llama2_instruct_models() -> List[Model]:
|
def llama2_instruct_models() -> list[Model]:
|
||||||
return [
|
return [
|
||||||
Model(
|
Model(
|
||||||
core_model_id=CoreModelId.llama2_7b_chat,
|
core_model_id=CoreModelId.llama2_7b_chat,
|
||||||
|
@ -466,7 +465,7 @@ def llama2_instruct_models() -> List[Model]:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def llama3_instruct_models() -> List[Model]:
|
def llama3_instruct_models() -> list[Model]:
|
||||||
return [
|
return [
|
||||||
Model(
|
Model(
|
||||||
core_model_id=CoreModelId.llama3_8b_instruct,
|
core_model_id=CoreModelId.llama3_8b_instruct,
|
||||||
|
@ -507,7 +506,7 @@ def llama3_instruct_models() -> List[Model]:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def llama3_1_instruct_models() -> List[Model]:
|
def llama3_1_instruct_models() -> list[Model]:
|
||||||
return [
|
return [
|
||||||
Model(
|
Model(
|
||||||
core_model_id=CoreModelId.llama3_1_8b_instruct,
|
core_model_id=CoreModelId.llama3_1_8b_instruct,
|
||||||
|
@ -635,7 +634,7 @@ def arch_args_3b() -> dict:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def llama3_2_quantized_models() -> List[Model]:
|
def llama3_2_quantized_models() -> list[Model]:
|
||||||
return [
|
return [
|
||||||
Model(
|
Model(
|
||||||
core_model_id=CoreModelId.llama3_2_1b_instruct,
|
core_model_id=CoreModelId.llama3_2_1b_instruct,
|
||||||
|
@ -704,7 +703,7 @@ def llama3_2_quantized_models() -> List[Model]:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def llama3_2_instruct_models() -> List[Model]:
|
def llama3_2_instruct_models() -> list[Model]:
|
||||||
return [
|
return [
|
||||||
Model(
|
Model(
|
||||||
core_model_id=CoreModelId.llama3_2_1b_instruct,
|
core_model_id=CoreModelId.llama3_2_1b_instruct,
|
||||||
|
@ -766,7 +765,7 @@ def llama3_2_instruct_models() -> List[Model]:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def llama3_3_instruct_models() -> List[Model]:
|
def llama3_3_instruct_models() -> list[Model]:
|
||||||
return [
|
return [
|
||||||
Model(
|
Model(
|
||||||
core_model_id=CoreModelId.llama3_3_70b_instruct,
|
core_model_id=CoreModelId.llama3_3_70b_instruct,
|
||||||
|
@ -790,7 +789,7 @@ def llama3_3_instruct_models() -> List[Model]:
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
@lru_cache
|
||||||
def safety_models() -> List[Model]:
|
def safety_models() -> list[Model]:
|
||||||
return [
|
return [
|
||||||
Model(
|
Model(
|
||||||
core_model_id=CoreModelId.llama_guard_4_12b,
|
core_model_id=CoreModelId.llama_guard_4_12b,
|
||||||
|
@ -919,7 +918,7 @@ def safety_models() -> List[Model]:
|
||||||
@dataclass
|
@dataclass
|
||||||
class LlamaDownloadInfo:
|
class LlamaDownloadInfo:
|
||||||
folder: str
|
folder: str
|
||||||
files: List[str]
|
files: list[str]
|
||||||
pth_size: int
|
pth_size: int
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
@ -159,13 +159,13 @@ def model_family(model_id) -> ModelFamily:
|
||||||
class Model(BaseModel):
|
class Model(BaseModel):
|
||||||
core_model_id: CoreModelId
|
core_model_id: CoreModelId
|
||||||
description: str
|
description: str
|
||||||
huggingface_repo: Optional[str] = None
|
huggingface_repo: str | None = None
|
||||||
arch_args: Dict[str, Any]
|
arch_args: dict[str, Any]
|
||||||
variant: str = ""
|
variant: str = ""
|
||||||
|
|
||||||
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
|
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
|
||||||
pth_file_count: int
|
pth_file_count: int
|
||||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
# silence pydantic until we remove the `model_` fields
|
# silence pydantic until we remove the `model_` fields
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, List, Optional, Protocol
|
from typing import Any, Protocol
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
@ -65,7 +65,7 @@ class DatasetsProtocolPrivate(Protocol):
|
||||||
|
|
||||||
|
|
||||||
class ScoringFunctionsProtocolPrivate(Protocol):
|
class ScoringFunctionsProtocolPrivate(Protocol):
|
||||||
async def list_scoring_functions(self) -> List[ScoringFn]: ...
|
async def list_scoring_functions(self) -> list[ScoringFn]: ...
|
||||||
|
|
||||||
async def register_scoring_function(self, scoring_fn: ScoringFn) -> None: ...
|
async def register_scoring_function(self, scoring_fn: ScoringFn) -> None: ...
|
||||||
|
|
||||||
|
@ -88,24 +88,24 @@ class ProviderSpec(BaseModel):
|
||||||
...,
|
...,
|
||||||
description="Fully-qualified classname of the config for this provider",
|
description="Fully-qualified classname of the config for this provider",
|
||||||
)
|
)
|
||||||
api_dependencies: List[Api] = Field(
|
api_dependencies: list[Api] = Field(
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
description="Higher-level API surfaces may depend on other providers to provide their functionality",
|
description="Higher-level API surfaces may depend on other providers to provide their functionality",
|
||||||
)
|
)
|
||||||
optional_api_dependencies: List[Api] = Field(
|
optional_api_dependencies: list[Api] = Field(
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
)
|
)
|
||||||
deprecation_warning: Optional[str] = Field(
|
deprecation_warning: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="If this provider is deprecated, specify the warning message here",
|
description="If this provider is deprecated, specify the warning message here",
|
||||||
)
|
)
|
||||||
deprecation_error: Optional[str] = Field(
|
deprecation_error: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="If this provider is deprecated and does NOT work, specify the error message here",
|
description="If this provider is deprecated and does NOT work, specify the error message here",
|
||||||
)
|
)
|
||||||
|
|
||||||
# used internally by the resolver; this is a hack for now
|
# used internally by the resolver; this is a hack for now
|
||||||
deps__: List[str] = Field(default_factory=list)
|
deps__: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_sample(self) -> bool:
|
def is_sample(self) -> bool:
|
||||||
|
@ -131,25 +131,25 @@ Fully-qualified name of the module to import. The module is expected to have:
|
||||||
- `get_adapter_impl(config, deps)`: returns the adapter implementation
|
- `get_adapter_impl(config, deps)`: returns the adapter implementation
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
pip_packages: List[str] = Field(
|
pip_packages: list[str] = Field(
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
description="The pip dependencies needed for this implementation",
|
description="The pip dependencies needed for this implementation",
|
||||||
)
|
)
|
||||||
config_class: str = Field(
|
config_class: str = Field(
|
||||||
description="Fully-qualified classname of the config for this provider",
|
description="Fully-qualified classname of the config for this provider",
|
||||||
)
|
)
|
||||||
provider_data_validator: Optional[str] = Field(
|
provider_data_validator: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class InlineProviderSpec(ProviderSpec):
|
class InlineProviderSpec(ProviderSpec):
|
||||||
pip_packages: List[str] = Field(
|
pip_packages: list[str] = Field(
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
description="The pip dependencies needed for this implementation",
|
description="The pip dependencies needed for this implementation",
|
||||||
)
|
)
|
||||||
container_image: Optional[str] = Field(
|
container_image: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="""
|
description="""
|
||||||
The container image to use for this implementation. If one is provided, pip_packages will be ignored.
|
The container image to use for this implementation. If one is provided, pip_packages will be ignored.
|
||||||
|
@ -164,14 +164,14 @@ Fully-qualified name of the module to import. The module is expected to have:
|
||||||
- `get_provider_impl(config, deps)`: returns the local implementation
|
- `get_provider_impl(config, deps)`: returns the local implementation
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
provider_data_validator: Optional[str] = Field(
|
provider_data_validator: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class RemoteProviderConfig(BaseModel):
|
class RemoteProviderConfig(BaseModel):
|
||||||
host: str = "localhost"
|
host: str = "localhost"
|
||||||
port: Optional[int] = None
|
port: int | None = None
|
||||||
protocol: str = "http"
|
protocol: str = "http"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -197,7 +197,7 @@ API responses, specify the adapter here.
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def container_image(self) -> Optional[str]:
|
def container_image(self) -> str | None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -205,16 +205,16 @@ API responses, specify the adapter here.
|
||||||
return self.adapter.module
|
return self.adapter.module
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pip_packages(self) -> List[str]:
|
def pip_packages(self) -> list[str]:
|
||||||
return self.adapter.pip_packages
|
return self.adapter.pip_packages
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def provider_data_validator(self) -> Optional[str]:
|
def provider_data_validator(self) -> str | None:
|
||||||
return self.adapter.provider_data_validator
|
return self.adapter.provider_data_validator
|
||||||
|
|
||||||
|
|
||||||
def remote_provider_spec(
|
def remote_provider_spec(
|
||||||
api: Api, adapter: AdapterSpec, api_dependencies: Optional[List[Api]] = None
|
api: Api, adapter: AdapterSpec, api_dependencies: list[Api] | None = None
|
||||||
) -> RemoteProviderSpec:
|
) -> RemoteProviderSpec:
|
||||||
return RemoteProviderSpec(
|
return RemoteProviderSpec(
|
||||||
api=api,
|
api=api,
|
||||||
|
|
|
@ -4,14 +4,14 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api
|
from llama_stack.distribution.datatypes import Api
|
||||||
|
|
||||||
from .config import MetaReferenceAgentsImplConfig
|
from .config import MetaReferenceAgentsImplConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Api, Any]):
|
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Api, Any]):
|
||||||
from .agents import MetaReferenceAgentsImpl
|
from .agents import MetaReferenceAgentsImpl
|
||||||
|
|
||||||
impl = MetaReferenceAgentsImpl(
|
impl = MetaReferenceAgentsImpl(
|
||||||
|
|
|
@ -10,8 +10,8 @@ import re
|
||||||
import secrets
|
import secrets
|
||||||
import string
|
import string
|
||||||
import uuid
|
import uuid
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import AsyncGenerator, List, Optional, Union
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
@ -112,7 +112,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
output_shields=agent_config.output_shields,
|
output_shields=agent_config.output_shields,
|
||||||
)
|
)
|
||||||
|
|
||||||
def turn_to_messages(self, turn: Turn) -> List[Message]:
|
def turn_to_messages(self, turn: Turn) -> list[Message]:
|
||||||
messages = []
|
messages = []
|
||||||
|
|
||||||
# NOTE: if a toolcall response is in a step, we do not add it when processing the input messages
|
# NOTE: if a toolcall response is in a step, we do not add it when processing the input messages
|
||||||
|
@ -161,7 +161,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
async def create_session(self, name: str) -> str:
|
async def create_session(self, name: str) -> str:
|
||||||
return await self.storage.create_session(name)
|
return await self.storage.create_session(name)
|
||||||
|
|
||||||
async def get_messages_from_turns(self, turns: List[Turn]) -> List[Message]:
|
async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]:
|
||||||
messages = []
|
messages = []
|
||||||
if self.agent_config.instructions != "":
|
if self.agent_config.instructions != "":
|
||||||
messages.append(SystemMessage(content=self.agent_config.instructions))
|
messages.append(SystemMessage(content=self.agent_config.instructions))
|
||||||
|
@ -201,8 +201,8 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
async def _run_turn(
|
async def _run_turn(
|
||||||
self,
|
self,
|
||||||
request: Union[AgentTurnCreateRequest, AgentTurnResumeRequest],
|
request: AgentTurnCreateRequest | AgentTurnResumeRequest,
|
||||||
turn_id: Optional[str] = None,
|
turn_id: str | None = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
assert request.stream is True, "Non-streaming not supported"
|
assert request.stream is True, "Non-streaming not supported"
|
||||||
|
|
||||||
|
@ -321,10 +321,10 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
self,
|
self,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
turn_id: str,
|
turn_id: str,
|
||||||
input_messages: List[Message],
|
input_messages: list[Message],
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
documents: Optional[List[Document]] = None,
|
documents: list[Document] | None = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
# Doing async generators makes downstream code much simpler and everything amenable to
|
# Doing async generators makes downstream code much simpler and everything amenable to
|
||||||
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
|
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
|
||||||
|
@ -374,8 +374,8 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
async def run_multiple_shields_wrapper(
|
async def run_multiple_shields_wrapper(
|
||||||
self,
|
self,
|
||||||
turn_id: str,
|
turn_id: str,
|
||||||
messages: List[Message],
|
messages: list[Message],
|
||||||
shields: List[str],
|
shields: list[str],
|
||||||
touchpoint: str,
|
touchpoint: str,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
async with tracing.span("run_shields") as span:
|
async with tracing.span("run_shields") as span:
|
||||||
|
@ -443,10 +443,10 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
self,
|
self,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
turn_id: str,
|
turn_id: str,
|
||||||
input_messages: List[Message],
|
input_messages: list[Message],
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
documents: Optional[List[Document]] = None,
|
documents: list[Document] | None = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
# if document is passed in a turn, we parse the raw text of the document
|
# if document is passed in a turn, we parse the raw text of the document
|
||||||
# and sent it as a user message
|
# and sent it as a user message
|
||||||
|
@ -760,7 +760,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
async def _initialize_tools(
|
async def _initialize_tools(
|
||||||
self,
|
self,
|
||||||
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
toolgroups_for_turn: list[AgentToolGroup] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
toolgroup_to_args = {}
|
toolgroup_to_args = {}
|
||||||
for toolgroup in (self.agent_config.toolgroups or []) + (toolgroups_for_turn or []):
|
for toolgroup in (self.agent_config.toolgroups or []) + (toolgroups_for_turn or []):
|
||||||
|
@ -847,7 +847,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
tool_name_to_args,
|
tool_name_to_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]:
|
def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, str | None]:
|
||||||
"""Parse a toolgroup name into its components.
|
"""Parse a toolgroup name into its components.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -921,7 +921,7 @@ async def get_raw_document_text(document: Document) -> str:
|
||||||
|
|
||||||
def _interpret_content_as_attachment(
|
def _interpret_content_as_attachment(
|
||||||
content: str,
|
content: str,
|
||||||
) -> Optional[Attachment]:
|
) -> Attachment | None:
|
||||||
match = re.search(TOOLS_ATTACHMENT_KEY_REGEX, content)
|
match = re.search(TOOLS_ATTACHMENT_KEY_REGEX, content)
|
||||||
if match:
|
if match:
|
||||||
snippet = match.group(1)
|
snippet = match.group(1)
|
||||||
|
|
|
@ -8,7 +8,7 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import shutil
|
import shutil
|
||||||
import uuid
|
import uuid
|
||||||
from typing import AsyncGenerator, List, Optional, Union
|
from collections.abc import AsyncGenerator
|
||||||
|
|
||||||
from llama_stack.apis.agents import (
|
from llama_stack.apis.agents import (
|
||||||
Agent,
|
Agent,
|
||||||
|
@ -142,16 +142,11 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
messages: List[
|
messages: list[UserMessage | ToolResponseMessage],
|
||||||
Union[
|
toolgroups: list[AgentToolGroup] | None = None,
|
||||||
UserMessage,
|
documents: list[Document] | None = None,
|
||||||
ToolResponseMessage,
|
stream: bool | None = False,
|
||||||
]
|
tool_config: ToolConfig | None = None,
|
||||||
],
|
|
||||||
toolgroups: Optional[List[AgentToolGroup]] = None,
|
|
||||||
documents: Optional[List[Document]] = None,
|
|
||||||
stream: Optional[bool] = False,
|
|
||||||
tool_config: Optional[ToolConfig] = None,
|
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
request = AgentTurnCreateRequest(
|
request = AgentTurnCreateRequest(
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
|
@ -180,8 +175,8 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
turn_id: str,
|
turn_id: str,
|
||||||
tool_responses: List[ToolResponse],
|
tool_responses: list[ToolResponse],
|
||||||
stream: Optional[bool] = False,
|
stream: bool | None = False,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
request = AgentTurnResumeRequest(
|
request = AgentTurnResumeRequest(
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
|
@ -219,7 +214,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
turn_ids: Optional[List[str]] = None,
|
turn_ids: list[str] | None = None,
|
||||||
) -> Session:
|
) -> Session:
|
||||||
agent = await self._get_agent_impl(agent_id)
|
agent = await self._get_agent_impl(agent_id)
|
||||||
session_info = await agent.storage.get_session_info(session_id)
|
session_info = await agent.storage.get_session_info(session_id)
|
||||||
|
@ -265,13 +260,13 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
|
|
||||||
async def create_openai_response(
|
async def create_openai_response(
|
||||||
self,
|
self,
|
||||||
input: Union[str, List[OpenAIResponseInputMessage]],
|
input: str | list[OpenAIResponseInputMessage],
|
||||||
model: str,
|
model: str,
|
||||||
previous_response_id: Optional[str] = None,
|
previous_response_id: str | None = None,
|
||||||
store: Optional[bool] = True,
|
store: bool | None = True,
|
||||||
stream: Optional[bool] = False,
|
stream: bool | None = False,
|
||||||
temperature: Optional[float] = None,
|
temperature: float | None = None,
|
||||||
tools: Optional[List[OpenAIResponseInputTool]] = None,
|
tools: list[OpenAIResponseInputTool] | None = None,
|
||||||
) -> OpenAIResponseObject:
|
) -> OpenAIResponseObject:
|
||||||
return await self.openai_responses_impl.create_openai_response(
|
return await self.openai_responses_impl.create_openai_response(
|
||||||
input, model, previous_response_id, store, stream, temperature, tools
|
input, model, previous_response_id, store, stream, temperature, tools
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@ class MetaReferenceAgentsImplConfig(BaseModel):
|
||||||
persistence_store: KVStoreConfig
|
persistence_store: KVStoreConfig
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
|
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"persistence_store": SqliteKVStoreConfig.sample_run_config(
|
"persistence_store": SqliteKVStoreConfig.sample_run_config(
|
||||||
__distro_dir__=__distro_dir__,
|
__distro_dir__=__distro_dir__,
|
||||||
|
|
|
@ -6,7 +6,8 @@
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from typing import AsyncIterator, List, Optional, Union, cast
|
from collections.abc import AsyncIterator
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
from openai.types.chat import ChatCompletionToolParam
|
from openai.types.chat import ChatCompletionToolParam
|
||||||
|
|
||||||
|
@ -49,15 +50,15 @@ logger = get_logger(name=__name__, category="openai_responses")
|
||||||
OPENAI_RESPONSES_PREFIX = "openai_responses:"
|
OPENAI_RESPONSES_PREFIX = "openai_responses:"
|
||||||
|
|
||||||
|
|
||||||
async def _previous_response_to_messages(previous_response: OpenAIResponseObject) -> List[OpenAIMessageParam]:
|
async def _previous_response_to_messages(previous_response: OpenAIResponseObject) -> list[OpenAIMessageParam]:
|
||||||
messages: List[OpenAIMessageParam] = []
|
messages: list[OpenAIMessageParam] = []
|
||||||
for output_message in previous_response.output:
|
for output_message in previous_response.output:
|
||||||
if isinstance(output_message, OpenAIResponseOutputMessage):
|
if isinstance(output_message, OpenAIResponseOutputMessage):
|
||||||
messages.append(OpenAIAssistantMessageParam(content=output_message.content[0].text))
|
messages.append(OpenAIAssistantMessageParam(content=output_message.content[0].text))
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
async def _openai_choices_to_output_messages(choices: List[OpenAIChoice]) -> List[OpenAIResponseOutputMessage]:
|
async def _openai_choices_to_output_messages(choices: list[OpenAIChoice]) -> list[OpenAIResponseOutputMessage]:
|
||||||
output_messages = []
|
output_messages = []
|
||||||
for choice in choices:
|
for choice in choices:
|
||||||
output_content = ""
|
output_content = ""
|
||||||
|
@ -101,22 +102,22 @@ class OpenAIResponsesImpl:
|
||||||
|
|
||||||
async def create_openai_response(
|
async def create_openai_response(
|
||||||
self,
|
self,
|
||||||
input: Union[str, List[OpenAIResponseInputMessage]],
|
input: str | list[OpenAIResponseInputMessage],
|
||||||
model: str,
|
model: str,
|
||||||
previous_response_id: Optional[str] = None,
|
previous_response_id: str | None = None,
|
||||||
store: Optional[bool] = True,
|
store: bool | None = True,
|
||||||
stream: Optional[bool] = False,
|
stream: bool | None = False,
|
||||||
temperature: Optional[float] = None,
|
temperature: float | None = None,
|
||||||
tools: Optional[List[OpenAIResponseInputTool]] = None,
|
tools: list[OpenAIResponseInputTool] | None = None,
|
||||||
):
|
):
|
||||||
stream = False if stream is None else stream
|
stream = False if stream is None else stream
|
||||||
|
|
||||||
messages: List[OpenAIMessageParam] = []
|
messages: list[OpenAIMessageParam] = []
|
||||||
if previous_response_id:
|
if previous_response_id:
|
||||||
previous_response = await self.get_openai_response(previous_response_id)
|
previous_response = await self.get_openai_response(previous_response_id)
|
||||||
messages.extend(await _previous_response_to_messages(previous_response))
|
messages.extend(await _previous_response_to_messages(previous_response))
|
||||||
# TODO: refactor this user_content parsing out into a separate method
|
# TODO: refactor this user_content parsing out into a separate method
|
||||||
user_content: Union[str, List[OpenAIChatCompletionContentPartParam]] = ""
|
user_content: str | list[OpenAIChatCompletionContentPartParam] = ""
|
||||||
if isinstance(input, list):
|
if isinstance(input, list):
|
||||||
user_content = []
|
user_content = []
|
||||||
for user_input in input:
|
for user_input in input:
|
||||||
|
@ -179,7 +180,7 @@ class OpenAIResponsesImpl:
|
||||||
# dump and reload to map to our pydantic types
|
# dump and reload to map to our pydantic types
|
||||||
chat_response = OpenAIChatCompletion(**chat_response.model_dump())
|
chat_response = OpenAIChatCompletion(**chat_response.model_dump())
|
||||||
|
|
||||||
output_messages: List[OpenAIResponseOutput] = []
|
output_messages: list[OpenAIResponseOutput] = []
|
||||||
if chat_response.choices[0].message.tool_calls:
|
if chat_response.choices[0].message.tool_calls:
|
||||||
output_messages.extend(
|
output_messages.extend(
|
||||||
await self._execute_tool_and_return_final_output(model, stream, chat_response, messages, temperature)
|
await self._execute_tool_and_return_final_output(model, stream, chat_response, messages, temperature)
|
||||||
|
@ -215,9 +216,9 @@ class OpenAIResponsesImpl:
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def _convert_response_tools_to_chat_tools(
|
async def _convert_response_tools_to_chat_tools(
|
||||||
self, tools: List[OpenAIResponseInputTool]
|
self, tools: list[OpenAIResponseInputTool]
|
||||||
) -> List[ChatCompletionToolParam]:
|
) -> list[ChatCompletionToolParam]:
|
||||||
chat_tools: List[ChatCompletionToolParam] = []
|
chat_tools: list[ChatCompletionToolParam] = []
|
||||||
for input_tool in tools:
|
for input_tool in tools:
|
||||||
# TODO: Handle other tool types
|
# TODO: Handle other tool types
|
||||||
if input_tool.type == "web_search":
|
if input_tool.type == "web_search":
|
||||||
|
@ -247,10 +248,10 @@ class OpenAIResponsesImpl:
|
||||||
model_id: str,
|
model_id: str,
|
||||||
stream: bool,
|
stream: bool,
|
||||||
chat_response: OpenAIChatCompletion,
|
chat_response: OpenAIChatCompletion,
|
||||||
messages: List[OpenAIMessageParam],
|
messages: list[OpenAIMessageParam],
|
||||||
temperature: float,
|
temperature: float,
|
||||||
) -> List[OpenAIResponseOutput]:
|
) -> list[OpenAIResponseOutput]:
|
||||||
output_messages: List[OpenAIResponseOutput] = []
|
output_messages: list[OpenAIResponseOutput] = []
|
||||||
choice = chat_response.choices[0]
|
choice = chat_response.choices[0]
|
||||||
|
|
||||||
# If the choice is not an assistant message, we don't need to execute any tools
|
# If the choice is not an assistant message, we don't need to execute any tools
|
||||||
|
@ -314,7 +315,7 @@ class OpenAIResponsesImpl:
|
||||||
async def _execute_tool_call(
|
async def _execute_tool_call(
|
||||||
self,
|
self,
|
||||||
function: OpenAIChatCompletionToolCallFunction,
|
function: OpenAIChatCompletionToolCallFunction,
|
||||||
) -> Optional[ToolInvocationResult]:
|
) -> ToolInvocationResult | None:
|
||||||
if not function.name:
|
if not function.name:
|
||||||
return None
|
return None
|
||||||
function_args = json.loads(function.arguments) if function.arguments else {}
|
function_args = json.loads(function.arguments) if function.arguments else {}
|
||||||
|
|
|
@ -8,7 +8,6 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -25,9 +24,9 @@ class AgentSessionInfo(BaseModel):
|
||||||
session_id: str
|
session_id: str
|
||||||
session_name: str
|
session_name: str
|
||||||
# TODO: is this used anywhere?
|
# TODO: is this used anywhere?
|
||||||
vector_db_id: Optional[str] = None
|
vector_db_id: str | None = None
|
||||||
started_at: datetime
|
started_at: datetime
|
||||||
access_attributes: Optional[AccessAttributes] = None
|
access_attributes: AccessAttributes | None = None
|
||||||
|
|
||||||
|
|
||||||
class AgentPersistence:
|
class AgentPersistence:
|
||||||
|
@ -55,7 +54,7 @@ class AgentPersistence:
|
||||||
)
|
)
|
||||||
return session_id
|
return session_id
|
||||||
|
|
||||||
async def get_session_info(self, session_id: str) -> Optional[AgentSessionInfo]:
|
async def get_session_info(self, session_id: str) -> AgentSessionInfo | None:
|
||||||
value = await self.kvstore.get(
|
value = await self.kvstore.get(
|
||||||
key=f"session:{self.agent_id}:{session_id}",
|
key=f"session:{self.agent_id}:{session_id}",
|
||||||
)
|
)
|
||||||
|
@ -78,7 +77,7 @@ class AgentPersistence:
|
||||||
|
|
||||||
return check_access(session_info.session_id, session_info.access_attributes, get_auth_attributes())
|
return check_access(session_info.session_id, session_info.access_attributes, get_auth_attributes())
|
||||||
|
|
||||||
async def get_session_if_accessible(self, session_id: str) -> Optional[AgentSessionInfo]:
|
async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None:
|
||||||
"""Get session info if the user has access to it. For internal use by sub-session methods."""
|
"""Get session info if the user has access to it. For internal use by sub-session methods."""
|
||||||
session_info = await self.get_session_info(session_id)
|
session_info = await self.get_session_info(session_id)
|
||||||
if not session_info:
|
if not session_info:
|
||||||
|
@ -106,7 +105,7 @@ class AgentPersistence:
|
||||||
value=turn.model_dump_json(),
|
value=turn.model_dump_json(),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_session_turns(self, session_id: str) -> List[Turn]:
|
async def get_session_turns(self, session_id: str) -> list[Turn]:
|
||||||
if not await self.get_session_if_accessible(session_id):
|
if not await self.get_session_if_accessible(session_id):
|
||||||
raise ValueError(f"Session {session_id} not found or access denied")
|
raise ValueError(f"Session {session_id} not found or access denied")
|
||||||
|
|
||||||
|
@ -125,7 +124,7 @@ class AgentPersistence:
|
||||||
turns.sort(key=lambda x: (x.completed_at or datetime.min))
|
turns.sort(key=lambda x: (x.completed_at or datetime.min))
|
||||||
return turns
|
return turns
|
||||||
|
|
||||||
async def get_session_turn(self, session_id: str, turn_id: str) -> Optional[Turn]:
|
async def get_session_turn(self, session_id: str, turn_id: str) -> Turn | None:
|
||||||
if not await self.get_session_if_accessible(session_id):
|
if not await self.get_session_if_accessible(session_id):
|
||||||
raise ValueError(f"Session {session_id} not found or access denied")
|
raise ValueError(f"Session {session_id} not found or access denied")
|
||||||
|
|
||||||
|
@ -145,7 +144,7 @@ class AgentPersistence:
|
||||||
value=step.model_dump_json(),
|
value=step.model_dump_json(),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> Optional[ToolExecutionStep]:
|
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> ToolExecutionStep | None:
|
||||||
if not await self.get_session_if_accessible(session_id):
|
if not await self.get_session_if_accessible(session_id):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -163,7 +162,7 @@ class AgentPersistence:
|
||||||
value=str(num_infer_iters),
|
value=str(num_infer_iters),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> Optional[int]:
|
async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> int | None:
|
||||||
if not await self.get_session_if_accessible(session_id):
|
if not await self.get_session_if_accessible(session_id):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,6 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import Message
|
from llama_stack.apis.inference import Message
|
||||||
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
||||||
|
@ -25,14 +24,14 @@ class ShieldRunnerMixin:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
safety_api: Safety,
|
safety_api: Safety,
|
||||||
input_shields: List[str] = None,
|
input_shields: list[str] = None,
|
||||||
output_shields: List[str] = None,
|
output_shields: list[str] = None,
|
||||||
):
|
):
|
||||||
self.safety_api = safety_api
|
self.safety_api = safety_api
|
||||||
self.input_shields = input_shields
|
self.input_shields = input_shields
|
||||||
self.output_shields = output_shields
|
self.output_shields = output_shields
|
||||||
|
|
||||||
async def run_multiple_shields(self, messages: List[Message], identifiers: List[str]) -> None:
|
async def run_multiple_shields(self, messages: list[Message], identifiers: list[str]) -> None:
|
||||||
async def run_shield_with_span(identifier: str):
|
async def run_shield_with_span(identifier: str):
|
||||||
async with tracing.span(f"run_shield_{identifier}"):
|
async with tracing.span(f"run_shield_{identifier}"):
|
||||||
return await self.safety_api.run_shield(
|
return await self.safety_api.run_shield(
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue