forked from phoenix-oss/llama-stack-mirror
chore: enable pyupgrade fixes (#1806)
# What does this PR do? The goal of this PR is code base modernization. Schema reflection code needed a minor adjustment to handle UnionTypes and collections.abc.AsyncIterator. (Both are preferred for latest Python releases.) Note to reviewers: almost all changes here are automatically generated by pyupgrade. Some additional unused imports were cleaned up. The only change worth of note can be found under `docs/openapi_generator` and `llama_stack/strong_typing/schema.py` where reflection code was updated to deal with "newer" types. Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
ffe3d0b2cd
commit
9e6561a1ec
319 changed files with 2843 additions and 3033 deletions
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import MetaReferenceAgentsImplConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Api, Any]):
|
||||
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Api, Any]):
|
||||
from .agents import MetaReferenceAgentsImpl
|
||||
|
||||
impl = MetaReferenceAgentsImpl(
|
||||
|
|
|
@ -10,8 +10,8 @@ import re
|
|||
import secrets
|
||||
import string
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime, timezone
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
|
@ -112,7 +112,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
output_shields=agent_config.output_shields,
|
||||
)
|
||||
|
||||
def turn_to_messages(self, turn: Turn) -> List[Message]:
|
||||
def turn_to_messages(self, turn: Turn) -> list[Message]:
|
||||
messages = []
|
||||
|
||||
# NOTE: if a toolcall response is in a step, we do not add it when processing the input messages
|
||||
|
@ -161,7 +161,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
async def create_session(self, name: str) -> str:
|
||||
return await self.storage.create_session(name)
|
||||
|
||||
async def get_messages_from_turns(self, turns: List[Turn]) -> List[Message]:
|
||||
async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]:
|
||||
messages = []
|
||||
if self.agent_config.instructions != "":
|
||||
messages.append(SystemMessage(content=self.agent_config.instructions))
|
||||
|
@ -201,8 +201,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
async def _run_turn(
|
||||
self,
|
||||
request: Union[AgentTurnCreateRequest, AgentTurnResumeRequest],
|
||||
turn_id: Optional[str] = None,
|
||||
request: AgentTurnCreateRequest | AgentTurnResumeRequest,
|
||||
turn_id: str | None = None,
|
||||
) -> AsyncGenerator:
|
||||
assert request.stream is True, "Non-streaming not supported"
|
||||
|
||||
|
@ -321,10 +321,10 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
self,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
input_messages: List[Message],
|
||||
input_messages: list[Message],
|
||||
sampling_params: SamplingParams,
|
||||
stream: bool = False,
|
||||
documents: Optional[List[Document]] = None,
|
||||
documents: list[Document] | None = None,
|
||||
) -> AsyncGenerator:
|
||||
# Doing async generators makes downstream code much simpler and everything amenable to
|
||||
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
|
||||
|
@ -374,8 +374,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
async def run_multiple_shields_wrapper(
|
||||
self,
|
||||
turn_id: str,
|
||||
messages: List[Message],
|
||||
shields: List[str],
|
||||
messages: list[Message],
|
||||
shields: list[str],
|
||||
touchpoint: str,
|
||||
) -> AsyncGenerator:
|
||||
async with tracing.span("run_shields") as span:
|
||||
|
@ -443,10 +443,10 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
self,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
input_messages: List[Message],
|
||||
input_messages: list[Message],
|
||||
sampling_params: SamplingParams,
|
||||
stream: bool = False,
|
||||
documents: Optional[List[Document]] = None,
|
||||
documents: list[Document] | None = None,
|
||||
) -> AsyncGenerator:
|
||||
# if document is passed in a turn, we parse the raw text of the document
|
||||
# and sent it as a user message
|
||||
|
@ -760,7 +760,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
async def _initialize_tools(
|
||||
self,
|
||||
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
||||
toolgroups_for_turn: list[AgentToolGroup] | None = None,
|
||||
) -> None:
|
||||
toolgroup_to_args = {}
|
||||
for toolgroup in (self.agent_config.toolgroups or []) + (toolgroups_for_turn or []):
|
||||
|
@ -847,7 +847,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
tool_name_to_args,
|
||||
)
|
||||
|
||||
def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]:
|
||||
def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, str | None]:
|
||||
"""Parse a toolgroup name into its components.
|
||||
|
||||
Args:
|
||||
|
@ -921,7 +921,7 @@ async def get_raw_document_text(document: Document) -> str:
|
|||
|
||||
def _interpret_content_as_attachment(
|
||||
content: str,
|
||||
) -> Optional[Attachment]:
|
||||
) -> Attachment | None:
|
||||
match = re.search(TOOLS_ATTACHMENT_KEY_REGEX, content)
|
||||
if match:
|
||||
snippet = match.group(1)
|
||||
|
|
|
@ -8,7 +8,7 @@ import json
|
|||
import logging
|
||||
import shutil
|
||||
import uuid
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
Agent,
|
||||
|
@ -142,16 +142,11 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
messages: List[
|
||||
Union[
|
||||
UserMessage,
|
||||
ToolResponseMessage,
|
||||
]
|
||||
],
|
||||
toolgroups: Optional[List[AgentToolGroup]] = None,
|
||||
documents: Optional[List[Document]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
messages: list[UserMessage | ToolResponseMessage],
|
||||
toolgroups: list[AgentToolGroup] | None = None,
|
||||
documents: list[Document] | None = None,
|
||||
stream: bool | None = False,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
request = AgentTurnCreateRequest(
|
||||
agent_id=agent_id,
|
||||
|
@ -180,8 +175,8 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
tool_responses: List[ToolResponse],
|
||||
stream: Optional[bool] = False,
|
||||
tool_responses: list[ToolResponse],
|
||||
stream: bool | None = False,
|
||||
) -> AsyncGenerator:
|
||||
request = AgentTurnResumeRequest(
|
||||
agent_id=agent_id,
|
||||
|
@ -219,7 +214,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_ids: Optional[List[str]] = None,
|
||||
turn_ids: list[str] | None = None,
|
||||
) -> Session:
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
session_info = await agent.storage.get_session_info(session_id)
|
||||
|
@ -265,13 +260,13 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
|
||||
async def create_openai_response(
|
||||
self,
|
||||
input: Union[str, List[OpenAIResponseInputMessage]],
|
||||
input: str | list[OpenAIResponseInputMessage],
|
||||
model: str,
|
||||
previous_response_id: Optional[str] = None,
|
||||
store: Optional[bool] = True,
|
||||
stream: Optional[bool] = False,
|
||||
temperature: Optional[float] = None,
|
||||
tools: Optional[List[OpenAIResponseInputTool]] = None,
|
||||
previous_response_id: str | None = None,
|
||||
store: bool | None = True,
|
||||
stream: bool | None = False,
|
||||
temperature: float | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
) -> OpenAIResponseObject:
|
||||
return await self.openai_responses_impl.create_openai_response(
|
||||
input, model, previous_response_id, store, stream, temperature, tools
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -16,7 +16,7 @@ class MetaReferenceAgentsImplConfig(BaseModel):
|
|||
persistence_store: KVStoreConfig
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||
return {
|
||||
"persistence_store": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
|
|
|
@ -6,7 +6,8 @@
|
|||
|
||||
import json
|
||||
import uuid
|
||||
from typing import AsyncIterator, List, Optional, Union, cast
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import cast
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
|
@ -49,15 +50,15 @@ logger = get_logger(name=__name__, category="openai_responses")
|
|||
OPENAI_RESPONSES_PREFIX = "openai_responses:"
|
||||
|
||||
|
||||
async def _previous_response_to_messages(previous_response: OpenAIResponseObject) -> List[OpenAIMessageParam]:
|
||||
messages: List[OpenAIMessageParam] = []
|
||||
async def _previous_response_to_messages(previous_response: OpenAIResponseObject) -> list[OpenAIMessageParam]:
|
||||
messages: list[OpenAIMessageParam] = []
|
||||
for output_message in previous_response.output:
|
||||
if isinstance(output_message, OpenAIResponseOutputMessage):
|
||||
messages.append(OpenAIAssistantMessageParam(content=output_message.content[0].text))
|
||||
return messages
|
||||
|
||||
|
||||
async def _openai_choices_to_output_messages(choices: List[OpenAIChoice]) -> List[OpenAIResponseOutputMessage]:
|
||||
async def _openai_choices_to_output_messages(choices: list[OpenAIChoice]) -> list[OpenAIResponseOutputMessage]:
|
||||
output_messages = []
|
||||
for choice in choices:
|
||||
output_content = ""
|
||||
|
@ -101,22 +102,22 @@ class OpenAIResponsesImpl:
|
|||
|
||||
async def create_openai_response(
|
||||
self,
|
||||
input: Union[str, List[OpenAIResponseInputMessage]],
|
||||
input: str | list[OpenAIResponseInputMessage],
|
||||
model: str,
|
||||
previous_response_id: Optional[str] = None,
|
||||
store: Optional[bool] = True,
|
||||
stream: Optional[bool] = False,
|
||||
temperature: Optional[float] = None,
|
||||
tools: Optional[List[OpenAIResponseInputTool]] = None,
|
||||
previous_response_id: str | None = None,
|
||||
store: bool | None = True,
|
||||
stream: bool | None = False,
|
||||
temperature: float | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
):
|
||||
stream = False if stream is None else stream
|
||||
|
||||
messages: List[OpenAIMessageParam] = []
|
||||
messages: list[OpenAIMessageParam] = []
|
||||
if previous_response_id:
|
||||
previous_response = await self.get_openai_response(previous_response_id)
|
||||
messages.extend(await _previous_response_to_messages(previous_response))
|
||||
# TODO: refactor this user_content parsing out into a separate method
|
||||
user_content: Union[str, List[OpenAIChatCompletionContentPartParam]] = ""
|
||||
user_content: str | list[OpenAIChatCompletionContentPartParam] = ""
|
||||
if isinstance(input, list):
|
||||
user_content = []
|
||||
for user_input in input:
|
||||
|
@ -179,7 +180,7 @@ class OpenAIResponsesImpl:
|
|||
# dump and reload to map to our pydantic types
|
||||
chat_response = OpenAIChatCompletion(**chat_response.model_dump())
|
||||
|
||||
output_messages: List[OpenAIResponseOutput] = []
|
||||
output_messages: list[OpenAIResponseOutput] = []
|
||||
if chat_response.choices[0].message.tool_calls:
|
||||
output_messages.extend(
|
||||
await self._execute_tool_and_return_final_output(model, stream, chat_response, messages, temperature)
|
||||
|
@ -215,9 +216,9 @@ class OpenAIResponsesImpl:
|
|||
return response
|
||||
|
||||
async def _convert_response_tools_to_chat_tools(
|
||||
self, tools: List[OpenAIResponseInputTool]
|
||||
) -> List[ChatCompletionToolParam]:
|
||||
chat_tools: List[ChatCompletionToolParam] = []
|
||||
self, tools: list[OpenAIResponseInputTool]
|
||||
) -> list[ChatCompletionToolParam]:
|
||||
chat_tools: list[ChatCompletionToolParam] = []
|
||||
for input_tool in tools:
|
||||
# TODO: Handle other tool types
|
||||
if input_tool.type == "web_search":
|
||||
|
@ -247,10 +248,10 @@ class OpenAIResponsesImpl:
|
|||
model_id: str,
|
||||
stream: bool,
|
||||
chat_response: OpenAIChatCompletion,
|
||||
messages: List[OpenAIMessageParam],
|
||||
messages: list[OpenAIMessageParam],
|
||||
temperature: float,
|
||||
) -> List[OpenAIResponseOutput]:
|
||||
output_messages: List[OpenAIResponseOutput] = []
|
||||
) -> list[OpenAIResponseOutput]:
|
||||
output_messages: list[OpenAIResponseOutput] = []
|
||||
choice = chat_response.choices[0]
|
||||
|
||||
# If the choice is not an assistant message, we don't need to execute any tools
|
||||
|
@ -314,7 +315,7 @@ class OpenAIResponsesImpl:
|
|||
async def _execute_tool_call(
|
||||
self,
|
||||
function: OpenAIChatCompletionToolCallFunction,
|
||||
) -> Optional[ToolInvocationResult]:
|
||||
) -> ToolInvocationResult | None:
|
||||
if not function.name:
|
||||
return None
|
||||
function_args = json.loads(function.arguments) if function.arguments else {}
|
||||
|
|
|
@ -8,7 +8,6 @@ import json
|
|||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -25,9 +24,9 @@ class AgentSessionInfo(BaseModel):
|
|||
session_id: str
|
||||
session_name: str
|
||||
# TODO: is this used anywhere?
|
||||
vector_db_id: Optional[str] = None
|
||||
vector_db_id: str | None = None
|
||||
started_at: datetime
|
||||
access_attributes: Optional[AccessAttributes] = None
|
||||
access_attributes: AccessAttributes | None = None
|
||||
|
||||
|
||||
class AgentPersistence:
|
||||
|
@ -55,7 +54,7 @@ class AgentPersistence:
|
|||
)
|
||||
return session_id
|
||||
|
||||
async def get_session_info(self, session_id: str) -> Optional[AgentSessionInfo]:
|
||||
async def get_session_info(self, session_id: str) -> AgentSessionInfo | None:
|
||||
value = await self.kvstore.get(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
)
|
||||
|
@ -78,7 +77,7 @@ class AgentPersistence:
|
|||
|
||||
return check_access(session_info.session_id, session_info.access_attributes, get_auth_attributes())
|
||||
|
||||
async def get_session_if_accessible(self, session_id: str) -> Optional[AgentSessionInfo]:
|
||||
async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None:
|
||||
"""Get session info if the user has access to it. For internal use by sub-session methods."""
|
||||
session_info = await self.get_session_info(session_id)
|
||||
if not session_info:
|
||||
|
@ -106,7 +105,7 @@ class AgentPersistence:
|
|||
value=turn.model_dump_json(),
|
||||
)
|
||||
|
||||
async def get_session_turns(self, session_id: str) -> List[Turn]:
|
||||
async def get_session_turns(self, session_id: str) -> list[Turn]:
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
raise ValueError(f"Session {session_id} not found or access denied")
|
||||
|
||||
|
@ -125,7 +124,7 @@ class AgentPersistence:
|
|||
turns.sort(key=lambda x: (x.completed_at or datetime.min))
|
||||
return turns
|
||||
|
||||
async def get_session_turn(self, session_id: str, turn_id: str) -> Optional[Turn]:
|
||||
async def get_session_turn(self, session_id: str, turn_id: str) -> Turn | None:
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
raise ValueError(f"Session {session_id} not found or access denied")
|
||||
|
||||
|
@ -145,7 +144,7 @@ class AgentPersistence:
|
|||
value=step.model_dump_json(),
|
||||
)
|
||||
|
||||
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> Optional[ToolExecutionStep]:
|
||||
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> ToolExecutionStep | None:
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
return None
|
||||
|
||||
|
@ -163,7 +162,7 @@ class AgentPersistence:
|
|||
value=str(num_infer_iters),
|
||||
)
|
||||
|
||||
async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> Optional[int]:
|
||||
async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> int | None:
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
return None
|
||||
|
||||
|
|
|
@ -6,7 +6,6 @@
|
|||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
||||
|
@ -25,14 +24,14 @@ class ShieldRunnerMixin:
|
|||
def __init__(
|
||||
self,
|
||||
safety_api: Safety,
|
||||
input_shields: List[str] = None,
|
||||
output_shields: List[str] = None,
|
||||
input_shields: list[str] = None,
|
||||
output_shields: list[str] = None,
|
||||
):
|
||||
self.safety_api = safety_api
|
||||
self.input_shields = input_shields
|
||||
self.output_shields = output_shields
|
||||
|
||||
async def run_multiple_shields(self, messages: List[Message], identifiers: List[str]) -> None:
|
||||
async def run_multiple_shields(self, messages: list[Message], identifiers: list[str]) -> None:
|
||||
async def run_shield_with_span(identifier: str):
|
||||
async with tracing.span(f"run_shield_{identifier}"):
|
||||
return await self.safety_api.run_shield(
|
||||
|
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from .config import LocalFSDatasetIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: LocalFSDatasetIOConfig,
|
||||
_deps: Dict[str, Any],
|
||||
_deps: dict[str, Any],
|
||||
):
|
||||
from .datasetio import LocalFSDatasetIOImpl
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -17,7 +17,7 @@ class LocalFSDatasetIOConfig(BaseModel):
|
|||
kvstore: KVStoreConfig
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
import pandas
|
||||
|
||||
|
@ -92,8 +92,8 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
async def iterrows(
|
||||
self,
|
||||
dataset_id: str,
|
||||
start_index: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
start_index: int | None = None,
|
||||
limit: int | None = None,
|
||||
) -> PaginatedResponse:
|
||||
dataset_def = self.dataset_infos[dataset_id]
|
||||
dataset_impl = PandasDataframeDataset(dataset_def)
|
||||
|
@ -102,7 +102,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
records = dataset_impl.df.to_dict("records")
|
||||
return paginate_records(records, start_index, limit)
|
||||
|
||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
|
||||
dataset_def = self.dataset_infos[dataset_id]
|
||||
dataset_impl = PandasDataframeDataset(dataset_def)
|
||||
await dataset_impl.load()
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
|
@ -12,7 +12,7 @@ from .config import MetaReferenceEvalConfig
|
|||
|
||||
async def get_provider_impl(
|
||||
config: MetaReferenceEvalConfig,
|
||||
deps: Dict[Api, Any],
|
||||
deps: dict[Api, Any],
|
||||
):
|
||||
from .eval import MetaReferenceEvalImpl
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -17,7 +17,7 @@ class MetaReferenceEvalConfig(BaseModel):
|
|||
kvstore: KVStoreConfig
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import json
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
|
@ -105,8 +105,8 @@ class MetaReferenceEvalImpl(
|
|||
return Job(job_id=job_id, status=JobStatus.completed)
|
||||
|
||||
async def _run_agent_generation(
|
||||
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
|
||||
) -> List[Dict[str, Any]]:
|
||||
self, input_rows: list[dict[str, Any]], benchmark_config: BenchmarkConfig
|
||||
) -> list[dict[str, Any]]:
|
||||
candidate = benchmark_config.eval_candidate
|
||||
create_response = await self.agents_api.create_agent(candidate.config)
|
||||
agent_id = create_response.agent_id
|
||||
|
@ -148,8 +148,8 @@ class MetaReferenceEvalImpl(
|
|||
return generations
|
||||
|
||||
async def _run_model_generation(
|
||||
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
|
||||
) -> List[Dict[str, Any]]:
|
||||
self, input_rows: list[dict[str, Any]], benchmark_config: BenchmarkConfig
|
||||
) -> list[dict[str, Any]]:
|
||||
candidate = benchmark_config.eval_candidate
|
||||
assert candidate.sampling_params.max_tokens is not None, "SamplingParams.max_tokens must be provided"
|
||||
|
||||
|
@ -185,8 +185,8 @@ class MetaReferenceEvalImpl(
|
|||
async def evaluate_rows(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
input_rows: list[dict[str, Any]],
|
||||
scoring_functions: list[str],
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse:
|
||||
candidate = benchmark_config.eval_candidate
|
||||
|
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: MetaReferenceInferenceConfig,
|
||||
_deps: Dict[str, Any],
|
||||
_deps: dict[str, Any],
|
||||
):
|
||||
from .inference import MetaReferenceInferenceImpl
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
|
@ -17,11 +17,11 @@ class MetaReferenceInferenceConfig(BaseModel):
|
|||
# the actual inference model id is dtermined by the moddel id in the request
|
||||
# Note: you need to register the model before using it for inference
|
||||
# models in the resouce list in the run.yaml config will be registered automatically
|
||||
model: Optional[str] = None
|
||||
torch_seed: Optional[int] = None
|
||||
model: str | None = None
|
||||
torch_seed: int | None = None
|
||||
max_seq_len: int = 4096
|
||||
max_batch_size: int = 1
|
||||
model_parallel_size: Optional[int] = None
|
||||
model_parallel_size: int | None = None
|
||||
|
||||
# when this is False, we assume that the distributed process group is setup by someone
|
||||
# outside of this code (e.g., when run inside `torchrun`). that is useful for clients
|
||||
|
@ -30,9 +30,9 @@ class MetaReferenceInferenceConfig(BaseModel):
|
|||
|
||||
# By default, the implementation will look at ~/.llama/checkpoints/<model> but you
|
||||
# can override by specifying the directory explicitly
|
||||
checkpoint_dir: Optional[str] = None
|
||||
checkpoint_dir: str | None = None
|
||||
|
||||
quantization: Optional[QuantizationConfig] = None
|
||||
quantization: QuantizationConfig | None = None
|
||||
|
||||
@field_validator("model")
|
||||
@classmethod
|
||||
|
@ -55,7 +55,7 @@ class MetaReferenceInferenceConfig(BaseModel):
|
|||
max_batch_size: str = "${env.MAX_BATCH_SIZE:1}",
|
||||
max_seq_len: str = "${env.MAX_SEQ_LEN:4096}",
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"model": model,
|
||||
"checkpoint_dir": checkpoint_dir,
|
||||
|
|
|
@ -5,7 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from typing import Generator, List, Optional, Tuple
|
||||
from collections.abc import Generator
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
||||
|
@ -39,7 +40,7 @@ Tokenizer = Llama4Tokenizer | Llama3Tokenizer
|
|||
class LogitsProcessor:
|
||||
def __init__(self, token_enforcer: TokenEnforcer):
|
||||
self.token_enforcer = token_enforcer
|
||||
self.mask: Optional[torch.Tensor] = None
|
||||
self.mask: torch.Tensor | None = None
|
||||
|
||||
def __call__(self, tokens: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||
token_sequence = tokens[0, :].tolist()
|
||||
|
@ -58,7 +59,7 @@ class LogitsProcessor:
|
|||
def get_logits_processor(
|
||||
tokenizer: Tokenizer,
|
||||
vocab_size: int,
|
||||
response_format: Optional[ResponseFormat],
|
||||
response_format: ResponseFormat | None,
|
||||
) -> Optional["LogitsProcessor"]:
|
||||
if response_format is None:
|
||||
return None
|
||||
|
@ -76,7 +77,7 @@ def get_logits_processor(
|
|||
return LogitsProcessor(token_enforcer)
|
||||
|
||||
|
||||
def _build_regular_tokens_list(tokenizer: Tokenizer, vocab_size: int) -> List[Tuple[int, str, bool]]:
|
||||
def _build_regular_tokens_list(tokenizer: Tokenizer, vocab_size: int) -> list[tuple[int, str, bool]]:
|
||||
token_0 = tokenizer.encode("0", bos=False, eos=False)[-1]
|
||||
regular_tokens = []
|
||||
|
||||
|
@ -158,7 +159,7 @@ class LlamaGenerator:
|
|||
|
||||
def completion(
|
||||
self,
|
||||
request_batch: List[CompletionRequestWithRawContent],
|
||||
request_batch: list[CompletionRequestWithRawContent],
|
||||
) -> Generator:
|
||||
first_request = request_batch[0]
|
||||
sampling_params = first_request.sampling_params or SamplingParams()
|
||||
|
@ -167,7 +168,7 @@ class LlamaGenerator:
|
|||
max_gen_len = self.args.max_seq_len - 1
|
||||
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
for result in self.inner_generator.generate(
|
||||
yield from self.inner_generator.generate(
|
||||
llm_inputs=[self.formatter.encode_content(request.content) for request in request_batch],
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
|
@ -179,12 +180,11 @@ class LlamaGenerator:
|
|||
self.args.vocab_size,
|
||||
first_request.response_format,
|
||||
),
|
||||
):
|
||||
yield result
|
||||
)
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
request_batch: List[ChatCompletionRequestWithRawContent],
|
||||
request_batch: list[ChatCompletionRequestWithRawContent],
|
||||
) -> Generator:
|
||||
first_request = request_batch[0]
|
||||
sampling_params = first_request.sampling_params or SamplingParams()
|
||||
|
@ -193,7 +193,7 @@ class LlamaGenerator:
|
|||
max_gen_len = self.args.max_seq_len - 1
|
||||
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
for result in self.inner_generator.generate(
|
||||
yield from self.inner_generator.generate(
|
||||
llm_inputs=[
|
||||
self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))
|
||||
for request in request_batch
|
||||
|
@ -208,5 +208,4 @@ class LlamaGenerator:
|
|||
self.args.vocab_size,
|
||||
first_request.response_format,
|
||||
),
|
||||
):
|
||||
yield result
|
||||
)
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
|
@ -184,11 +184,11 @@ class MetaReferenceInferenceImpl(
|
|||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> CompletionResponse | CompletionResponseStreamChunk:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
if logprobs:
|
||||
|
@ -215,11 +215,11 @@ class MetaReferenceInferenceImpl(
|
|||
async def batch_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content_batch: List[InterleavedContent],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
content_batch: list[InterleavedContent],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> BatchCompletionResponse:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
|
@ -291,14 +291,14 @@ class MetaReferenceInferenceImpl(
|
|||
for x in impl():
|
||||
yield x
|
||||
|
||||
async def _nonstream_completion(self, request_batch: List[CompletionRequest]) -> List[CompletionResponse]:
|
||||
async def _nonstream_completion(self, request_batch: list[CompletionRequest]) -> list[CompletionResponse]:
|
||||
tokenizer = self.generator.formatter.tokenizer
|
||||
|
||||
first_request = request_batch[0]
|
||||
|
||||
class ItemState(BaseModel):
|
||||
tokens: List[int] = []
|
||||
logprobs: List[TokenLogProbs] = []
|
||||
tokens: list[int] = []
|
||||
logprobs: list[TokenLogProbs] = []
|
||||
stop_reason: StopReason | None = None
|
||||
finished: bool = False
|
||||
|
||||
|
@ -349,15 +349,15 @@ class MetaReferenceInferenceImpl(
|
|||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
messages: list[Message],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
|
@ -395,13 +395,13 @@ class MetaReferenceInferenceImpl(
|
|||
async def batch_chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages_batch: List[List[Message]],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
messages_batch: list[list[Message]],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> BatchChatCompletionResponse:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
|
@ -436,15 +436,15 @@ class MetaReferenceInferenceImpl(
|
|||
return BatchChatCompletionResponse(batch=results)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request_batch: List[ChatCompletionRequest]
|
||||
) -> List[ChatCompletionResponse]:
|
||||
self, request_batch: list[ChatCompletionRequest]
|
||||
) -> list[ChatCompletionResponse]:
|
||||
tokenizer = self.generator.formatter.tokenizer
|
||||
|
||||
first_request = request_batch[0]
|
||||
|
||||
class ItemState(BaseModel):
|
||||
tokens: List[int] = []
|
||||
logprobs: List[TokenLogProbs] = []
|
||||
tokens: list[int] = []
|
||||
logprobs: list[TokenLogProbs] = []
|
||||
stop_reason: StopReason | None = None
|
||||
finished: bool = False
|
||||
|
||||
|
|
|
@ -4,9 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import Callable, Generator
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Generator, List
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
|
||||
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
|
||||
|
@ -82,7 +83,7 @@ class LlamaModelParallelGenerator:
|
|||
|
||||
def completion(
|
||||
self,
|
||||
request_batch: List[CompletionRequestWithRawContent],
|
||||
request_batch: list[CompletionRequestWithRawContent],
|
||||
) -> Generator:
|
||||
req_obj = deepcopy(request_batch)
|
||||
gen = self.group.run_inference(("completion", req_obj))
|
||||
|
@ -90,7 +91,7 @@ class LlamaModelParallelGenerator:
|
|||
|
||||
def chat_completion(
|
||||
self,
|
||||
request_batch: List[ChatCompletionRequestWithRawContent],
|
||||
request_batch: list[ChatCompletionRequestWithRawContent],
|
||||
) -> Generator:
|
||||
req_obj = deepcopy(request_batch)
|
||||
gen = self.group.run_inference(("chat_completion", req_obj))
|
||||
|
|
|
@ -18,8 +18,9 @@ import os
|
|||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Callable, Generator
|
||||
from enum import Enum
|
||||
from typing import Callable, Generator, List, Literal, Optional, Tuple, Union
|
||||
from typing import Annotated, Literal
|
||||
|
||||
import torch
|
||||
import zmq
|
||||
|
@ -30,7 +31,6 @@ from fairscale.nn.model_parallel.initialize import (
|
|||
)
|
||||
from pydantic import BaseModel, Field
|
||||
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.models.llama.datatypes import GenerationResult
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
|
@ -69,15 +69,15 @@ class CancelSentinel(BaseModel):
|
|||
|
||||
class TaskRequest(BaseModel):
|
||||
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
|
||||
task: Tuple[
|
||||
task: tuple[
|
||||
str,
|
||||
List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent],
|
||||
list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent],
|
||||
]
|
||||
|
||||
|
||||
class TaskResponse(BaseModel):
|
||||
type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response
|
||||
result: List[GenerationResult]
|
||||
result: list[GenerationResult]
|
||||
|
||||
|
||||
class ExceptionResponse(BaseModel):
|
||||
|
@ -85,15 +85,9 @@ class ExceptionResponse(BaseModel):
|
|||
error: str
|
||||
|
||||
|
||||
ProcessingMessage = Union[
|
||||
ReadyRequest,
|
||||
ReadyResponse,
|
||||
EndSentinel,
|
||||
CancelSentinel,
|
||||
TaskRequest,
|
||||
TaskResponse,
|
||||
ExceptionResponse,
|
||||
]
|
||||
ProcessingMessage = (
|
||||
ReadyRequest | ReadyResponse | EndSentinel | CancelSentinel | TaskRequest | TaskResponse | ExceptionResponse
|
||||
)
|
||||
|
||||
|
||||
class ProcessingMessageWrapper(BaseModel):
|
||||
|
@ -203,7 +197,7 @@ def maybe_get_work(sock: zmq.Socket):
|
|||
return client_id, message
|
||||
|
||||
|
||||
def maybe_parse_message(maybe_json: Optional[str]) -> Optional[ProcessingMessage]:
|
||||
def maybe_parse_message(maybe_json: str | None) -> ProcessingMessage | None:
|
||||
if maybe_json is None:
|
||||
return None
|
||||
try:
|
||||
|
@ -334,9 +328,9 @@ class ModelParallelProcessGroup:
|
|||
|
||||
def run_inference(
|
||||
self,
|
||||
req: Tuple[
|
||||
req: tuple[
|
||||
str,
|
||||
List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent],
|
||||
list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent],
|
||||
],
|
||||
) -> Generator:
|
||||
assert not self.running, "inference already running"
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.providers.inline.inference.sentence_transformers.config import (
|
||||
SentenceTransformersInferenceConfig,
|
||||
|
@ -13,7 +13,7 @@ from llama_stack.providers.inline.inference.sentence_transformers.config import
|
|||
|
||||
async def get_provider_impl(
|
||||
config: SentenceTransformersInferenceConfig,
|
||||
_deps: Dict[str, Any],
|
||||
_deps: dict[str, Any],
|
||||
):
|
||||
from .sentence_transformers import SentenceTransformersInferenceImpl
|
||||
|
||||
|
|
|
@ -4,12 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SentenceTransformersInferenceConfig(BaseModel):
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
||||
return {}
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
CompletionResponse,
|
||||
|
@ -60,46 +60,46 @@ class SentenceTransformersInferenceImpl(
|
|||
self,
|
||||
model_id: str,
|
||||
content: str,
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, AsyncGenerator]:
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> CompletionResponse | AsyncGenerator:
|
||||
raise ValueError("Sentence transformers don't support completion")
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
messages: list[Message],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
raise ValueError("Sentence transformers don't support chat completion")
|
||||
|
||||
async def batch_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content_batch: List[InterleavedContent],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
content_batch: list[InterleavedContent],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
):
|
||||
raise NotImplementedError("Batch completion is not supported for Sentence Transformers")
|
||||
|
||||
async def batch_chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages_batch: List[List[Message]],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
messages_batch: list[list[Message]],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
):
|
||||
raise NotImplementedError("Batch chat completion is not supported for Sentence Transformers")
|
||||
|
|
|
@ -4,12 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from .config import VLLMConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: VLLMConfig, _deps: Dict[str, Any]):
|
||||
async def get_provider_impl(config: VLLMConfig, _deps: dict[str, Any]):
|
||||
from .vllm import VLLMInferenceImpl
|
||||
|
||||
impl = VLLMInferenceImpl(config)
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
@ -42,7 +42,7 @@ class VLLMConfig(BaseModel):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"tensor_parallel_size": "${env.TENSOR_PARALLEL_SIZE:1}",
|
||||
"max_tokens": "${env.MAX_TOKENS:4096}",
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import vllm
|
||||
|
||||
|
@ -55,8 +54,8 @@ def _merge_context_into_content(message: Message) -> Message: # type: ignore
|
|||
|
||||
|
||||
def _llama_stack_tools_to_openai_tools(
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
) -> List[vllm.entrypoints.openai.protocol.ChatCompletionToolsParam]:
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
) -> list[vllm.entrypoints.openai.protocol.ChatCompletionToolsParam]:
|
||||
"""
|
||||
Convert the list of available tools from Llama Stack's format to vLLM's
|
||||
version of OpenAI's format.
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
import json
|
||||
import re
|
||||
import uuid
|
||||
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
|
||||
# These vLLM modules contain names that overlap with Llama Stack names, so we import
|
||||
# fully-qualified names
|
||||
|
@ -100,7 +100,7 @@ def _random_uuid_str() -> str:
|
|||
|
||||
|
||||
def _response_format_to_guided_decoding_params(
|
||||
response_format: Optional[ResponseFormat], # type: ignore
|
||||
response_format: ResponseFormat | None, # type: ignore
|
||||
) -> vllm.sampling_params.GuidedDecodingParams:
|
||||
"""
|
||||
Translate constrained decoding parameters from Llama Stack's format to vLLM's format.
|
||||
|
@ -131,9 +131,9 @@ def _response_format_to_guided_decoding_params(
|
|||
|
||||
|
||||
def _convert_sampling_params(
|
||||
sampling_params: Optional[SamplingParams],
|
||||
response_format: Optional[ResponseFormat], # type: ignore
|
||||
log_prob_config: Optional[LogProbConfig],
|
||||
sampling_params: SamplingParams | None,
|
||||
response_format: ResponseFormat | None, # type: ignore
|
||||
log_prob_config: LogProbConfig | None,
|
||||
) -> vllm.SamplingParams:
|
||||
"""Convert sampling and constrained decoding configuration from Llama Stack's format to vLLM's
|
||||
format."""
|
||||
|
@ -370,11 +370,11 @@ class VLLMInferenceImpl(
|
|||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> CompletionResponse | AsyncIterator[CompletionResponseStreamChunk]:
|
||||
if model_id not in self.model_ids:
|
||||
raise ValueError(
|
||||
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
|
||||
|
@ -403,25 +403,25 @@ class VLLMInferenceImpl(
|
|||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: List[Message], # type: ignore
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None, # type: ignore
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
messages: list[Message], # type: ignore
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None, # type: ignore
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
|
||||
sampling_params = sampling_params or SamplingParams()
|
||||
if model_id not in self.model_ids:
|
||||
|
@ -605,7 +605,7 @@ class VLLMInferenceImpl(
|
|||
|
||||
async def _chat_completion_for_meta_llama(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||
"""
|
||||
Subroutine that routes chat completions for Meta Llama models through Llama Stack's
|
||||
chat template instead of using vLLM's version of that template. The Llama Stack version
|
||||
|
@ -701,7 +701,7 @@ class VLLMInferenceImpl(
|
|||
# Tool calls come in pieces, but Llama Stack expects them in bigger chunks. We build up
|
||||
# those chunks and output them at the end.
|
||||
# This data structure holds the current set of partial tool calls.
|
||||
index_to_tool_call: Dict[int, Dict] = dict()
|
||||
index_to_tool_call: dict[int, dict] = dict()
|
||||
|
||||
# The Llama Stack event stream must always start with a start event. Use an empty one to
|
||||
# simplify logic below
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
|
@ -15,7 +15,7 @@ from .config import TorchtunePostTrainingConfig
|
|||
|
||||
async def get_provider_impl(
|
||||
config: TorchtunePostTrainingConfig,
|
||||
deps: Dict[Api, Any],
|
||||
deps: dict[Api, Any],
|
||||
):
|
||||
from .post_training import TorchtunePostTrainingImpl
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ import json
|
|||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
|
@ -34,7 +34,7 @@ class TorchtuneCheckpointer:
|
|||
model_id: str,
|
||||
training_algorithm: str,
|
||||
checkpoint_dir: str,
|
||||
checkpoint_files: List[str],
|
||||
checkpoint_files: list[str],
|
||||
output_dir: str,
|
||||
model_type: str,
|
||||
):
|
||||
|
@ -54,11 +54,11 @@ class TorchtuneCheckpointer:
|
|||
# get ckpt paths
|
||||
self._checkpoint_path = Path.joinpath(self._checkpoint_dir, self._checkpoint_file)
|
||||
|
||||
def load_checkpoint(self) -> Dict[str, Any]:
|
||||
def load_checkpoint(self) -> dict[str, Any]:
|
||||
"""
|
||||
Load Meta checkpoint from file. Currently only loading from a single file is supported.
|
||||
"""
|
||||
state_dict: Dict[str, Any] = {}
|
||||
state_dict: dict[str, Any] = {}
|
||||
model_state_dict = safe_torch_load(self._checkpoint_path)
|
||||
if self._model_type == ModelType.LLAMA3_VISION:
|
||||
from torchtune.models.llama3_2_vision._convert_weights import (
|
||||
|
@ -82,7 +82,7 @@ class TorchtuneCheckpointer:
|
|||
|
||||
def save_checkpoint(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
epoch: int,
|
||||
adapter_only: bool = False,
|
||||
checkpoint_format: str | None = None,
|
||||
|
@ -100,7 +100,7 @@ class TorchtuneCheckpointer:
|
|||
def _save_meta_format_checkpoint(
|
||||
self,
|
||||
model_file_path: Path,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
adapter_only: bool = False,
|
||||
) -> None:
|
||||
model_file_path.mkdir(parents=True, exist_ok=True)
|
||||
|
@ -168,7 +168,7 @@ class TorchtuneCheckpointer:
|
|||
def _save_hf_format_checkpoint(
|
||||
self,
|
||||
model_file_path: Path,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
) -> None:
|
||||
# the config.json file contains model params needed for state dict conversion
|
||||
config = json.loads(Path.joinpath(self._checkpoint_dir.parent, "config.json").read_text())
|
||||
|
@ -179,7 +179,7 @@ class TorchtuneCheckpointer:
|
|||
repo_id_path = Path.joinpath(self._checkpoint_dir.parent, REPO_ID_FNAME).with_suffix(".json")
|
||||
self.repo_id = None
|
||||
if repo_id_path.exists():
|
||||
with open(repo_id_path, "r") as json_file:
|
||||
with open(repo_id_path) as json_file:
|
||||
data = json.load(json_file)
|
||||
self.repo_id = data.get("repo_id")
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Callable, Dict
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel
|
||||
|
@ -35,7 +35,7 @@ class ModelConfig(BaseModel):
|
|||
checkpoint_type: str
|
||||
|
||||
|
||||
MODEL_CONFIGS: Dict[str, ModelConfig] = {
|
||||
MODEL_CONFIGS: dict[str, ModelConfig] = {
|
||||
"Llama3.2-3B-Instruct": ModelConfig(
|
||||
model_definition=lora_llama3_2_3b,
|
||||
tokenizer_type=llama3_tokenizer,
|
||||
|
@ -48,7 +48,7 @@ MODEL_CONFIGS: Dict[str, ModelConfig] = {
|
|||
),
|
||||
}
|
||||
|
||||
DATA_FORMATS: Dict[str, Transform] = {
|
||||
DATA_FORMATS: dict[str, Transform] = {
|
||||
"instruct": InputOutputToMessages,
|
||||
"dialog": ShareGPTToMessages,
|
||||
}
|
||||
|
|
|
@ -4,17 +4,17 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class TorchtunePostTrainingConfig(BaseModel):
|
||||
torch_seed: Optional[int] = None
|
||||
checkpoint_format: Optional[Literal["meta", "huggingface"]] = "meta"
|
||||
torch_seed: int | None = None
|
||||
checkpoint_format: Literal["meta", "huggingface"] | None = "meta"
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"checkpoint_format": "meta",
|
||||
}
|
||||
|
|
|
@ -11,7 +11,8 @@
|
|||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from typing import Any, Mapping
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
|
||||
|
||||
|
|
|
@ -10,7 +10,8 @@
|
|||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Mapping
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
|
@ -27,7 +28,7 @@ from llama_stack.providers.inline.post_training.torchtune.datasets.format_adapte
|
|||
class SFTDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
rows: List[Dict[str, Any]],
|
||||
rows: list[dict[str, Any]],
|
||||
message_transform: Transform,
|
||||
model_transform: Transform,
|
||||
dataset_type: str,
|
||||
|
@ -40,11 +41,11 @@ class SFTDataset(Dataset):
|
|||
def __len__(self):
|
||||
return len(self._rows)
|
||||
|
||||
def __getitem__(self, index: int) -> Dict[str, Any]:
|
||||
def __getitem__(self, index: int) -> dict[str, Any]:
|
||||
sample = self._rows[index]
|
||||
return self._prepare_sample(sample)
|
||||
|
||||
def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, Any]:
|
||||
def _prepare_sample(self, sample: Mapping[str, Any]) -> dict[str, Any]:
|
||||
if self._dataset_type == "instruct":
|
||||
sample = llama_stack_instruct_to_torchtune_instruct(sample)
|
||||
elif self._dataset_type == "dialog":
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
|
@ -64,7 +64,7 @@ class TorchtunePostTrainingImpl:
|
|||
)
|
||||
|
||||
@staticmethod
|
||||
def _resources_stats_to_artifact(resources_stats: Dict[str, Any]) -> JobArtifact:
|
||||
def _resources_stats_to_artifact(resources_stats: dict[str, Any]) -> JobArtifact:
|
||||
return JobArtifact(
|
||||
type=TrainingArtifactType.RESOURCES_STATS.value,
|
||||
name=TrainingArtifactType.RESOURCES_STATS.value,
|
||||
|
@ -75,11 +75,11 @@ class TorchtunePostTrainingImpl:
|
|||
self,
|
||||
job_uuid: str,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: Dict[str, Any],
|
||||
logger_config: Dict[str, Any],
|
||||
hyperparam_search_config: dict[str, Any],
|
||||
logger_config: dict[str, Any],
|
||||
model: str,
|
||||
checkpoint_dir: Optional[str],
|
||||
algorithm_config: Optional[AlgorithmConfig],
|
||||
checkpoint_dir: str | None,
|
||||
algorithm_config: AlgorithmConfig | None,
|
||||
) -> PostTrainingJob:
|
||||
if isinstance(algorithm_config, LoraFinetuningConfig):
|
||||
|
||||
|
@ -121,8 +121,8 @@ class TorchtunePostTrainingImpl:
|
|||
finetuned_model: str,
|
||||
algorithm_config: DPOAlignmentConfig,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: Dict[str, Any],
|
||||
logger_config: Dict[str, Any],
|
||||
hyperparam_search_config: dict[str, Any],
|
||||
logger_config: dict[str, Any],
|
||||
) -> PostTrainingJob: ...
|
||||
|
||||
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
|
||||
|
@ -144,7 +144,7 @@ class TorchtunePostTrainingImpl:
|
|||
return data[0] if data else None
|
||||
|
||||
@webmethod(route="/post-training/job/status")
|
||||
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]:
|
||||
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse | None:
|
||||
job = self._scheduler.get_job(job_uuid)
|
||||
|
||||
match job.status:
|
||||
|
@ -175,6 +175,6 @@ class TorchtunePostTrainingImpl:
|
|||
self._scheduler.cancel(job_uuid)
|
||||
|
||||
@webmethod(route="/post-training/job/artifacts")
|
||||
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]:
|
||||
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse | None:
|
||||
job = self._scheduler.get_job(job_uuid)
|
||||
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))
|
||||
|
|
|
@ -11,7 +11,7 @@ import time
|
|||
from datetime import datetime, timezone
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
@ -80,10 +80,10 @@ class LoraFinetuningSingleDevice:
|
|||
config: TorchtunePostTrainingConfig,
|
||||
job_uuid: str,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: Dict[str, Any],
|
||||
logger_config: Dict[str, Any],
|
||||
hyperparam_search_config: dict[str, Any],
|
||||
logger_config: dict[str, Any],
|
||||
model: str,
|
||||
checkpoint_dir: Optional[str],
|
||||
checkpoint_dir: str | None,
|
||||
algorithm_config: LoraFinetuningConfig | QATFinetuningConfig | None,
|
||||
datasetio_api: DatasetIO,
|
||||
datasets_api: Datasets,
|
||||
|
@ -156,7 +156,7 @@ class LoraFinetuningSingleDevice:
|
|||
self.datasets_api = datasets_api
|
||||
|
||||
async def load_checkpoint(self):
|
||||
def get_checkpoint_files(checkpoint_dir: str) -> List[str]:
|
||||
def get_checkpoint_files(checkpoint_dir: str) -> list[str]:
|
||||
try:
|
||||
# List all files in the given directory
|
||||
files = os.listdir(checkpoint_dir)
|
||||
|
@ -250,8 +250,8 @@ class LoraFinetuningSingleDevice:
|
|||
self,
|
||||
enable_activation_checkpointing: bool,
|
||||
enable_activation_offloading: bool,
|
||||
base_model_state_dict: Dict[str, Any],
|
||||
lora_weights_state_dict: Optional[Dict[str, Any]] = None,
|
||||
base_model_state_dict: dict[str, Any],
|
||||
lora_weights_state_dict: dict[str, Any] | None = None,
|
||||
) -> nn.Module:
|
||||
self._lora_rank = self.algorithm_config.rank
|
||||
self._lora_alpha = self.algorithm_config.alpha
|
||||
|
@ -335,7 +335,7 @@ class LoraFinetuningSingleDevice:
|
|||
tokenizer: Llama3Tokenizer,
|
||||
shuffle: bool,
|
||||
batch_size: int,
|
||||
) -> Tuple[DistributedSampler, DataLoader]:
|
||||
) -> tuple[DistributedSampler, DataLoader]:
|
||||
async def fetch_rows(dataset_id: str):
|
||||
return await self.datasetio_api.iterrows(
|
||||
dataset_id=dataset_id,
|
||||
|
@ -430,7 +430,7 @@ class LoraFinetuningSingleDevice:
|
|||
checkpoint_format=self._checkpoint_format,
|
||||
)
|
||||
|
||||
async def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
async def _loss_step(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
# Shape [b, s], needed for the loss not the model
|
||||
labels = batch.pop("labels")
|
||||
# run model
|
||||
|
@ -452,7 +452,7 @@ class LoraFinetuningSingleDevice:
|
|||
|
||||
return loss
|
||||
|
||||
async def train(self) -> Tuple[Dict[str, Any], List[Checkpoint]]:
|
||||
async def train(self) -> tuple[dict[str, Any], list[Checkpoint]]:
|
||||
"""
|
||||
The core training loop.
|
||||
"""
|
||||
|
@ -464,7 +464,7 @@ class LoraFinetuningSingleDevice:
|
|||
|
||||
# training artifacts
|
||||
checkpoints = []
|
||||
memory_stats: Dict[str, Any] = {}
|
||||
memory_stats: dict[str, Any] = {}
|
||||
|
||||
# self.epochs_run should be non-zero when we're resuming from a checkpoint
|
||||
for curr_epoch in range(self.epochs_run, self.total_epochs):
|
||||
|
@ -565,7 +565,7 @@ class LoraFinetuningSingleDevice:
|
|||
|
||||
return (memory_stats, checkpoints)
|
||||
|
||||
async def validation(self) -> Tuple[float, float]:
|
||||
async def validation(self) -> tuple[float, float]:
|
||||
total_loss = 0.0
|
||||
total_tokens = 0
|
||||
log.info("Starting validation...")
|
||||
|
|
|
@ -4,12 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from .config import CodeScannerConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: CodeScannerConfig, deps: Dict[str, Any]):
|
||||
async def get_provider_impl(config: CodeScannerConfig, deps: dict[str, Any]):
|
||||
from .code_scanner import MetaReferenceCodeScannerSafetyImpl
|
||||
|
||||
impl = MetaReferenceCodeScannerSafetyImpl(config, deps)
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import (
|
||||
|
@ -48,8 +48,8 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
|
|||
async def run_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
messages: List[Message],
|
||||
params: Dict[str, Any] = None,
|
||||
messages: list[Message],
|
||||
params: dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
shield = await self.shield_store.get_shield(shield_id)
|
||||
if not shield:
|
||||
|
|
|
@ -4,12 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CodeScannerConfig(BaseModel):
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {}
|
||||
|
|
|
@ -4,12 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from .config import LlamaGuardConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: LlamaGuardConfig, deps: Dict[str, Any]):
|
||||
async def get_provider_impl(config: LlamaGuardConfig, deps: dict[str, Any]):
|
||||
from .llama_guard import LlamaGuardSafetyImpl
|
||||
|
||||
assert isinstance(config, LlamaGuardConfig), f"Unexpected config type: {type(config)}"
|
||||
|
|
|
@ -4,16 +4,16 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LlamaGuardConfig(BaseModel):
|
||||
excluded_categories: List[str] = []
|
||||
excluded_categories: list[str] = []
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"excluded_categories": [],
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import re
|
||||
from string import Template
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
|
||||
from llama_stack.apis.inference import (
|
||||
|
@ -149,8 +149,8 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
async def run_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
messages: List[Message],
|
||||
params: Dict[str, Any] = None,
|
||||
messages: list[Message],
|
||||
params: dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
shield = await self.shield_store.get_shield(shield_id)
|
||||
if not shield:
|
||||
|
@ -177,7 +177,7 @@ class LlamaGuardShield:
|
|||
self,
|
||||
model: str,
|
||||
inference_api: Inference,
|
||||
excluded_categories: Optional[List[str]] = None,
|
||||
excluded_categories: list[str] | None = None,
|
||||
):
|
||||
if excluded_categories is None:
|
||||
excluded_categories = []
|
||||
|
@ -193,7 +193,7 @@ class LlamaGuardShield:
|
|||
self.inference_api = inference_api
|
||||
self.excluded_categories = excluded_categories
|
||||
|
||||
def check_unsafe_response(self, response: str) -> Optional[str]:
|
||||
def check_unsafe_response(self, response: str) -> str | None:
|
||||
match = re.match(r"^unsafe\n(.*)$", response)
|
||||
if match:
|
||||
# extracts the unsafe code
|
||||
|
@ -202,7 +202,7 @@ class LlamaGuardShield:
|
|||
|
||||
return None
|
||||
|
||||
def get_safety_categories(self) -> List[str]:
|
||||
def get_safety_categories(self) -> list[str]:
|
||||
excluded_categories = self.excluded_categories
|
||||
if set(excluded_categories) == set(SAFETY_CATEGORIES_TO_CODE_MAP.values()):
|
||||
excluded_categories = []
|
||||
|
@ -218,7 +218,7 @@ class LlamaGuardShield:
|
|||
|
||||
return final_categories
|
||||
|
||||
def validate_messages(self, messages: List[Message]) -> None:
|
||||
def validate_messages(self, messages: list[Message]) -> None:
|
||||
if len(messages) == 0:
|
||||
raise ValueError("Messages must not be empty")
|
||||
if messages[0].role != Role.user.value:
|
||||
|
@ -229,7 +229,7 @@ class LlamaGuardShield:
|
|||
|
||||
return messages
|
||||
|
||||
async def run(self, messages: List[Message]) -> RunShieldResponse:
|
||||
async def run(self, messages: list[Message]) -> RunShieldResponse:
|
||||
messages = self.validate_messages(messages)
|
||||
|
||||
if self.model == CoreModelId.llama_guard_3_11b_vision.value:
|
||||
|
@ -247,10 +247,10 @@ class LlamaGuardShield:
|
|||
content = content.strip()
|
||||
return self.get_shield_response(content)
|
||||
|
||||
def build_text_shield_input(self, messages: List[Message]) -> UserMessage:
|
||||
def build_text_shield_input(self, messages: list[Message]) -> UserMessage:
|
||||
return UserMessage(content=self.build_prompt(messages))
|
||||
|
||||
def build_vision_shield_input(self, messages: List[Message]) -> UserMessage:
|
||||
def build_vision_shield_input(self, messages: list[Message]) -> UserMessage:
|
||||
conversation = []
|
||||
most_recent_img = None
|
||||
|
||||
|
@ -284,7 +284,7 @@ class LlamaGuardShield:
|
|||
|
||||
return UserMessage(content=prompt)
|
||||
|
||||
def build_prompt(self, messages: List[Message]) -> str:
|
||||
def build_prompt(self, messages: list[Message]) -> str:
|
||||
categories = self.get_safety_categories()
|
||||
categories_str = "\n".join(categories)
|
||||
conversations_str = "\n\n".join(
|
||||
|
|
|
@ -4,12 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from .config import PromptGuardConfig # noqa: F401
|
||||
from .config import PromptGuardConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: PromptGuardConfig, deps: Dict[str, Any]):
|
||||
async def get_provider_impl(config: PromptGuardConfig, deps: dict[str, Any]):
|
||||
from .prompt_guard import PromptGuardSafetyImpl
|
||||
|
||||
impl = PromptGuardSafetyImpl(config, deps)
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
|
@ -26,7 +26,7 @@ class PromptGuardConfig(BaseModel):
|
|||
return v
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"guard_type": "injection",
|
||||
}
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
@ -49,8 +49,8 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
async def run_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
messages: List[Message],
|
||||
params: Dict[str, Any] = None,
|
||||
messages: list[Message],
|
||||
params: dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
shield = await self.shield_store.get_shield(shield_id)
|
||||
if not shield:
|
||||
|
@ -81,7 +81,7 @@ class PromptGuardShield:
|
|||
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(model_dir, device_map=self.device)
|
||||
|
||||
async def run(self, messages: List[Message]) -> RunShieldResponse:
|
||||
async def run(self, messages: list[Message]) -> RunShieldResponse:
|
||||
message = messages[-1]
|
||||
text = interleaved_content_as_str(message.content)
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
|
@ -12,7 +12,7 @@ from .config import BasicScoringConfig
|
|||
|
||||
async def get_provider_impl(
|
||||
config: BasicScoringConfig,
|
||||
deps: Dict[Api, Any],
|
||||
deps: dict[Api, Any],
|
||||
):
|
||||
from .scoring import BasicScoringImpl
|
||||
|
||||
|
|
|
@ -3,12 +3,12 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BasicScoringConfig(BaseModel):
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {}
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
|
@ -66,7 +66,7 @@ class BasicScoringImpl(
|
|||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def list_scoring_functions(self) -> List[ScoringFn]:
|
||||
async def list_scoring_functions(self) -> list[ScoringFn]:
|
||||
scoring_fn_defs_list = [
|
||||
fn_def for impl in self.scoring_fn_id_impls.values() for fn_def in impl.get_supported_scoring_fn_defs()
|
||||
]
|
||||
|
@ -82,7 +82,7 @@ class BasicScoringImpl(
|
|||
async def score_batch(
|
||||
self,
|
||||
dataset_id: str,
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse:
|
||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
|
@ -107,8 +107,8 @@ class BasicScoringImpl(
|
|||
|
||||
async def score(
|
||||
self,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||
input_rows: list[dict[str, Any]],
|
||||
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
||||
) -> ScoreResponse:
|
||||
res = {}
|
||||
for scoring_fn_id in scoring_functions.keys():
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
|
@ -17,7 +17,7 @@ from ..utils.bfcl.checker import ast_checker, is_empty_output
|
|||
from .fn_defs.bfcl import bfcl
|
||||
|
||||
|
||||
def postprocess(x: Dict[str, Any], test_category: str) -> Dict[str, Any]:
|
||||
def postprocess(x: dict[str, Any], test_category: str) -> dict[str, Any]:
|
||||
contain_func_call = False
|
||||
error = None
|
||||
error_type = None
|
||||
|
@ -52,11 +52,11 @@ def postprocess(x: Dict[str, Any], test_category: str) -> Dict[str, Any]:
|
|||
}
|
||||
|
||||
|
||||
def gen_valid(x: Dict[str, Any]) -> Dict[str, float]:
|
||||
def gen_valid(x: dict[str, Any]) -> dict[str, float]:
|
||||
return {"valid": x["valid"]}
|
||||
|
||||
|
||||
def gen_relevance_acc(x: Dict[str, Any]) -> Dict[str, float]:
|
||||
def gen_relevance_acc(x: dict[str, Any]) -> dict[str, float]:
|
||||
# This function serves for both relevance and irrelevance tests, which share the exact opposite logic.
|
||||
# If `test_category` is "irrelevance", the model is expected to output no function call.
|
||||
# No function call means either the AST decoding fails (a error message is generated) or the decoded AST does not contain any function call (such as a empty list, `[]`).
|
||||
|
@ -78,9 +78,9 @@ class BFCLScoringFn(RegisteredBaseScoringFn):
|
|||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = "bfcl",
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
input_row: dict[str, Any],
|
||||
scoring_fn_identifier: str | None = "bfcl",
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> ScoringResultRow:
|
||||
test_category = re.sub(r"_[0-9_-]+$", "", input_row["id"])
|
||||
score_result = postprocess(input_row, test_category)
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
|
@ -228,9 +228,9 @@ class DocVQAScoringFn(RegisteredBaseScoringFn):
|
|||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = "docvqa",
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
input_row: dict[str, Any],
|
||||
scoring_fn_identifier: str | None = "docvqa",
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> ScoringResultRow:
|
||||
expected_answers = json.loads(input_row["expected_answer"])
|
||||
generated_answer = input_row["generated_answer"]
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
|
@ -26,9 +26,9 @@ class EqualityScoringFn(RegisteredBaseScoringFn):
|
|||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = "equality",
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
input_row: dict[str, Any],
|
||||
scoring_fn_identifier: str | None = "equality",
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> ScoringResultRow:
|
||||
assert "expected_answer" in input_row, "Expected answer not found in input row."
|
||||
assert "generated_answer" in input_row, "Generated answer not found in input row."
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
|
@ -28,9 +28,9 @@ class IfEvalScoringFn(RegisteredBaseScoringFn):
|
|||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
input_row: dict[str, Any],
|
||||
scoring_fn_identifier: str | None = None,
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> ScoringResultRow:
|
||||
from ..utils.ifeval_utils import INSTRUCTION_DICT, INSTRUCTION_LIST
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
|
||||
|
@ -28,9 +28,9 @@ class RegexParserMathResponseScoringFn(RegisteredBaseScoringFn):
|
|||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
input_row: dict[str, Any],
|
||||
scoring_fn_identifier: str | None = None,
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> ScoringResultRow:
|
||||
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
|
||||
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
|
||||
|
@ -28,9 +28,9 @@ class RegexParserScoringFn(RegisteredBaseScoringFn):
|
|||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
input_row: dict[str, Any],
|
||||
scoring_fn_identifier: str | None = None,
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> ScoringResultRow:
|
||||
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
|
||||
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
|
@ -26,9 +26,9 @@ class SubsetOfScoringFn(RegisteredBaseScoringFn):
|
|||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = "subset_of",
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
input_row: dict[str, Any],
|
||||
scoring_fn_identifier: str | None = "subset_of",
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> ScoringResultRow:
|
||||
expected_answer = input_row["expected_answer"]
|
||||
generated_answer = input_row["generated_answer"]
|
||||
|
|
|
@ -11,8 +11,8 @@ import logging
|
|||
import random
|
||||
import re
|
||||
import string
|
||||
from collections.abc import Iterable, Sequence
|
||||
from types import MappingProxyType
|
||||
from typing import Dict, Iterable, List, Optional, Sequence, Union
|
||||
|
||||
import emoji
|
||||
import langdetect
|
||||
|
@ -1673,12 +1673,11 @@ def split_chinese_japanese_hindi(lines: str) -> Iterable[str]:
|
|||
The separator for hindi is '।'
|
||||
"""
|
||||
for line in lines.splitlines():
|
||||
for sent in re.findall(
|
||||
yield from re.findall(
|
||||
r"[^!?。\.\!\?\!\?\.\n।]+[!?。\.\!\?\!\?\.\n।]?",
|
||||
line.strip(),
|
||||
flags=re.U,
|
||||
):
|
||||
yield sent
|
||||
)
|
||||
|
||||
|
||||
def count_words_cjk(text: str) -> int:
|
||||
|
@ -1707,7 +1706,7 @@ def count_words_cjk(text: str) -> int:
|
|||
return non_asian_words_cnt + asian_chars_cnt + emoji_cnt
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
@functools.cache
|
||||
def _get_sentence_tokenizer():
|
||||
return nltk.data.load("nltk:tokenizers/punkt/english.pickle")
|
||||
|
||||
|
@ -1719,8 +1718,8 @@ def count_sentences(text):
|
|||
return len(tokenized_sentences)
|
||||
|
||||
|
||||
def get_langid(text: str, lid_path: Optional[str] = None) -> str:
|
||||
line_langs: List[str] = []
|
||||
def get_langid(text: str, lid_path: str | None = None) -> str:
|
||||
line_langs: list[str] = []
|
||||
lines = [line.strip() for line in text.split("\n") if len(line.strip()) >= 4]
|
||||
|
||||
for line in lines:
|
||||
|
@ -1741,7 +1740,7 @@ def generate_keywords(num_keywords):
|
|||
|
||||
|
||||
"""Library of instructions"""
|
||||
_InstructionArgsDtype = Optional[Dict[str, Union[int, str, Sequence[str]]]]
|
||||
_InstructionArgsDtype = dict[str, int | str | Sequence[str]] | None
|
||||
|
||||
_LANGUAGES = LANGUAGE_CODES
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import re
|
||||
from typing import Sequence
|
||||
from collections.abc import Sequence
|
||||
|
||||
from llama_stack.providers.utils.scoring.basic_scoring_utils import time_limit
|
||||
|
||||
|
@ -323,7 +323,7 @@ def _fix_a_slash_b(string: str) -> str:
|
|||
try:
|
||||
ia = int(a)
|
||||
ib = int(b)
|
||||
assert string == "{}/{}".format(ia, ib)
|
||||
assert string == f"{ia}/{ib}"
|
||||
new_string = "\\frac{" + str(ia) + "}{" + str(ib) + "}"
|
||||
return new_string
|
||||
except (ValueError, AssertionError):
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -18,7 +18,7 @@ class BraintrustProviderDataValidator(BaseModel):
|
|||
|
||||
async def get_provider_impl(
|
||||
config: BraintrustScoringConfig,
|
||||
deps: Dict[Api, Any],
|
||||
deps: dict[Api, Any],
|
||||
):
|
||||
from .braintrust import BraintrustScoringImpl
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from autoevals.llm import Factuality
|
||||
from autoevals.ragas import (
|
||||
|
@ -132,7 +132,7 @@ class BraintrustScoringImpl(
|
|||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def list_scoring_functions(self) -> List[ScoringFn]:
|
||||
async def list_scoring_functions(self) -> list[ScoringFn]:
|
||||
scoring_fn_defs_list = list(self.supported_fn_defs_registry.values())
|
||||
for f in scoring_fn_defs_list:
|
||||
assert f.identifier.startswith("braintrust"), (
|
||||
|
@ -159,7 +159,7 @@ class BraintrustScoringImpl(
|
|||
async def score_batch(
|
||||
self,
|
||||
dataset_id: str,
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]],
|
||||
scoring_functions: dict[str, ScoringFnParams | None],
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse:
|
||||
await self.set_api_key()
|
||||
|
@ -181,9 +181,7 @@ class BraintrustScoringImpl(
|
|||
results=res.results,
|
||||
)
|
||||
|
||||
async def score_row(
|
||||
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None
|
||||
) -> ScoringResultRow:
|
||||
async def score_row(self, input_row: dict[str, Any], scoring_fn_identifier: str | None = None) -> ScoringResultRow:
|
||||
validate_row_schema(input_row, get_valid_schemas(Api.scoring.value))
|
||||
await self.set_api_key()
|
||||
assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None"
|
||||
|
@ -203,8 +201,8 @@ class BraintrustScoringImpl(
|
|||
|
||||
async def score(
|
||||
self,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]],
|
||||
input_rows: list[dict[str, Any]],
|
||||
scoring_functions: dict[str, ScoringFnParams | None],
|
||||
) -> ScoreResponse:
|
||||
await self.set_api_key()
|
||||
res = {}
|
||||
|
|
|
@ -3,19 +3,19 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class BraintrustScoringConfig(BaseModel):
|
||||
openai_api_key: Optional[str] = Field(
|
||||
openai_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="The OpenAI API Key",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"openai_api_key": "${env.OPENAI_API_KEY:}",
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
|
@ -12,7 +12,7 @@ from .config import LlmAsJudgeScoringConfig
|
|||
|
||||
async def get_provider_impl(
|
||||
config: LlmAsJudgeScoringConfig,
|
||||
deps: Dict[Api, Any],
|
||||
deps: dict[Api, Any],
|
||||
):
|
||||
from .scoring import LlmAsJudgeScoringImpl
|
||||
|
||||
|
|
|
@ -3,12 +3,12 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LlmAsJudgeScoringConfig(BaseModel):
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {}
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
|
@ -50,7 +50,7 @@ class LlmAsJudgeScoringImpl(
|
|||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def list_scoring_functions(self) -> List[ScoringFn]:
|
||||
async def list_scoring_functions(self) -> list[ScoringFn]:
|
||||
scoring_fn_defs_list = self.llm_as_judge_fn.get_supported_scoring_fn_defs()
|
||||
|
||||
for f in self.llm_as_judge_fn.get_supported_scoring_fn_defs():
|
||||
|
@ -66,7 +66,7 @@ class LlmAsJudgeScoringImpl(
|
|||
async def score_batch(
|
||||
self,
|
||||
dataset_id: str,
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse:
|
||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
|
@ -91,8 +91,8 @@ class LlmAsJudgeScoringImpl(
|
|||
|
||||
async def score(
|
||||
self,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||
input_rows: list[dict[str, Any]],
|
||||
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
||||
) -> ScoreResponse:
|
||||
res = {}
|
||||
for scoring_fn_id in scoring_functions.keys():
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference.inference import Inference, UserMessage
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
|
@ -30,9 +30,9 @@ class LlmAsJudgeScoringFn(RegisteredBaseScoringFn):
|
|||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
input_row: dict[str, Any],
|
||||
scoring_fn_identifier: str | None = None,
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> ScoringResultRow:
|
||||
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
|
||||
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
|
@ -13,7 +13,7 @@ from .config import TelemetryConfig, TelemetrySink
|
|||
__all__ = ["TelemetryConfig", "TelemetrySink"]
|
||||
|
||||
|
||||
async def get_provider_impl(config: TelemetryConfig, deps: Dict[Api, Any]):
|
||||
async def get_provider_impl(config: TelemetryConfig, deps: dict[Api, Any]):
|
||||
from .telemetry import TelemetryAdapter
|
||||
|
||||
impl = TelemetryAdapter(config, deps)
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
@ -33,7 +33,7 @@ class TelemetryConfig(BaseModel):
|
|||
default="",
|
||||
description="The service name to use for telemetry",
|
||||
)
|
||||
sinks: List[TelemetrySink] = Field(
|
||||
sinks: list[TelemetrySink] = Field(
|
||||
default=[TelemetrySink.CONSOLE, TelemetrySink.SQLITE],
|
||||
description="List of telemetry sinks to enable (possible values: otel, sqlite, console)",
|
||||
)
|
||||
|
@ -50,7 +50,7 @@ class TelemetryConfig(BaseModel):
|
|||
return v
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, db_name: str = "trace_store.db") -> Dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str, db_name: str = "trace_store.db") -> dict[str, Any]:
|
||||
return {
|
||||
"service_name": "${env.OTEL_SERVICE_NAME:}",
|
||||
"sinks": "${env.TELEMETRY_SINKS:console,sqlite}",
|
||||
|
|
|
@ -78,7 +78,7 @@ class ConsoleSpanProcessor(SpanProcessor):
|
|||
|
||||
severity = event.attributes.get("severity", "info")
|
||||
message = event.attributes.get("message", event.name)
|
||||
if isinstance(message, (dict, list)):
|
||||
if isinstance(message, dict | list):
|
||||
message = json.dumps(message, indent=2)
|
||||
|
||||
severity_colors = {
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import threading
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from opentelemetry import metrics, trace
|
||||
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
|
||||
|
@ -60,7 +60,7 @@ def is_tracing_enabled(tracer):
|
|||
|
||||
|
||||
class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
||||
def __init__(self, config: TelemetryConfig, deps: Dict[Api, Any]) -> None:
|
||||
def __init__(self, config: TelemetryConfig, deps: dict[Api, Any]) -> None:
|
||||
self.config = config
|
||||
self.datasetio_api = deps.get(Api.datasetio)
|
||||
self.meter = None
|
||||
|
@ -231,10 +231,10 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
|
||||
async def query_traces(
|
||||
self,
|
||||
attribute_filters: Optional[List[QueryCondition]] = None,
|
||||
limit: Optional[int] = 100,
|
||||
offset: Optional[int] = 0,
|
||||
order_by: Optional[List[str]] = None,
|
||||
attribute_filters: list[QueryCondition] | None = None,
|
||||
limit: int | None = 100,
|
||||
offset: int | None = 0,
|
||||
order_by: list[str] | None = None,
|
||||
) -> QueryTracesResponse:
|
||||
return QueryTracesResponse(
|
||||
data=await self.trace_store.query_traces(
|
||||
|
@ -254,8 +254,8 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
async def get_span_tree(
|
||||
self,
|
||||
span_id: str,
|
||||
attributes_to_return: Optional[List[str]] = None,
|
||||
max_depth: Optional[int] = None,
|
||||
attributes_to_return: list[str] | None = None,
|
||||
max_depth: int | None = None,
|
||||
) -> QuerySpanTreeResponse:
|
||||
return QuerySpanTreeResponse(
|
||||
data=await self.trace_store.get_span_tree(
|
||||
|
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from .config import CodeInterpreterToolConfig
|
||||
|
||||
__all__ = ["CodeInterpreterToolConfig", "CodeInterpreterToolRuntimeImpl"]
|
||||
|
||||
|
||||
async def get_provider_impl(config: CodeInterpreterToolConfig, _deps: Dict[str, Any]):
|
||||
async def get_provider_impl(config: CodeInterpreterToolConfig, _deps: dict[str, Any]):
|
||||
from .code_interpreter import CodeInterpreterToolRuntimeImpl
|
||||
|
||||
impl = CodeInterpreterToolRuntimeImpl(config)
|
||||
|
|
|
@ -18,7 +18,6 @@ from dataclasses import dataclass
|
|||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
@ -45,7 +44,7 @@ except:
|
|||
"""
|
||||
|
||||
|
||||
def generate_bwrap_command(bind_dirs: List[str]) -> str:
|
||||
def generate_bwrap_command(bind_dirs: list[str]) -> str:
|
||||
"""
|
||||
Generate the bwrap command string for binding all
|
||||
directories in the current directory read-only.
|
||||
|
@ -71,7 +70,7 @@ class CodeExecutionContext:
|
|||
|
||||
@dataclass
|
||||
class CodeExecutionRequest:
|
||||
scripts: List[str]
|
||||
scripts: list[str]
|
||||
only_last_cell_stdouterr: bool = True
|
||||
only_last_cell_fail: bool = True
|
||||
seed: int = 0
|
||||
|
|
|
@ -9,7 +9,7 @@ import asyncio
|
|||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
|
@ -46,7 +46,7 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
return
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||
) -> ListToolDefsResponse:
|
||||
return ListToolDefsResponse(
|
||||
data=[
|
||||
|
@ -64,7 +64,7 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
]
|
||||
)
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
|
||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
||||
script = kwargs["code"]
|
||||
# Use environment variable to control bwrap usage
|
||||
force_disable_bwrap = os.environ.get("DISABLE_CODE_SANDBOX", "").lower() in ("1", "true", "yes")
|
||||
|
|
|
@ -4,12 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CodeInterpreterToolConfig(BaseModel):
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {}
|
||||
|
|
|
@ -15,7 +15,7 @@ def get_code_env_prefix() -> str:
|
|||
global CODE_ENV_PREFIX
|
||||
|
||||
if CODE_ENV_PREFIX is None:
|
||||
with open(CODE_ENV_PREFIX_FILE, "r") as f:
|
||||
with open(CODE_ENV_PREFIX_FILE) as f:
|
||||
CODE_ENV_PREFIX = f.read()
|
||||
|
||||
return CODE_ENV_PREFIX
|
||||
|
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import RagToolRuntimeConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: RagToolRuntimeConfig, deps: Dict[Api, Any]):
|
||||
async def get_provider_impl(config: RagToolRuntimeConfig, deps: dict[Api, Any]):
|
||||
from .memory import MemoryToolRuntimeImpl
|
||||
|
||||
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference])
|
||||
|
|
|
@ -4,12 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RagToolRuntimeConfig(BaseModel):
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {}
|
||||
|
|
|
@ -8,7 +8,7 @@ import asyncio
|
|||
import logging
|
||||
import secrets
|
||||
import string
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
|
@ -74,7 +74,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
|||
|
||||
async def insert(
|
||||
self,
|
||||
documents: List[RAGDocument],
|
||||
documents: list[RAGDocument],
|
||||
vector_db_id: str,
|
||||
chunk_size_in_tokens: int = 512,
|
||||
) -> None:
|
||||
|
@ -101,8 +101,8 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
|||
async def query(
|
||||
self,
|
||||
content: InterleavedContent,
|
||||
vector_db_ids: List[str],
|
||||
query_config: Optional[RAGQueryConfig] = None,
|
||||
vector_db_ids: list[str],
|
||||
query_config: RAGQueryConfig | None = None,
|
||||
) -> RAGQueryResult:
|
||||
if not vector_db_ids:
|
||||
return RAGQueryResult(content=None)
|
||||
|
@ -123,7 +123,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
|||
)
|
||||
for vector_db_id in vector_db_ids
|
||||
]
|
||||
results: List[QueryChunksResponse] = await asyncio.gather(*tasks)
|
||||
results: list[QueryChunksResponse] = await asyncio.gather(*tasks)
|
||||
chunks = [c for r in results for c in r.chunks]
|
||||
scores = [s for r in results for s in r.scores]
|
||||
|
||||
|
@ -168,7 +168,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
|||
)
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||
) -> ListToolDefsResponse:
|
||||
# Parameters are not listed since these methods are not yet invoked automatically
|
||||
# by the LLM. The method is only implemented so things like /tools can list without
|
||||
|
@ -193,7 +193,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
|||
]
|
||||
)
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
|
||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
||||
vector_db_ids = kwargs.get("vector_db_ids", [])
|
||||
query_config = kwargs.get("query_config")
|
||||
if query_config:
|
||||
|
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import ChromaVectorIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: ChromaVectorIOConfig, deps: Dict[Api, Any]):
|
||||
async def get_provider_impl(config: ChromaVectorIOConfig, deps: dict[Api, Any]):
|
||||
from llama_stack.providers.remote.vector_io.chroma.chroma import (
|
||||
ChromaVectorIOAdapter,
|
||||
)
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -13,5 +13,5 @@ class ChromaVectorIOConfig(BaseModel):
|
|||
db_path: str
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, db_path: str = "${env.CHROMADB_PATH}", **kwargs: Any) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, db_path: str = "${env.CHROMADB_PATH}", **kwargs: Any) -> dict[str, Any]:
|
||||
return {"db_path": db_path}
|
||||
|
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import FaissVectorIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: FaissVectorIOConfig, deps: Dict[Api, Any]):
|
||||
async def get_provider_impl(config: FaissVectorIOConfig, deps: dict[Api, Any]):
|
||||
from .faiss import FaissVectorIOAdapter
|
||||
|
||||
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -20,7 +20,7 @@ class FaissVectorIOConfig(BaseModel):
|
|||
kvstore: KVStoreConfig
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
|
|
|
@ -9,7 +9,7 @@ import base64
|
|||
import io
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
import faiss
|
||||
import numpy as np
|
||||
|
@ -84,7 +84,7 @@ class FaissIndex(EmbeddingIndex):
|
|||
|
||||
await self.kvstore.delete(f"{FAISS_INDEX_PREFIX}{self.bank_id}")
|
||||
|
||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
||||
# Add dimension check
|
||||
embedding_dim = embeddings.shape[1] if len(embeddings.shape) > 1 else embeddings.shape[0]
|
||||
if embedding_dim != self.index.d:
|
||||
|
@ -159,7 +159,7 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
inference_api=self.inference_api,
|
||||
)
|
||||
|
||||
async def list_vector_dbs(self) -> List[VectorDB]:
|
||||
async def list_vector_dbs(self) -> list[VectorDB]:
|
||||
return [i.vector_db for i in self.cache.values()]
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
|
@ -176,8 +176,8 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
chunks: List[Chunk],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
index = self.cache.get(vector_db_id)
|
||||
if index is None:
|
||||
|
@ -189,7 +189,7 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
self,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
index = self.cache.get(vector_db_id)
|
||||
if index is None:
|
||||
|
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import MilvusVectorIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: MilvusVectorIOConfig, deps: Dict[Api, Any]):
|
||||
async def get_provider_impl(config: MilvusVectorIOConfig, deps: dict[Api, Any]):
|
||||
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter
|
||||
|
||||
impl = MilvusVectorIOAdapter(config, deps[Api.inference])
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -16,5 +16,5 @@ class MilvusVectorIOConfig(BaseModel):
|
|||
db_path: str
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {"db_path": "${env.MILVUS_DB_PATH}"}
|
||||
|
|
|
@ -4,14 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import QdrantVectorIOConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: QdrantVectorIOConfig, deps: Dict[Api, ProviderSpec]):
|
||||
async def get_adapter_impl(config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec]):
|
||||
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
|
||||
|
||||
impl = QdrantVectorIOAdapter(config, deps[Api.inference])
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -17,7 +17,7 @@ class QdrantVectorIOConfig(BaseModel):
|
|||
path: str
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||
return {
|
||||
"path": "${env.QDRANT_PATH:~/.llama/" + __distro_dir__ + "}/" + "qdrant.db",
|
||||
}
|
||||
|
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import SQLiteVectorIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: SQLiteVectorIOConfig, deps: Dict[Api, Any]):
|
||||
async def get_provider_impl(config: SQLiteVectorIOConfig, deps: dict[Api, Any]):
|
||||
from .sqlite_vec import SQLiteVecVectorIOAdapter
|
||||
|
||||
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -13,7 +13,7 @@ class SQLiteVectorIOConfig(BaseModel):
|
|||
db_path: str
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||
return {
|
||||
"db_path": "${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + "sqlite_vec.db",
|
||||
}
|
||||
|
|
|
@ -10,7 +10,7 @@ import logging
|
|||
import sqlite3
|
||||
import struct
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import sqlite_vec
|
||||
|
@ -25,7 +25,7 @@ from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, Vect
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def serialize_vector(vector: List[float]) -> bytes:
|
||||
def serialize_vector(vector: list[float]) -> bytes:
|
||||
"""Serialize a list of floats into a compact binary representation."""
|
||||
return struct.pack(f"{len(vector)}f", *vector)
|
||||
|
||||
|
@ -98,7 +98,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
|
||||
await asyncio.to_thread(_drop_tables)
|
||||
|
||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray, batch_size: int = 500):
|
||||
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray, batch_size: int = 500):
|
||||
"""
|
||||
Add new chunks along with their embeddings using batch inserts.
|
||||
For each chunk, we insert its JSON into the metadata table and then insert its
|
||||
|
@ -209,7 +209,7 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
def __init__(self, config, inference_api: Inference) -> None:
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.cache: Dict[str, VectorDBWithIndex] = {}
|
||||
self.cache: dict[str, VectorDBWithIndex] = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
def _setup_connection():
|
||||
|
@ -264,7 +264,7 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.config.db_path, vector_db.identifier)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||
|
||||
async def list_vector_dbs(self) -> List[VectorDB]:
|
||||
async def list_vector_dbs(self) -> list[VectorDB]:
|
||||
return [v.vector_db for v in self.cache.values()]
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
|
@ -286,7 +286,7 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
|
||||
await asyncio.to_thread(_delete_vector_db_from_registry)
|
||||
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: List[Chunk], ttl_seconds: Optional[int] = None) -> None:
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
if vector_db_id not in self.cache:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found. Found: {list(self.cache.keys())}")
|
||||
# The VectorDBWithIndex helper is expected to compute embeddings via the inference_api
|
||||
|
@ -294,7 +294,7 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
await self.cache[vector_db_id].insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self, vector_db_id: str, query: Any, params: Optional[Dict[str, Any]] = None
|
||||
self, vector_db_id: str, query: Any, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
if vector_db_id not in self.cache:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
|
@ -303,5 +303,5 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
|
||||
def generate_chunk_id(document_id: str, chunk_text: str) -> str:
|
||||
"""Generate a unique chunk ID using a hash of document ID and chunk text."""
|
||||
hash_input = f"{document_id}:{chunk_text}".encode("utf-8")
|
||||
hash_input = f"{document_id}:{chunk_text}".encode()
|
||||
return str(uuid.UUID(hashlib.md5(hash_input).hexdigest()))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue