mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-21 03:59:42 +00:00
Merge branch 'main' into nvidia-e2e-notebook
This commit is contained in:
commit
b1d941e1f0
447 changed files with 6462 additions and 64778 deletions
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, List, Optional, Protocol
|
||||
from typing import Any, Protocol
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
@ -22,6 +22,27 @@ from llama_stack.schema_utils import json_schema_type
|
|||
|
||||
|
||||
class ModelsProtocolPrivate(Protocol):
|
||||
"""
|
||||
Protocol for model management.
|
||||
|
||||
This allows users to register their preferred model identifiers.
|
||||
|
||||
Model registration requires -
|
||||
- a provider, used to route the registration request
|
||||
- a model identifier, user's intended name for the model during inference
|
||||
- a provider model identifier, a model identifier supported by the provider
|
||||
|
||||
Providers will only accept registration for provider model ids they support.
|
||||
|
||||
Example,
|
||||
register: provider x my-model-id x provider-model-id
|
||||
-> Error if provider does not support provider-model-id
|
||||
-> Error if my-model-id is already registered
|
||||
-> Success if provider supports provider-model-id
|
||||
inference: my-model-id x ...
|
||||
-> Provider uses provider-model-id for inference
|
||||
"""
|
||||
|
||||
async def register_model(self, model: Model) -> Model: ...
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None: ...
|
||||
|
@ -44,7 +65,7 @@ class DatasetsProtocolPrivate(Protocol):
|
|||
|
||||
|
||||
class ScoringFunctionsProtocolPrivate(Protocol):
|
||||
async def list_scoring_functions(self) -> List[ScoringFn]: ...
|
||||
async def list_scoring_functions(self) -> list[ScoringFn]: ...
|
||||
|
||||
async def register_scoring_function(self, scoring_fn: ScoringFn) -> None: ...
|
||||
|
||||
|
@ -67,24 +88,24 @@ class ProviderSpec(BaseModel):
|
|||
...,
|
||||
description="Fully-qualified classname of the config for this provider",
|
||||
)
|
||||
api_dependencies: List[Api] = Field(
|
||||
api_dependencies: list[Api] = Field(
|
||||
default_factory=list,
|
||||
description="Higher-level API surfaces may depend on other providers to provide their functionality",
|
||||
)
|
||||
optional_api_dependencies: List[Api] = Field(
|
||||
optional_api_dependencies: list[Api] = Field(
|
||||
default_factory=list,
|
||||
)
|
||||
deprecation_warning: Optional[str] = Field(
|
||||
deprecation_warning: str | None = Field(
|
||||
default=None,
|
||||
description="If this provider is deprecated, specify the warning message here",
|
||||
)
|
||||
deprecation_error: Optional[str] = Field(
|
||||
deprecation_error: str | None = Field(
|
||||
default=None,
|
||||
description="If this provider is deprecated and does NOT work, specify the error message here",
|
||||
)
|
||||
|
||||
# used internally by the resolver; this is a hack for now
|
||||
deps__: List[str] = Field(default_factory=list)
|
||||
deps__: list[str] = Field(default_factory=list)
|
||||
|
||||
@property
|
||||
def is_sample(self) -> bool:
|
||||
|
@ -110,25 +131,25 @@ Fully-qualified name of the module to import. The module is expected to have:
|
|||
- `get_adapter_impl(config, deps)`: returns the adapter implementation
|
||||
""",
|
||||
)
|
||||
pip_packages: List[str] = Field(
|
||||
pip_packages: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="The pip dependencies needed for this implementation",
|
||||
)
|
||||
config_class: str = Field(
|
||||
description="Fully-qualified classname of the config for this provider",
|
||||
)
|
||||
provider_data_validator: Optional[str] = Field(
|
||||
provider_data_validator: str | None = Field(
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class InlineProviderSpec(ProviderSpec):
|
||||
pip_packages: List[str] = Field(
|
||||
pip_packages: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="The pip dependencies needed for this implementation",
|
||||
)
|
||||
container_image: Optional[str] = Field(
|
||||
container_image: str | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
The container image to use for this implementation. If one is provided, pip_packages will be ignored.
|
||||
|
@ -143,14 +164,14 @@ Fully-qualified name of the module to import. The module is expected to have:
|
|||
- `get_provider_impl(config, deps)`: returns the local implementation
|
||||
""",
|
||||
)
|
||||
provider_data_validator: Optional[str] = Field(
|
||||
provider_data_validator: str | None = Field(
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
class RemoteProviderConfig(BaseModel):
|
||||
host: str = "localhost"
|
||||
port: Optional[int] = None
|
||||
port: int | None = None
|
||||
protocol: str = "http"
|
||||
|
||||
@property
|
||||
|
@ -176,7 +197,7 @@ API responses, specify the adapter here.
|
|||
)
|
||||
|
||||
@property
|
||||
def container_image(self) -> Optional[str]:
|
||||
def container_image(self) -> str | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
|
@ -184,16 +205,16 @@ API responses, specify the adapter here.
|
|||
return self.adapter.module
|
||||
|
||||
@property
|
||||
def pip_packages(self) -> List[str]:
|
||||
def pip_packages(self) -> list[str]:
|
||||
return self.adapter.pip_packages
|
||||
|
||||
@property
|
||||
def provider_data_validator(self) -> Optional[str]:
|
||||
def provider_data_validator(self) -> str | None:
|
||||
return self.adapter.provider_data_validator
|
||||
|
||||
|
||||
def remote_provider_spec(
|
||||
api: Api, adapter: AdapterSpec, api_dependencies: Optional[List[Api]] = None
|
||||
api: Api, adapter: AdapterSpec, api_dependencies: list[Api] | None = None
|
||||
) -> RemoteProviderSpec:
|
||||
return RemoteProviderSpec(
|
||||
api=api,
|
||||
|
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import MetaReferenceAgentsImplConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Api, Any]):
|
||||
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Api, Any]):
|
||||
from .agents import MetaReferenceAgentsImpl
|
||||
|
||||
impl = MetaReferenceAgentsImpl(
|
||||
|
|
|
@ -10,8 +10,8 @@ import re
|
|||
import secrets
|
||||
import string
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime, timezone
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
|
@ -112,7 +112,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
output_shields=agent_config.output_shields,
|
||||
)
|
||||
|
||||
def turn_to_messages(self, turn: Turn) -> List[Message]:
|
||||
def turn_to_messages(self, turn: Turn) -> list[Message]:
|
||||
messages = []
|
||||
|
||||
# NOTE: if a toolcall response is in a step, we do not add it when processing the input messages
|
||||
|
@ -161,7 +161,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
async def create_session(self, name: str) -> str:
|
||||
return await self.storage.create_session(name)
|
||||
|
||||
async def get_messages_from_turns(self, turns: List[Turn]) -> List[Message]:
|
||||
async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]:
|
||||
messages = []
|
||||
if self.agent_config.instructions != "":
|
||||
messages.append(SystemMessage(content=self.agent_config.instructions))
|
||||
|
@ -201,8 +201,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
async def _run_turn(
|
||||
self,
|
||||
request: Union[AgentTurnCreateRequest, AgentTurnResumeRequest],
|
||||
turn_id: Optional[str] = None,
|
||||
request: AgentTurnCreateRequest | AgentTurnResumeRequest,
|
||||
turn_id: str | None = None,
|
||||
) -> AsyncGenerator:
|
||||
assert request.stream is True, "Non-streaming not supported"
|
||||
|
||||
|
@ -321,10 +321,10 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
self,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
input_messages: List[Message],
|
||||
input_messages: list[Message],
|
||||
sampling_params: SamplingParams,
|
||||
stream: bool = False,
|
||||
documents: Optional[List[Document]] = None,
|
||||
documents: list[Document] | None = None,
|
||||
) -> AsyncGenerator:
|
||||
# Doing async generators makes downstream code much simpler and everything amenable to
|
||||
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
|
||||
|
@ -374,8 +374,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
async def run_multiple_shields_wrapper(
|
||||
self,
|
||||
turn_id: str,
|
||||
messages: List[Message],
|
||||
shields: List[str],
|
||||
messages: list[Message],
|
||||
shields: list[str],
|
||||
touchpoint: str,
|
||||
) -> AsyncGenerator:
|
||||
async with tracing.span("run_shields") as span:
|
||||
|
@ -443,10 +443,10 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
self,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
input_messages: List[Message],
|
||||
input_messages: list[Message],
|
||||
sampling_params: SamplingParams,
|
||||
stream: bool = False,
|
||||
documents: Optional[List[Document]] = None,
|
||||
documents: list[Document] | None = None,
|
||||
) -> AsyncGenerator:
|
||||
# if document is passed in a turn, we parse the raw text of the document
|
||||
# and sent it as a user message
|
||||
|
@ -760,7 +760,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
async def _initialize_tools(
|
||||
self,
|
||||
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
||||
toolgroups_for_turn: list[AgentToolGroup] | None = None,
|
||||
) -> None:
|
||||
toolgroup_to_args = {}
|
||||
for toolgroup in (self.agent_config.toolgroups or []) + (toolgroups_for_turn or []):
|
||||
|
@ -847,7 +847,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
tool_name_to_args,
|
||||
)
|
||||
|
||||
def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]:
|
||||
def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, str | None]:
|
||||
"""Parse a toolgroup name into its components.
|
||||
|
||||
Args:
|
||||
|
@ -921,7 +921,7 @@ async def get_raw_document_text(document: Document) -> str:
|
|||
|
||||
def _interpret_content_as_attachment(
|
||||
content: str,
|
||||
) -> Optional[Attachment]:
|
||||
) -> Attachment | None:
|
||||
match = re.search(TOOLS_ATTACHMENT_KEY_REGEX, content)
|
||||
if match:
|
||||
snippet = match.group(1)
|
||||
|
|
|
@ -6,9 +6,8 @@
|
|||
|
||||
import json
|
||||
import logging
|
||||
import shutil
|
||||
import uuid
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
Agent,
|
||||
|
@ -78,10 +77,6 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
tool_runtime_api=self.tool_runtime_api,
|
||||
)
|
||||
|
||||
# check if "bwrap" is available
|
||||
if not shutil.which("bwrap"):
|
||||
logger.warning("Warning: `bwrap` is not available. Code interpreter tool will not work correctly.")
|
||||
|
||||
async def create_agent(
|
||||
self,
|
||||
agent_config: AgentConfig,
|
||||
|
@ -142,16 +137,11 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
messages: List[
|
||||
Union[
|
||||
UserMessage,
|
||||
ToolResponseMessage,
|
||||
]
|
||||
],
|
||||
toolgroups: Optional[List[AgentToolGroup]] = None,
|
||||
documents: Optional[List[Document]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
messages: list[UserMessage | ToolResponseMessage],
|
||||
toolgroups: list[AgentToolGroup] | None = None,
|
||||
documents: list[Document] | None = None,
|
||||
stream: bool | None = False,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
request = AgentTurnCreateRequest(
|
||||
agent_id=agent_id,
|
||||
|
@ -180,8 +170,8 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
tool_responses: List[ToolResponse],
|
||||
stream: Optional[bool] = False,
|
||||
tool_responses: list[ToolResponse],
|
||||
stream: bool | None = False,
|
||||
) -> AsyncGenerator:
|
||||
request = AgentTurnResumeRequest(
|
||||
agent_id=agent_id,
|
||||
|
@ -219,7 +209,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_ids: Optional[List[str]] = None,
|
||||
turn_ids: list[str] | None = None,
|
||||
) -> Session:
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
session_info = await agent.storage.get_session_info(session_id)
|
||||
|
@ -265,13 +255,14 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
|
||||
async def create_openai_response(
|
||||
self,
|
||||
input: Union[str, List[OpenAIResponseInputMessage]],
|
||||
input: str | list[OpenAIResponseInputMessage],
|
||||
model: str,
|
||||
previous_response_id: Optional[str] = None,
|
||||
store: Optional[bool] = True,
|
||||
stream: Optional[bool] = False,
|
||||
tools: Optional[List[OpenAIResponseInputTool]] = None,
|
||||
previous_response_id: str | None = None,
|
||||
store: bool | None = True,
|
||||
stream: bool | None = False,
|
||||
temperature: float | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
) -> OpenAIResponseObject:
|
||||
return await self.openai_responses_impl.create_openai_response(
|
||||
input, model, previous_response_id, store, stream, tools
|
||||
input, model, previous_response_id, store, stream, temperature, tools
|
||||
)
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -16,7 +16,7 @@ class MetaReferenceAgentsImplConfig(BaseModel):
|
|||
persistence_store: KVStoreConfig
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||
return {
|
||||
"persistence_store": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
|
|
|
@ -6,7 +6,8 @@
|
|||
|
||||
import json
|
||||
import uuid
|
||||
from typing import AsyncIterator, List, Optional, Union, cast
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import cast
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
|
@ -49,15 +50,15 @@ logger = get_logger(name=__name__, category="openai_responses")
|
|||
OPENAI_RESPONSES_PREFIX = "openai_responses:"
|
||||
|
||||
|
||||
async def _previous_response_to_messages(previous_response: OpenAIResponseObject) -> List[OpenAIMessageParam]:
|
||||
messages: List[OpenAIMessageParam] = []
|
||||
async def _previous_response_to_messages(previous_response: OpenAIResponseObject) -> list[OpenAIMessageParam]:
|
||||
messages: list[OpenAIMessageParam] = []
|
||||
for output_message in previous_response.output:
|
||||
if isinstance(output_message, OpenAIResponseOutputMessage):
|
||||
messages.append(OpenAIAssistantMessageParam(content=output_message.content[0].text))
|
||||
return messages
|
||||
|
||||
|
||||
async def _openai_choices_to_output_messages(choices: List[OpenAIChoice]) -> List[OpenAIResponseOutputMessage]:
|
||||
async def _openai_choices_to_output_messages(choices: list[OpenAIChoice]) -> list[OpenAIResponseOutputMessage]:
|
||||
output_messages = []
|
||||
for choice in choices:
|
||||
output_content = ""
|
||||
|
@ -101,21 +102,22 @@ class OpenAIResponsesImpl:
|
|||
|
||||
async def create_openai_response(
|
||||
self,
|
||||
input: Union[str, List[OpenAIResponseInputMessage]],
|
||||
input: str | list[OpenAIResponseInputMessage],
|
||||
model: str,
|
||||
previous_response_id: Optional[str] = None,
|
||||
store: Optional[bool] = True,
|
||||
stream: Optional[bool] = False,
|
||||
tools: Optional[List[OpenAIResponseInputTool]] = None,
|
||||
previous_response_id: str | None = None,
|
||||
store: bool | None = True,
|
||||
stream: bool | None = False,
|
||||
temperature: float | None = None,
|
||||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
):
|
||||
stream = False if stream is None else stream
|
||||
|
||||
messages: List[OpenAIMessageParam] = []
|
||||
messages: list[OpenAIMessageParam] = []
|
||||
if previous_response_id:
|
||||
previous_response = await self.get_openai_response(previous_response_id)
|
||||
messages.extend(await _previous_response_to_messages(previous_response))
|
||||
# TODO: refactor this user_content parsing out into a separate method
|
||||
user_content: Union[str, List[OpenAIChatCompletionContentPartParam]] = ""
|
||||
user_content: str | list[OpenAIChatCompletionContentPartParam] = ""
|
||||
if isinstance(input, list):
|
||||
user_content = []
|
||||
for user_input in input:
|
||||
|
@ -141,6 +143,7 @@ class OpenAIResponsesImpl:
|
|||
messages=messages,
|
||||
tools=chat_tools,
|
||||
stream=stream,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
if stream:
|
||||
|
@ -177,10 +180,10 @@ class OpenAIResponsesImpl:
|
|||
# dump and reload to map to our pydantic types
|
||||
chat_response = OpenAIChatCompletion(**chat_response.model_dump())
|
||||
|
||||
output_messages: List[OpenAIResponseOutput] = []
|
||||
output_messages: list[OpenAIResponseOutput] = []
|
||||
if chat_response.choices[0].message.tool_calls:
|
||||
output_messages.extend(
|
||||
await self._execute_tool_and_return_final_output(model, stream, chat_response, messages)
|
||||
await self._execute_tool_and_return_final_output(model, stream, chat_response, messages, temperature)
|
||||
)
|
||||
else:
|
||||
output_messages.extend(await _openai_choices_to_output_messages(chat_response.choices))
|
||||
|
@ -213,9 +216,9 @@ class OpenAIResponsesImpl:
|
|||
return response
|
||||
|
||||
async def _convert_response_tools_to_chat_tools(
|
||||
self, tools: List[OpenAIResponseInputTool]
|
||||
) -> List[ChatCompletionToolParam]:
|
||||
chat_tools: List[ChatCompletionToolParam] = []
|
||||
self, tools: list[OpenAIResponseInputTool]
|
||||
) -> list[ChatCompletionToolParam]:
|
||||
chat_tools: list[ChatCompletionToolParam] = []
|
||||
for input_tool in tools:
|
||||
# TODO: Handle other tool types
|
||||
if input_tool.type == "web_search":
|
||||
|
@ -241,9 +244,14 @@ class OpenAIResponsesImpl:
|
|||
return chat_tools
|
||||
|
||||
async def _execute_tool_and_return_final_output(
|
||||
self, model_id: str, stream: bool, chat_response: OpenAIChatCompletion, messages: List[OpenAIMessageParam]
|
||||
) -> List[OpenAIResponseOutput]:
|
||||
output_messages: List[OpenAIResponseOutput] = []
|
||||
self,
|
||||
model_id: str,
|
||||
stream: bool,
|
||||
chat_response: OpenAIChatCompletion,
|
||||
messages: list[OpenAIMessageParam],
|
||||
temperature: float,
|
||||
) -> list[OpenAIResponseOutput]:
|
||||
output_messages: list[OpenAIResponseOutput] = []
|
||||
choice = chat_response.choices[0]
|
||||
|
||||
# If the choice is not an assistant message, we don't need to execute any tools
|
||||
|
@ -295,6 +303,7 @@ class OpenAIResponsesImpl:
|
|||
model=model_id,
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
temperature=temperature,
|
||||
)
|
||||
# type cast to appease mypy
|
||||
tool_results_chat_response = cast(OpenAIChatCompletion, tool_results_chat_response)
|
||||
|
@ -306,7 +315,7 @@ class OpenAIResponsesImpl:
|
|||
async def _execute_tool_call(
|
||||
self,
|
||||
function: OpenAIChatCompletionToolCallFunction,
|
||||
) -> Optional[ToolInvocationResult]:
|
||||
) -> ToolInvocationResult | None:
|
||||
if not function.name:
|
||||
return None
|
||||
function_args = json.loads(function.arguments) if function.arguments else {}
|
||||
|
|
|
@ -8,7 +8,6 @@ import json
|
|||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -25,9 +24,9 @@ class AgentSessionInfo(BaseModel):
|
|||
session_id: str
|
||||
session_name: str
|
||||
# TODO: is this used anywhere?
|
||||
vector_db_id: Optional[str] = None
|
||||
vector_db_id: str | None = None
|
||||
started_at: datetime
|
||||
access_attributes: Optional[AccessAttributes] = None
|
||||
access_attributes: AccessAttributes | None = None
|
||||
|
||||
|
||||
class AgentPersistence:
|
||||
|
@ -55,7 +54,7 @@ class AgentPersistence:
|
|||
)
|
||||
return session_id
|
||||
|
||||
async def get_session_info(self, session_id: str) -> Optional[AgentSessionInfo]:
|
||||
async def get_session_info(self, session_id: str) -> AgentSessionInfo | None:
|
||||
value = await self.kvstore.get(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
)
|
||||
|
@ -78,7 +77,7 @@ class AgentPersistence:
|
|||
|
||||
return check_access(session_info.session_id, session_info.access_attributes, get_auth_attributes())
|
||||
|
||||
async def get_session_if_accessible(self, session_id: str) -> Optional[AgentSessionInfo]:
|
||||
async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None:
|
||||
"""Get session info if the user has access to it. For internal use by sub-session methods."""
|
||||
session_info = await self.get_session_info(session_id)
|
||||
if not session_info:
|
||||
|
@ -106,7 +105,7 @@ class AgentPersistence:
|
|||
value=turn.model_dump_json(),
|
||||
)
|
||||
|
||||
async def get_session_turns(self, session_id: str) -> List[Turn]:
|
||||
async def get_session_turns(self, session_id: str) -> list[Turn]:
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
raise ValueError(f"Session {session_id} not found or access denied")
|
||||
|
||||
|
@ -125,7 +124,7 @@ class AgentPersistence:
|
|||
turns.sort(key=lambda x: (x.completed_at or datetime.min))
|
||||
return turns
|
||||
|
||||
async def get_session_turn(self, session_id: str, turn_id: str) -> Optional[Turn]:
|
||||
async def get_session_turn(self, session_id: str, turn_id: str) -> Turn | None:
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
raise ValueError(f"Session {session_id} not found or access denied")
|
||||
|
||||
|
@ -145,7 +144,7 @@ class AgentPersistence:
|
|||
value=step.model_dump_json(),
|
||||
)
|
||||
|
||||
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> Optional[ToolExecutionStep]:
|
||||
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> ToolExecutionStep | None:
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
return None
|
||||
|
||||
|
@ -163,7 +162,7 @@ class AgentPersistence:
|
|||
value=str(num_infer_iters),
|
||||
)
|
||||
|
||||
async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> Optional[int]:
|
||||
async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> int | None:
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
return None
|
||||
|
||||
|
|
|
@ -6,7 +6,6 @@
|
|||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
||||
|
@ -25,14 +24,14 @@ class ShieldRunnerMixin:
|
|||
def __init__(
|
||||
self,
|
||||
safety_api: Safety,
|
||||
input_shields: List[str] = None,
|
||||
output_shields: List[str] = None,
|
||||
input_shields: list[str] = None,
|
||||
output_shields: list[str] = None,
|
||||
):
|
||||
self.safety_api = safety_api
|
||||
self.input_shields = input_shields
|
||||
self.output_shields = output_shields
|
||||
|
||||
async def run_multiple_shields(self, messages: List[Message], identifiers: List[str]) -> None:
|
||||
async def run_multiple_shields(self, messages: list[Message], identifiers: list[str]) -> None:
|
||||
async def run_shield_with_span(identifier: str):
|
||||
async with tracing.span(f"run_shield_{identifier}"):
|
||||
return await self.safety_api.run_shield(
|
||||
|
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from .config import LocalFSDatasetIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: LocalFSDatasetIOConfig,
|
||||
_deps: Dict[str, Any],
|
||||
_deps: dict[str, Any],
|
||||
):
|
||||
from .datasetio import LocalFSDatasetIOImpl
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -17,7 +17,7 @@ class LocalFSDatasetIOConfig(BaseModel):
|
|||
kvstore: KVStoreConfig
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
import pandas
|
||||
|
||||
|
@ -92,8 +92,8 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
async def iterrows(
|
||||
self,
|
||||
dataset_id: str,
|
||||
start_index: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
start_index: int | None = None,
|
||||
limit: int | None = None,
|
||||
) -> PaginatedResponse:
|
||||
dataset_def = self.dataset_infos[dataset_id]
|
||||
dataset_impl = PandasDataframeDataset(dataset_def)
|
||||
|
@ -102,7 +102,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
records = dataset_impl.df.to_dict("records")
|
||||
return paginate_records(records, start_index, limit)
|
||||
|
||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
|
||||
dataset_def = self.dataset_infos[dataset_id]
|
||||
dataset_impl = PandasDataframeDataset(dataset_def)
|
||||
await dataset_impl.load()
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
|
@ -12,7 +12,7 @@ from .config import MetaReferenceEvalConfig
|
|||
|
||||
async def get_provider_impl(
|
||||
config: MetaReferenceEvalConfig,
|
||||
deps: Dict[Api, Any],
|
||||
deps: dict[Api, Any],
|
||||
):
|
||||
from .eval import MetaReferenceEvalImpl
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -17,7 +17,7 @@ class MetaReferenceEvalConfig(BaseModel):
|
|||
kvstore: KVStoreConfig
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import json
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
|
@ -105,8 +105,8 @@ class MetaReferenceEvalImpl(
|
|||
return Job(job_id=job_id, status=JobStatus.completed)
|
||||
|
||||
async def _run_agent_generation(
|
||||
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
|
||||
) -> List[Dict[str, Any]]:
|
||||
self, input_rows: list[dict[str, Any]], benchmark_config: BenchmarkConfig
|
||||
) -> list[dict[str, Any]]:
|
||||
candidate = benchmark_config.eval_candidate
|
||||
create_response = await self.agents_api.create_agent(candidate.config)
|
||||
agent_id = create_response.agent_id
|
||||
|
@ -148,8 +148,8 @@ class MetaReferenceEvalImpl(
|
|||
return generations
|
||||
|
||||
async def _run_model_generation(
|
||||
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
|
||||
) -> List[Dict[str, Any]]:
|
||||
self, input_rows: list[dict[str, Any]], benchmark_config: BenchmarkConfig
|
||||
) -> list[dict[str, Any]]:
|
||||
candidate = benchmark_config.eval_candidate
|
||||
assert candidate.sampling_params.max_tokens is not None, "SamplingParams.max_tokens must be provided"
|
||||
|
||||
|
@ -185,8 +185,8 @@ class MetaReferenceEvalImpl(
|
|||
async def evaluate_rows(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
input_rows: list[dict[str, Any]],
|
||||
scoring_functions: list[str],
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse:
|
||||
candidate = benchmark_config.eval_candidate
|
||||
|
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: MetaReferenceInferenceConfig,
|
||||
_deps: Dict[str, Any],
|
||||
_deps: dict[str, Any],
|
||||
):
|
||||
from .inference import MetaReferenceInferenceImpl
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
|
@ -17,11 +17,11 @@ class MetaReferenceInferenceConfig(BaseModel):
|
|||
# the actual inference model id is dtermined by the moddel id in the request
|
||||
# Note: you need to register the model before using it for inference
|
||||
# models in the resouce list in the run.yaml config will be registered automatically
|
||||
model: Optional[str] = None
|
||||
torch_seed: Optional[int] = None
|
||||
model: str | None = None
|
||||
torch_seed: int | None = None
|
||||
max_seq_len: int = 4096
|
||||
max_batch_size: int = 1
|
||||
model_parallel_size: Optional[int] = None
|
||||
model_parallel_size: int | None = None
|
||||
|
||||
# when this is False, we assume that the distributed process group is setup by someone
|
||||
# outside of this code (e.g., when run inside `torchrun`). that is useful for clients
|
||||
|
@ -30,9 +30,9 @@ class MetaReferenceInferenceConfig(BaseModel):
|
|||
|
||||
# By default, the implementation will look at ~/.llama/checkpoints/<model> but you
|
||||
# can override by specifying the directory explicitly
|
||||
checkpoint_dir: Optional[str] = None
|
||||
checkpoint_dir: str | None = None
|
||||
|
||||
quantization: Optional[QuantizationConfig] = None
|
||||
quantization: QuantizationConfig | None = None
|
||||
|
||||
@field_validator("model")
|
||||
@classmethod
|
||||
|
@ -55,7 +55,7 @@ class MetaReferenceInferenceConfig(BaseModel):
|
|||
max_batch_size: str = "${env.MAX_BATCH_SIZE:1}",
|
||||
max_seq_len: str = "${env.MAX_SEQ_LEN:4096}",
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"model": model,
|
||||
"checkpoint_dir": checkpoint_dir,
|
||||
|
|
|
@ -5,7 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from typing import Generator, List, Optional, Tuple
|
||||
from collections.abc import Generator
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
||||
|
@ -39,7 +40,7 @@ Tokenizer = Llama4Tokenizer | Llama3Tokenizer
|
|||
class LogitsProcessor:
|
||||
def __init__(self, token_enforcer: TokenEnforcer):
|
||||
self.token_enforcer = token_enforcer
|
||||
self.mask: Optional[torch.Tensor] = None
|
||||
self.mask: torch.Tensor | None = None
|
||||
|
||||
def __call__(self, tokens: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||
token_sequence = tokens[0, :].tolist()
|
||||
|
@ -58,7 +59,7 @@ class LogitsProcessor:
|
|||
def get_logits_processor(
|
||||
tokenizer: Tokenizer,
|
||||
vocab_size: int,
|
||||
response_format: Optional[ResponseFormat],
|
||||
response_format: ResponseFormat | None,
|
||||
) -> Optional["LogitsProcessor"]:
|
||||
if response_format is None:
|
||||
return None
|
||||
|
@ -76,7 +77,7 @@ def get_logits_processor(
|
|||
return LogitsProcessor(token_enforcer)
|
||||
|
||||
|
||||
def _build_regular_tokens_list(tokenizer: Tokenizer, vocab_size: int) -> List[Tuple[int, str, bool]]:
|
||||
def _build_regular_tokens_list(tokenizer: Tokenizer, vocab_size: int) -> list[tuple[int, str, bool]]:
|
||||
token_0 = tokenizer.encode("0", bos=False, eos=False)[-1]
|
||||
regular_tokens = []
|
||||
|
||||
|
@ -158,7 +159,7 @@ class LlamaGenerator:
|
|||
|
||||
def completion(
|
||||
self,
|
||||
request_batch: List[CompletionRequestWithRawContent],
|
||||
request_batch: list[CompletionRequestWithRawContent],
|
||||
) -> Generator:
|
||||
first_request = request_batch[0]
|
||||
sampling_params = first_request.sampling_params or SamplingParams()
|
||||
|
@ -167,7 +168,7 @@ class LlamaGenerator:
|
|||
max_gen_len = self.args.max_seq_len - 1
|
||||
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
for result in self.inner_generator.generate(
|
||||
yield from self.inner_generator.generate(
|
||||
llm_inputs=[self.formatter.encode_content(request.content) for request in request_batch],
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
|
@ -179,12 +180,11 @@ class LlamaGenerator:
|
|||
self.args.vocab_size,
|
||||
first_request.response_format,
|
||||
),
|
||||
):
|
||||
yield result
|
||||
)
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
request_batch: List[ChatCompletionRequestWithRawContent],
|
||||
request_batch: list[ChatCompletionRequestWithRawContent],
|
||||
) -> Generator:
|
||||
first_request = request_batch[0]
|
||||
sampling_params = first_request.sampling_params or SamplingParams()
|
||||
|
@ -193,7 +193,7 @@ class LlamaGenerator:
|
|||
max_gen_len = self.args.max_seq_len - 1
|
||||
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
for result in self.inner_generator.generate(
|
||||
yield from self.inner_generator.generate(
|
||||
llm_inputs=[
|
||||
self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))
|
||||
for request in request_batch
|
||||
|
@ -208,5 +208,4 @@ class LlamaGenerator:
|
|||
self.args.vocab_size,
|
||||
first_request.response_format,
|
||||
),
|
||||
):
|
||||
yield result
|
||||
)
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
|
@ -184,11 +184,11 @@ class MetaReferenceInferenceImpl(
|
|||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> CompletionResponse | CompletionResponseStreamChunk:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
if logprobs:
|
||||
|
@ -215,11 +215,11 @@ class MetaReferenceInferenceImpl(
|
|||
async def batch_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content_batch: List[InterleavedContent],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
content_batch: list[InterleavedContent],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> BatchCompletionResponse:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
|
@ -291,14 +291,14 @@ class MetaReferenceInferenceImpl(
|
|||
for x in impl():
|
||||
yield x
|
||||
|
||||
async def _nonstream_completion(self, request_batch: List[CompletionRequest]) -> List[CompletionResponse]:
|
||||
async def _nonstream_completion(self, request_batch: list[CompletionRequest]) -> list[CompletionResponse]:
|
||||
tokenizer = self.generator.formatter.tokenizer
|
||||
|
||||
first_request = request_batch[0]
|
||||
|
||||
class ItemState(BaseModel):
|
||||
tokens: List[int] = []
|
||||
logprobs: List[TokenLogProbs] = []
|
||||
tokens: list[int] = []
|
||||
logprobs: list[TokenLogProbs] = []
|
||||
stop_reason: StopReason | None = None
|
||||
finished: bool = False
|
||||
|
||||
|
@ -349,15 +349,15 @@ class MetaReferenceInferenceImpl(
|
|||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
messages: list[Message],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
|
@ -395,13 +395,13 @@ class MetaReferenceInferenceImpl(
|
|||
async def batch_chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages_batch: List[List[Message]],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
messages_batch: list[list[Message]],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> BatchChatCompletionResponse:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
|
@ -436,15 +436,15 @@ class MetaReferenceInferenceImpl(
|
|||
return BatchChatCompletionResponse(batch=results)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request_batch: List[ChatCompletionRequest]
|
||||
) -> List[ChatCompletionResponse]:
|
||||
self, request_batch: list[ChatCompletionRequest]
|
||||
) -> list[ChatCompletionResponse]:
|
||||
tokenizer = self.generator.formatter.tokenizer
|
||||
|
||||
first_request = request_batch[0]
|
||||
|
||||
class ItemState(BaseModel):
|
||||
tokens: List[int] = []
|
||||
logprobs: List[TokenLogProbs] = []
|
||||
tokens: list[int] = []
|
||||
logprobs: list[TokenLogProbs] = []
|
||||
stop_reason: StopReason | None = None
|
||||
finished: bool = False
|
||||
|
||||
|
|
|
@ -4,9 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import Callable, Generator
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Generator, List
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
|
||||
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
|
||||
|
@ -82,7 +83,7 @@ class LlamaModelParallelGenerator:
|
|||
|
||||
def completion(
|
||||
self,
|
||||
request_batch: List[CompletionRequestWithRawContent],
|
||||
request_batch: list[CompletionRequestWithRawContent],
|
||||
) -> Generator:
|
||||
req_obj = deepcopy(request_batch)
|
||||
gen = self.group.run_inference(("completion", req_obj))
|
||||
|
@ -90,7 +91,7 @@ class LlamaModelParallelGenerator:
|
|||
|
||||
def chat_completion(
|
||||
self,
|
||||
request_batch: List[ChatCompletionRequestWithRawContent],
|
||||
request_batch: list[ChatCompletionRequestWithRawContent],
|
||||
) -> Generator:
|
||||
req_obj = deepcopy(request_batch)
|
||||
gen = self.group.run_inference(("chat_completion", req_obj))
|
||||
|
|
|
@ -18,8 +18,9 @@ import os
|
|||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Callable, Generator
|
||||
from enum import Enum
|
||||
from typing import Callable, Generator, List, Literal, Optional, Tuple, Union
|
||||
from typing import Annotated, Literal
|
||||
|
||||
import torch
|
||||
import zmq
|
||||
|
@ -30,7 +31,6 @@ from fairscale.nn.model_parallel.initialize import (
|
|||
)
|
||||
from pydantic import BaseModel, Field
|
||||
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.models.llama.datatypes import GenerationResult
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
|
@ -69,15 +69,15 @@ class CancelSentinel(BaseModel):
|
|||
|
||||
class TaskRequest(BaseModel):
|
||||
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
|
||||
task: Tuple[
|
||||
task: tuple[
|
||||
str,
|
||||
List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent],
|
||||
list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent],
|
||||
]
|
||||
|
||||
|
||||
class TaskResponse(BaseModel):
|
||||
type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response
|
||||
result: List[GenerationResult]
|
||||
result: list[GenerationResult]
|
||||
|
||||
|
||||
class ExceptionResponse(BaseModel):
|
||||
|
@ -85,15 +85,9 @@ class ExceptionResponse(BaseModel):
|
|||
error: str
|
||||
|
||||
|
||||
ProcessingMessage = Union[
|
||||
ReadyRequest,
|
||||
ReadyResponse,
|
||||
EndSentinel,
|
||||
CancelSentinel,
|
||||
TaskRequest,
|
||||
TaskResponse,
|
||||
ExceptionResponse,
|
||||
]
|
||||
ProcessingMessage = (
|
||||
ReadyRequest | ReadyResponse | EndSentinel | CancelSentinel | TaskRequest | TaskResponse | ExceptionResponse
|
||||
)
|
||||
|
||||
|
||||
class ProcessingMessageWrapper(BaseModel):
|
||||
|
@ -203,7 +197,7 @@ def maybe_get_work(sock: zmq.Socket):
|
|||
return client_id, message
|
||||
|
||||
|
||||
def maybe_parse_message(maybe_json: Optional[str]) -> Optional[ProcessingMessage]:
|
||||
def maybe_parse_message(maybe_json: str | None) -> ProcessingMessage | None:
|
||||
if maybe_json is None:
|
||||
return None
|
||||
try:
|
||||
|
@ -334,9 +328,9 @@ class ModelParallelProcessGroup:
|
|||
|
||||
def run_inference(
|
||||
self,
|
||||
req: Tuple[
|
||||
req: tuple[
|
||||
str,
|
||||
List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent],
|
||||
list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent],
|
||||
],
|
||||
) -> Generator:
|
||||
assert not self.running, "inference already running"
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.providers.inline.inference.sentence_transformers.config import (
|
||||
SentenceTransformersInferenceConfig,
|
||||
|
@ -13,7 +13,7 @@ from llama_stack.providers.inline.inference.sentence_transformers.config import
|
|||
|
||||
async def get_provider_impl(
|
||||
config: SentenceTransformersInferenceConfig,
|
||||
_deps: Dict[str, Any],
|
||||
_deps: dict[str, Any],
|
||||
):
|
||||
from .sentence_transformers import SentenceTransformersInferenceImpl
|
||||
|
||||
|
|
|
@ -4,12 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SentenceTransformersInferenceConfig(BaseModel):
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
||||
return {}
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
CompletionResponse,
|
||||
|
@ -60,46 +60,46 @@ class SentenceTransformersInferenceImpl(
|
|||
self,
|
||||
model_id: str,
|
||||
content: str,
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, AsyncGenerator]:
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> CompletionResponse | AsyncGenerator:
|
||||
raise ValueError("Sentence transformers don't support completion")
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
messages: list[Message],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
raise ValueError("Sentence transformers don't support chat completion")
|
||||
|
||||
async def batch_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content_batch: List[InterleavedContent],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
content_batch: list[InterleavedContent],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
):
|
||||
raise NotImplementedError("Batch completion is not supported for Sentence Transformers")
|
||||
|
||||
async def batch_chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages_batch: List[List[Message]],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
messages_batch: list[list[Message]],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
):
|
||||
raise NotImplementedError("Batch chat completion is not supported for Sentence Transformers")
|
||||
|
|
|
@ -4,12 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from .config import VLLMConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: VLLMConfig, _deps: Dict[str, Any]):
|
||||
async def get_provider_impl(config: VLLMConfig, _deps: dict[str, Any]):
|
||||
from .vllm import VLLMInferenceImpl
|
||||
|
||||
impl = VLLMInferenceImpl(config)
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
@ -42,7 +42,7 @@ class VLLMConfig(BaseModel):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"tensor_parallel_size": "${env.TENSOR_PARALLEL_SIZE:1}",
|
||||
"max_tokens": "${env.MAX_TOKENS:4096}",
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import vllm
|
||||
|
||||
|
@ -55,8 +54,8 @@ def _merge_context_into_content(message: Message) -> Message: # type: ignore
|
|||
|
||||
|
||||
def _llama_stack_tools_to_openai_tools(
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
) -> List[vllm.entrypoints.openai.protocol.ChatCompletionToolsParam]:
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
) -> list[vllm.entrypoints.openai.protocol.ChatCompletionToolsParam]:
|
||||
"""
|
||||
Convert the list of available tools from Llama Stack's format to vLLM's
|
||||
version of OpenAI's format.
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
import json
|
||||
import re
|
||||
import uuid
|
||||
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
|
||||
# These vLLM modules contain names that overlap with Llama Stack names, so we import
|
||||
# fully-qualified names
|
||||
|
@ -100,7 +100,7 @@ def _random_uuid_str() -> str:
|
|||
|
||||
|
||||
def _response_format_to_guided_decoding_params(
|
||||
response_format: Optional[ResponseFormat], # type: ignore
|
||||
response_format: ResponseFormat | None, # type: ignore
|
||||
) -> vllm.sampling_params.GuidedDecodingParams:
|
||||
"""
|
||||
Translate constrained decoding parameters from Llama Stack's format to vLLM's format.
|
||||
|
@ -131,9 +131,9 @@ def _response_format_to_guided_decoding_params(
|
|||
|
||||
|
||||
def _convert_sampling_params(
|
||||
sampling_params: Optional[SamplingParams],
|
||||
response_format: Optional[ResponseFormat], # type: ignore
|
||||
log_prob_config: Optional[LogProbConfig],
|
||||
sampling_params: SamplingParams | None,
|
||||
response_format: ResponseFormat | None, # type: ignore
|
||||
log_prob_config: LogProbConfig | None,
|
||||
) -> vllm.SamplingParams:
|
||||
"""Convert sampling and constrained decoding configuration from Llama Stack's format to vLLM's
|
||||
format."""
|
||||
|
@ -370,11 +370,11 @@ class VLLMInferenceImpl(
|
|||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> CompletionResponse | AsyncIterator[CompletionResponseStreamChunk]:
|
||||
if model_id not in self.model_ids:
|
||||
raise ValueError(
|
||||
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
|
||||
|
@ -403,25 +403,25 @@ class VLLMInferenceImpl(
|
|||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: List[Message], # type: ignore
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None, # type: ignore
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
messages: list[Message], # type: ignore
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None, # type: ignore
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
|
||||
sampling_params = sampling_params or SamplingParams()
|
||||
if model_id not in self.model_ids:
|
||||
|
@ -605,7 +605,7 @@ class VLLMInferenceImpl(
|
|||
|
||||
async def _chat_completion_for_meta_llama(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||
"""
|
||||
Subroutine that routes chat completions for Meta Llama models through Llama Stack's
|
||||
chat template instead of using vLLM's version of that template. The Llama Stack version
|
||||
|
@ -701,7 +701,7 @@ class VLLMInferenceImpl(
|
|||
# Tool calls come in pieces, but Llama Stack expects them in bigger chunks. We build up
|
||||
# those chunks and output them at the end.
|
||||
# This data structure holds the current set of partial tool calls.
|
||||
index_to_tool_call: Dict[int, Dict] = dict()
|
||||
index_to_tool_call: dict[int, dict] = dict()
|
||||
|
||||
# The Llama Stack event stream must always start with a start event. Use an empty one to
|
||||
# simplify logic below
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
|
@ -15,7 +15,7 @@ from .config import TorchtunePostTrainingConfig
|
|||
|
||||
async def get_provider_impl(
|
||||
config: TorchtunePostTrainingConfig,
|
||||
deps: Dict[Api, Any],
|
||||
deps: dict[Api, Any],
|
||||
):
|
||||
from .post_training import TorchtunePostTrainingImpl
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ import json
|
|||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
|
@ -34,7 +34,7 @@ class TorchtuneCheckpointer:
|
|||
model_id: str,
|
||||
training_algorithm: str,
|
||||
checkpoint_dir: str,
|
||||
checkpoint_files: List[str],
|
||||
checkpoint_files: list[str],
|
||||
output_dir: str,
|
||||
model_type: str,
|
||||
):
|
||||
|
@ -54,11 +54,11 @@ class TorchtuneCheckpointer:
|
|||
# get ckpt paths
|
||||
self._checkpoint_path = Path.joinpath(self._checkpoint_dir, self._checkpoint_file)
|
||||
|
||||
def load_checkpoint(self) -> Dict[str, Any]:
|
||||
def load_checkpoint(self) -> dict[str, Any]:
|
||||
"""
|
||||
Load Meta checkpoint from file. Currently only loading from a single file is supported.
|
||||
"""
|
||||
state_dict: Dict[str, Any] = {}
|
||||
state_dict: dict[str, Any] = {}
|
||||
model_state_dict = safe_torch_load(self._checkpoint_path)
|
||||
if self._model_type == ModelType.LLAMA3_VISION:
|
||||
from torchtune.models.llama3_2_vision._convert_weights import (
|
||||
|
@ -82,7 +82,7 @@ class TorchtuneCheckpointer:
|
|||
|
||||
def save_checkpoint(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
epoch: int,
|
||||
adapter_only: bool = False,
|
||||
checkpoint_format: str | None = None,
|
||||
|
@ -100,7 +100,7 @@ class TorchtuneCheckpointer:
|
|||
def _save_meta_format_checkpoint(
|
||||
self,
|
||||
model_file_path: Path,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
adapter_only: bool = False,
|
||||
) -> None:
|
||||
model_file_path.mkdir(parents=True, exist_ok=True)
|
||||
|
@ -168,7 +168,7 @@ class TorchtuneCheckpointer:
|
|||
def _save_hf_format_checkpoint(
|
||||
self,
|
||||
model_file_path: Path,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
) -> None:
|
||||
# the config.json file contains model params needed for state dict conversion
|
||||
config = json.loads(Path.joinpath(self._checkpoint_dir.parent, "config.json").read_text())
|
||||
|
@ -179,7 +179,7 @@ class TorchtuneCheckpointer:
|
|||
repo_id_path = Path.joinpath(self._checkpoint_dir.parent, REPO_ID_FNAME).with_suffix(".json")
|
||||
self.repo_id = None
|
||||
if repo_id_path.exists():
|
||||
with open(repo_id_path, "r") as json_file:
|
||||
with open(repo_id_path) as json_file:
|
||||
data = json.load(json_file)
|
||||
self.repo_id = data.get("repo_id")
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Callable, Dict
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel
|
||||
|
@ -35,7 +35,7 @@ class ModelConfig(BaseModel):
|
|||
checkpoint_type: str
|
||||
|
||||
|
||||
MODEL_CONFIGS: Dict[str, ModelConfig] = {
|
||||
MODEL_CONFIGS: dict[str, ModelConfig] = {
|
||||
"Llama3.2-3B-Instruct": ModelConfig(
|
||||
model_definition=lora_llama3_2_3b,
|
||||
tokenizer_type=llama3_tokenizer,
|
||||
|
@ -48,7 +48,7 @@ MODEL_CONFIGS: Dict[str, ModelConfig] = {
|
|||
),
|
||||
}
|
||||
|
||||
DATA_FORMATS: Dict[str, Transform] = {
|
||||
DATA_FORMATS: dict[str, Transform] = {
|
||||
"instruct": InputOutputToMessages,
|
||||
"dialog": ShareGPTToMessages,
|
||||
}
|
||||
|
|
|
@ -4,17 +4,17 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class TorchtunePostTrainingConfig(BaseModel):
|
||||
torch_seed: Optional[int] = None
|
||||
checkpoint_format: Optional[Literal["meta", "huggingface"]] = "meta"
|
||||
torch_seed: int | None = None
|
||||
checkpoint_format: Literal["meta", "huggingface"] | None = "meta"
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"checkpoint_format": "meta",
|
||||
}
|
||||
|
|
|
@ -11,7 +11,8 @@
|
|||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from typing import Any, Mapping
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
|
||||
|
||||
|
|
|
@ -10,7 +10,8 @@
|
|||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Mapping
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
|
@ -27,7 +28,7 @@ from llama_stack.providers.inline.post_training.torchtune.datasets.format_adapte
|
|||
class SFTDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
rows: List[Dict[str, Any]],
|
||||
rows: list[dict[str, Any]],
|
||||
message_transform: Transform,
|
||||
model_transform: Transform,
|
||||
dataset_type: str,
|
||||
|
@ -40,11 +41,11 @@ class SFTDataset(Dataset):
|
|||
def __len__(self):
|
||||
return len(self._rows)
|
||||
|
||||
def __getitem__(self, index: int) -> Dict[str, Any]:
|
||||
def __getitem__(self, index: int) -> dict[str, Any]:
|
||||
sample = self._rows[index]
|
||||
return self._prepare_sample(sample)
|
||||
|
||||
def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, Any]:
|
||||
def _prepare_sample(self, sample: Mapping[str, Any]) -> dict[str, Any]:
|
||||
if self._dataset_type == "instruct":
|
||||
sample = llama_stack_instruct_to_torchtune_instruct(sample)
|
||||
elif self._dataset_type == "dialog":
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
|
@ -64,7 +64,7 @@ class TorchtunePostTrainingImpl:
|
|||
)
|
||||
|
||||
@staticmethod
|
||||
def _resources_stats_to_artifact(resources_stats: Dict[str, Any]) -> JobArtifact:
|
||||
def _resources_stats_to_artifact(resources_stats: dict[str, Any]) -> JobArtifact:
|
||||
return JobArtifact(
|
||||
type=TrainingArtifactType.RESOURCES_STATS.value,
|
||||
name=TrainingArtifactType.RESOURCES_STATS.value,
|
||||
|
@ -75,11 +75,11 @@ class TorchtunePostTrainingImpl:
|
|||
self,
|
||||
job_uuid: str,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: Dict[str, Any],
|
||||
logger_config: Dict[str, Any],
|
||||
hyperparam_search_config: dict[str, Any],
|
||||
logger_config: dict[str, Any],
|
||||
model: str,
|
||||
checkpoint_dir: Optional[str],
|
||||
algorithm_config: Optional[AlgorithmConfig],
|
||||
checkpoint_dir: str | None,
|
||||
algorithm_config: AlgorithmConfig | None,
|
||||
) -> PostTrainingJob:
|
||||
if isinstance(algorithm_config, LoraFinetuningConfig):
|
||||
|
||||
|
@ -121,8 +121,8 @@ class TorchtunePostTrainingImpl:
|
|||
finetuned_model: str,
|
||||
algorithm_config: DPOAlignmentConfig,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: Dict[str, Any],
|
||||
logger_config: Dict[str, Any],
|
||||
hyperparam_search_config: dict[str, Any],
|
||||
logger_config: dict[str, Any],
|
||||
) -> PostTrainingJob: ...
|
||||
|
||||
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
|
||||
|
@ -144,7 +144,7 @@ class TorchtunePostTrainingImpl:
|
|||
return data[0] if data else None
|
||||
|
||||
@webmethod(route="/post-training/job/status")
|
||||
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]:
|
||||
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse | None:
|
||||
job = self._scheduler.get_job(job_uuid)
|
||||
|
||||
match job.status:
|
||||
|
@ -175,6 +175,6 @@ class TorchtunePostTrainingImpl:
|
|||
self._scheduler.cancel(job_uuid)
|
||||
|
||||
@webmethod(route="/post-training/job/artifacts")
|
||||
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]:
|
||||
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse | None:
|
||||
job = self._scheduler.get_job(job_uuid)
|
||||
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))
|
||||
|
|
|
@ -11,7 +11,7 @@ import time
|
|||
from datetime import datetime, timezone
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
@ -80,10 +80,10 @@ class LoraFinetuningSingleDevice:
|
|||
config: TorchtunePostTrainingConfig,
|
||||
job_uuid: str,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: Dict[str, Any],
|
||||
logger_config: Dict[str, Any],
|
||||
hyperparam_search_config: dict[str, Any],
|
||||
logger_config: dict[str, Any],
|
||||
model: str,
|
||||
checkpoint_dir: Optional[str],
|
||||
checkpoint_dir: str | None,
|
||||
algorithm_config: LoraFinetuningConfig | QATFinetuningConfig | None,
|
||||
datasetio_api: DatasetIO,
|
||||
datasets_api: Datasets,
|
||||
|
@ -156,7 +156,7 @@ class LoraFinetuningSingleDevice:
|
|||
self.datasets_api = datasets_api
|
||||
|
||||
async def load_checkpoint(self):
|
||||
def get_checkpoint_files(checkpoint_dir: str) -> List[str]:
|
||||
def get_checkpoint_files(checkpoint_dir: str) -> list[str]:
|
||||
try:
|
||||
# List all files in the given directory
|
||||
files = os.listdir(checkpoint_dir)
|
||||
|
@ -250,8 +250,8 @@ class LoraFinetuningSingleDevice:
|
|||
self,
|
||||
enable_activation_checkpointing: bool,
|
||||
enable_activation_offloading: bool,
|
||||
base_model_state_dict: Dict[str, Any],
|
||||
lora_weights_state_dict: Optional[Dict[str, Any]] = None,
|
||||
base_model_state_dict: dict[str, Any],
|
||||
lora_weights_state_dict: dict[str, Any] | None = None,
|
||||
) -> nn.Module:
|
||||
self._lora_rank = self.algorithm_config.rank
|
||||
self._lora_alpha = self.algorithm_config.alpha
|
||||
|
@ -335,7 +335,7 @@ class LoraFinetuningSingleDevice:
|
|||
tokenizer: Llama3Tokenizer,
|
||||
shuffle: bool,
|
||||
batch_size: int,
|
||||
) -> Tuple[DistributedSampler, DataLoader]:
|
||||
) -> tuple[DistributedSampler, DataLoader]:
|
||||
async def fetch_rows(dataset_id: str):
|
||||
return await self.datasetio_api.iterrows(
|
||||
dataset_id=dataset_id,
|
||||
|
@ -430,7 +430,7 @@ class LoraFinetuningSingleDevice:
|
|||
checkpoint_format=self._checkpoint_format,
|
||||
)
|
||||
|
||||
async def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
async def _loss_step(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
# Shape [b, s], needed for the loss not the model
|
||||
labels = batch.pop("labels")
|
||||
# run model
|
||||
|
@ -452,7 +452,7 @@ class LoraFinetuningSingleDevice:
|
|||
|
||||
return loss
|
||||
|
||||
async def train(self) -> Tuple[Dict[str, Any], List[Checkpoint]]:
|
||||
async def train(self) -> tuple[dict[str, Any], list[Checkpoint]]:
|
||||
"""
|
||||
The core training loop.
|
||||
"""
|
||||
|
@ -464,7 +464,7 @@ class LoraFinetuningSingleDevice:
|
|||
|
||||
# training artifacts
|
||||
checkpoints = []
|
||||
memory_stats: Dict[str, Any] = {}
|
||||
memory_stats: dict[str, Any] = {}
|
||||
|
||||
# self.epochs_run should be non-zero when we're resuming from a checkpoint
|
||||
for curr_epoch in range(self.epochs_run, self.total_epochs):
|
||||
|
@ -565,7 +565,7 @@ class LoraFinetuningSingleDevice:
|
|||
|
||||
return (memory_stats, checkpoints)
|
||||
|
||||
async def validation(self) -> Tuple[float, float]:
|
||||
async def validation(self) -> tuple[float, float]:
|
||||
total_loss = 0.0
|
||||
total_tokens = 0
|
||||
log.info("Starting validation...")
|
||||
|
|
|
@ -4,12 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from .config import CodeScannerConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: CodeScannerConfig, deps: Dict[str, Any]):
|
||||
async def get_provider_impl(config: CodeScannerConfig, deps: dict[str, Any]):
|
||||
from .code_scanner import MetaReferenceCodeScannerSafetyImpl
|
||||
|
||||
impl = MetaReferenceCodeScannerSafetyImpl(config, deps)
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import (
|
||||
|
@ -48,8 +48,8 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
|
|||
async def run_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
messages: List[Message],
|
||||
params: Dict[str, Any] = None,
|
||||
messages: list[Message],
|
||||
params: dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
shield = await self.shield_store.get_shield(shield_id)
|
||||
if not shield:
|
||||
|
|
|
@ -4,12 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CodeScannerConfig(BaseModel):
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {}
|
||||
|
|
|
@ -4,12 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from .config import LlamaGuardConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: LlamaGuardConfig, deps: Dict[str, Any]):
|
||||
async def get_provider_impl(config: LlamaGuardConfig, deps: dict[str, Any]):
|
||||
from .llama_guard import LlamaGuardSafetyImpl
|
||||
|
||||
assert isinstance(config, LlamaGuardConfig), f"Unexpected config type: {type(config)}"
|
||||
|
|
|
@ -4,16 +4,16 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LlamaGuardConfig(BaseModel):
|
||||
excluded_categories: List[str] = []
|
||||
excluded_categories: list[str] = []
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"excluded_categories": [],
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import re
|
||||
from string import Template
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
|
||||
from llama_stack.apis.inference import (
|
||||
|
@ -149,8 +149,8 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
async def run_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
messages: List[Message],
|
||||
params: Dict[str, Any] = None,
|
||||
messages: list[Message],
|
||||
params: dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
shield = await self.shield_store.get_shield(shield_id)
|
||||
if not shield:
|
||||
|
@ -177,7 +177,7 @@ class LlamaGuardShield:
|
|||
self,
|
||||
model: str,
|
||||
inference_api: Inference,
|
||||
excluded_categories: Optional[List[str]] = None,
|
||||
excluded_categories: list[str] | None = None,
|
||||
):
|
||||
if excluded_categories is None:
|
||||
excluded_categories = []
|
||||
|
@ -193,7 +193,7 @@ class LlamaGuardShield:
|
|||
self.inference_api = inference_api
|
||||
self.excluded_categories = excluded_categories
|
||||
|
||||
def check_unsafe_response(self, response: str) -> Optional[str]:
|
||||
def check_unsafe_response(self, response: str) -> str | None:
|
||||
match = re.match(r"^unsafe\n(.*)$", response)
|
||||
if match:
|
||||
# extracts the unsafe code
|
||||
|
@ -202,7 +202,7 @@ class LlamaGuardShield:
|
|||
|
||||
return None
|
||||
|
||||
def get_safety_categories(self) -> List[str]:
|
||||
def get_safety_categories(self) -> list[str]:
|
||||
excluded_categories = self.excluded_categories
|
||||
if set(excluded_categories) == set(SAFETY_CATEGORIES_TO_CODE_MAP.values()):
|
||||
excluded_categories = []
|
||||
|
@ -218,7 +218,7 @@ class LlamaGuardShield:
|
|||
|
||||
return final_categories
|
||||
|
||||
def validate_messages(self, messages: List[Message]) -> None:
|
||||
def validate_messages(self, messages: list[Message]) -> None:
|
||||
if len(messages) == 0:
|
||||
raise ValueError("Messages must not be empty")
|
||||
if messages[0].role != Role.user.value:
|
||||
|
@ -229,7 +229,7 @@ class LlamaGuardShield:
|
|||
|
||||
return messages
|
||||
|
||||
async def run(self, messages: List[Message]) -> RunShieldResponse:
|
||||
async def run(self, messages: list[Message]) -> RunShieldResponse:
|
||||
messages = self.validate_messages(messages)
|
||||
|
||||
if self.model == CoreModelId.llama_guard_3_11b_vision.value:
|
||||
|
@ -247,10 +247,10 @@ class LlamaGuardShield:
|
|||
content = content.strip()
|
||||
return self.get_shield_response(content)
|
||||
|
||||
def build_text_shield_input(self, messages: List[Message]) -> UserMessage:
|
||||
def build_text_shield_input(self, messages: list[Message]) -> UserMessage:
|
||||
return UserMessage(content=self.build_prompt(messages))
|
||||
|
||||
def build_vision_shield_input(self, messages: List[Message]) -> UserMessage:
|
||||
def build_vision_shield_input(self, messages: list[Message]) -> UserMessage:
|
||||
conversation = []
|
||||
most_recent_img = None
|
||||
|
||||
|
@ -284,7 +284,7 @@ class LlamaGuardShield:
|
|||
|
||||
return UserMessage(content=prompt)
|
||||
|
||||
def build_prompt(self, messages: List[Message]) -> str:
|
||||
def build_prompt(self, messages: list[Message]) -> str:
|
||||
categories = self.get_safety_categories()
|
||||
categories_str = "\n".join(categories)
|
||||
conversations_str = "\n\n".join(
|
||||
|
|
|
@ -4,12 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from .config import PromptGuardConfig # noqa: F401
|
||||
from .config import PromptGuardConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: PromptGuardConfig, deps: Dict[str, Any]):
|
||||
async def get_provider_impl(config: PromptGuardConfig, deps: dict[str, Any]):
|
||||
from .prompt_guard import PromptGuardSafetyImpl
|
||||
|
||||
impl = PromptGuardSafetyImpl(config, deps)
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
|
@ -26,7 +26,7 @@ class PromptGuardConfig(BaseModel):
|
|||
return v
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"guard_type": "injection",
|
||||
}
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
@ -49,8 +49,8 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
async def run_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
messages: List[Message],
|
||||
params: Dict[str, Any] = None,
|
||||
messages: list[Message],
|
||||
params: dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
shield = await self.shield_store.get_shield(shield_id)
|
||||
if not shield:
|
||||
|
@ -81,7 +81,7 @@ class PromptGuardShield:
|
|||
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(model_dir, device_map=self.device)
|
||||
|
||||
async def run(self, messages: List[Message]) -> RunShieldResponse:
|
||||
async def run(self, messages: list[Message]) -> RunShieldResponse:
|
||||
message = messages[-1]
|
||||
text = interleaved_content_as_str(message.content)
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
|
@ -12,7 +12,7 @@ from .config import BasicScoringConfig
|
|||
|
||||
async def get_provider_impl(
|
||||
config: BasicScoringConfig,
|
||||
deps: Dict[Api, Any],
|
||||
deps: dict[Api, Any],
|
||||
):
|
||||
from .scoring import BasicScoringImpl
|
||||
|
||||
|
|
|
@ -3,12 +3,12 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BasicScoringConfig(BaseModel):
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {}
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
|
@ -66,7 +66,7 @@ class BasicScoringImpl(
|
|||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def list_scoring_functions(self) -> List[ScoringFn]:
|
||||
async def list_scoring_functions(self) -> list[ScoringFn]:
|
||||
scoring_fn_defs_list = [
|
||||
fn_def for impl in self.scoring_fn_id_impls.values() for fn_def in impl.get_supported_scoring_fn_defs()
|
||||
]
|
||||
|
@ -82,7 +82,7 @@ class BasicScoringImpl(
|
|||
async def score_batch(
|
||||
self,
|
||||
dataset_id: str,
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse:
|
||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
|
@ -107,8 +107,8 @@ class BasicScoringImpl(
|
|||
|
||||
async def score(
|
||||
self,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||
input_rows: list[dict[str, Any]],
|
||||
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
||||
) -> ScoreResponse:
|
||||
res = {}
|
||||
for scoring_fn_id in scoring_functions.keys():
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
|
@ -17,7 +17,7 @@ from ..utils.bfcl.checker import ast_checker, is_empty_output
|
|||
from .fn_defs.bfcl import bfcl
|
||||
|
||||
|
||||
def postprocess(x: Dict[str, Any], test_category: str) -> Dict[str, Any]:
|
||||
def postprocess(x: dict[str, Any], test_category: str) -> dict[str, Any]:
|
||||
contain_func_call = False
|
||||
error = None
|
||||
error_type = None
|
||||
|
@ -52,11 +52,11 @@ def postprocess(x: Dict[str, Any], test_category: str) -> Dict[str, Any]:
|
|||
}
|
||||
|
||||
|
||||
def gen_valid(x: Dict[str, Any]) -> Dict[str, float]:
|
||||
def gen_valid(x: dict[str, Any]) -> dict[str, float]:
|
||||
return {"valid": x["valid"]}
|
||||
|
||||
|
||||
def gen_relevance_acc(x: Dict[str, Any]) -> Dict[str, float]:
|
||||
def gen_relevance_acc(x: dict[str, Any]) -> dict[str, float]:
|
||||
# This function serves for both relevance and irrelevance tests, which share the exact opposite logic.
|
||||
# If `test_category` is "irrelevance", the model is expected to output no function call.
|
||||
# No function call means either the AST decoding fails (a error message is generated) or the decoded AST does not contain any function call (such as a empty list, `[]`).
|
||||
|
@ -78,9 +78,9 @@ class BFCLScoringFn(RegisteredBaseScoringFn):
|
|||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = "bfcl",
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
input_row: dict[str, Any],
|
||||
scoring_fn_identifier: str | None = "bfcl",
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> ScoringResultRow:
|
||||
test_category = re.sub(r"_[0-9_-]+$", "", input_row["id"])
|
||||
score_result = postprocess(input_row, test_category)
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
|
@ -228,9 +228,9 @@ class DocVQAScoringFn(RegisteredBaseScoringFn):
|
|||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = "docvqa",
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
input_row: dict[str, Any],
|
||||
scoring_fn_identifier: str | None = "docvqa",
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> ScoringResultRow:
|
||||
expected_answers = json.loads(input_row["expected_answer"])
|
||||
generated_answer = input_row["generated_answer"]
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
|
@ -26,9 +26,9 @@ class EqualityScoringFn(RegisteredBaseScoringFn):
|
|||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = "equality",
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
input_row: dict[str, Any],
|
||||
scoring_fn_identifier: str | None = "equality",
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> ScoringResultRow:
|
||||
assert "expected_answer" in input_row, "Expected answer not found in input row."
|
||||
assert "generated_answer" in input_row, "Generated answer not found in input row."
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
|
@ -28,9 +28,9 @@ class IfEvalScoringFn(RegisteredBaseScoringFn):
|
|||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
input_row: dict[str, Any],
|
||||
scoring_fn_identifier: str | None = None,
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> ScoringResultRow:
|
||||
from ..utils.ifeval_utils import INSTRUCTION_DICT, INSTRUCTION_LIST
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
|
||||
|
@ -28,9 +28,9 @@ class RegexParserMathResponseScoringFn(RegisteredBaseScoringFn):
|
|||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
input_row: dict[str, Any],
|
||||
scoring_fn_identifier: str | None = None,
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> ScoringResultRow:
|
||||
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
|
||||
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
|
||||
|
@ -28,9 +28,9 @@ class RegexParserScoringFn(RegisteredBaseScoringFn):
|
|||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
input_row: dict[str, Any],
|
||||
scoring_fn_identifier: str | None = None,
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> ScoringResultRow:
|
||||
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
|
||||
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
|
@ -26,9 +26,9 @@ class SubsetOfScoringFn(RegisteredBaseScoringFn):
|
|||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = "subset_of",
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
input_row: dict[str, Any],
|
||||
scoring_fn_identifier: str | None = "subset_of",
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> ScoringResultRow:
|
||||
expected_answer = input_row["expected_answer"]
|
||||
generated_answer = input_row["generated_answer"]
|
||||
|
|
|
@ -11,8 +11,8 @@ import logging
|
|||
import random
|
||||
import re
|
||||
import string
|
||||
from collections.abc import Iterable, Sequence
|
||||
from types import MappingProxyType
|
||||
from typing import Dict, Iterable, List, Optional, Sequence, Union
|
||||
|
||||
import emoji
|
||||
import langdetect
|
||||
|
@ -1673,12 +1673,11 @@ def split_chinese_japanese_hindi(lines: str) -> Iterable[str]:
|
|||
The separator for hindi is '।'
|
||||
"""
|
||||
for line in lines.splitlines():
|
||||
for sent in re.findall(
|
||||
yield from re.findall(
|
||||
r"[^!?。\.\!\?\!\?\.\n।]+[!?。\.\!\?\!\?\.\n।]?",
|
||||
line.strip(),
|
||||
flags=re.U,
|
||||
):
|
||||
yield sent
|
||||
)
|
||||
|
||||
|
||||
def count_words_cjk(text: str) -> int:
|
||||
|
@ -1707,7 +1706,7 @@ def count_words_cjk(text: str) -> int:
|
|||
return non_asian_words_cnt + asian_chars_cnt + emoji_cnt
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
@functools.cache
|
||||
def _get_sentence_tokenizer():
|
||||
return nltk.data.load("nltk:tokenizers/punkt/english.pickle")
|
||||
|
||||
|
@ -1719,8 +1718,8 @@ def count_sentences(text):
|
|||
return len(tokenized_sentences)
|
||||
|
||||
|
||||
def get_langid(text: str, lid_path: Optional[str] = None) -> str:
|
||||
line_langs: List[str] = []
|
||||
def get_langid(text: str, lid_path: str | None = None) -> str:
|
||||
line_langs: list[str] = []
|
||||
lines = [line.strip() for line in text.split("\n") if len(line.strip()) >= 4]
|
||||
|
||||
for line in lines:
|
||||
|
@ -1741,7 +1740,7 @@ def generate_keywords(num_keywords):
|
|||
|
||||
|
||||
"""Library of instructions"""
|
||||
_InstructionArgsDtype = Optional[Dict[str, Union[int, str, Sequence[str]]]]
|
||||
_InstructionArgsDtype = dict[str, int | str | Sequence[str]] | None
|
||||
|
||||
_LANGUAGES = LANGUAGE_CODES
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import re
|
||||
from typing import Sequence
|
||||
from collections.abc import Sequence
|
||||
|
||||
from llama_stack.providers.utils.scoring.basic_scoring_utils import time_limit
|
||||
|
||||
|
@ -323,7 +323,7 @@ def _fix_a_slash_b(string: str) -> str:
|
|||
try:
|
||||
ia = int(a)
|
||||
ib = int(b)
|
||||
assert string == "{}/{}".format(ia, ib)
|
||||
assert string == f"{ia}/{ib}"
|
||||
new_string = "\\frac{" + str(ia) + "}{" + str(ib) + "}"
|
||||
return new_string
|
||||
except (ValueError, AssertionError):
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -18,7 +18,7 @@ class BraintrustProviderDataValidator(BaseModel):
|
|||
|
||||
async def get_provider_impl(
|
||||
config: BraintrustScoringConfig,
|
||||
deps: Dict[Api, Any],
|
||||
deps: dict[Api, Any],
|
||||
):
|
||||
from .braintrust import BraintrustScoringImpl
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from autoevals.llm import Factuality
|
||||
from autoevals.ragas import (
|
||||
|
@ -132,7 +132,7 @@ class BraintrustScoringImpl(
|
|||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def list_scoring_functions(self) -> List[ScoringFn]:
|
||||
async def list_scoring_functions(self) -> list[ScoringFn]:
|
||||
scoring_fn_defs_list = list(self.supported_fn_defs_registry.values())
|
||||
for f in scoring_fn_defs_list:
|
||||
assert f.identifier.startswith("braintrust"), (
|
||||
|
@ -159,7 +159,7 @@ class BraintrustScoringImpl(
|
|||
async def score_batch(
|
||||
self,
|
||||
dataset_id: str,
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]],
|
||||
scoring_functions: dict[str, ScoringFnParams | None],
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse:
|
||||
await self.set_api_key()
|
||||
|
@ -181,9 +181,7 @@ class BraintrustScoringImpl(
|
|||
results=res.results,
|
||||
)
|
||||
|
||||
async def score_row(
|
||||
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None
|
||||
) -> ScoringResultRow:
|
||||
async def score_row(self, input_row: dict[str, Any], scoring_fn_identifier: str | None = None) -> ScoringResultRow:
|
||||
validate_row_schema(input_row, get_valid_schemas(Api.scoring.value))
|
||||
await self.set_api_key()
|
||||
assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None"
|
||||
|
@ -203,8 +201,8 @@ class BraintrustScoringImpl(
|
|||
|
||||
async def score(
|
||||
self,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]],
|
||||
input_rows: list[dict[str, Any]],
|
||||
scoring_functions: dict[str, ScoringFnParams | None],
|
||||
) -> ScoreResponse:
|
||||
await self.set_api_key()
|
||||
res = {}
|
||||
|
|
|
@ -3,19 +3,19 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class BraintrustScoringConfig(BaseModel):
|
||||
openai_api_key: Optional[str] = Field(
|
||||
openai_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="The OpenAI API Key",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"openai_api_key": "${env.OPENAI_API_KEY:}",
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
|
@ -12,7 +12,7 @@ from .config import LlmAsJudgeScoringConfig
|
|||
|
||||
async def get_provider_impl(
|
||||
config: LlmAsJudgeScoringConfig,
|
||||
deps: Dict[Api, Any],
|
||||
deps: dict[Api, Any],
|
||||
):
|
||||
from .scoring import LlmAsJudgeScoringImpl
|
||||
|
||||
|
|
|
@ -3,12 +3,12 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LlmAsJudgeScoringConfig(BaseModel):
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {}
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
|
@ -50,7 +50,7 @@ class LlmAsJudgeScoringImpl(
|
|||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def list_scoring_functions(self) -> List[ScoringFn]:
|
||||
async def list_scoring_functions(self) -> list[ScoringFn]:
|
||||
scoring_fn_defs_list = self.llm_as_judge_fn.get_supported_scoring_fn_defs()
|
||||
|
||||
for f in self.llm_as_judge_fn.get_supported_scoring_fn_defs():
|
||||
|
@ -66,7 +66,7 @@ class LlmAsJudgeScoringImpl(
|
|||
async def score_batch(
|
||||
self,
|
||||
dataset_id: str,
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse:
|
||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
|
@ -91,8 +91,8 @@ class LlmAsJudgeScoringImpl(
|
|||
|
||||
async def score(
|
||||
self,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||
input_rows: list[dict[str, Any]],
|
||||
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
||||
) -> ScoreResponse:
|
||||
res = {}
|
||||
for scoring_fn_id in scoring_functions.keys():
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference.inference import Inference, UserMessage
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
|
@ -30,9 +30,9 @@ class LlmAsJudgeScoringFn(RegisteredBaseScoringFn):
|
|||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
input_row: dict[str, Any],
|
||||
scoring_fn_identifier: str | None = None,
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> ScoringResultRow:
|
||||
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
|
||||
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
|
@ -13,7 +13,7 @@ from .config import TelemetryConfig, TelemetrySink
|
|||
__all__ = ["TelemetryConfig", "TelemetrySink"]
|
||||
|
||||
|
||||
async def get_provider_impl(config: TelemetryConfig, deps: Dict[Api, Any]):
|
||||
async def get_provider_impl(config: TelemetryConfig, deps: dict[Api, Any]):
|
||||
from .telemetry import TelemetryAdapter
|
||||
|
||||
impl = TelemetryAdapter(config, deps)
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
@ -33,7 +33,7 @@ class TelemetryConfig(BaseModel):
|
|||
default="",
|
||||
description="The service name to use for telemetry",
|
||||
)
|
||||
sinks: List[TelemetrySink] = Field(
|
||||
sinks: list[TelemetrySink] = Field(
|
||||
default=[TelemetrySink.CONSOLE, TelemetrySink.SQLITE],
|
||||
description="List of telemetry sinks to enable (possible values: otel, sqlite, console)",
|
||||
)
|
||||
|
@ -50,7 +50,7 @@ class TelemetryConfig(BaseModel):
|
|||
return v
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, db_name: str = "trace_store.db") -> Dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str, db_name: str = "trace_store.db") -> dict[str, Any]:
|
||||
return {
|
||||
"service_name": "${env.OTEL_SERVICE_NAME:}",
|
||||
"sinks": "${env.TELEMETRY_SINKS:console,sqlite}",
|
||||
|
|
|
@ -78,7 +78,7 @@ class ConsoleSpanProcessor(SpanProcessor):
|
|||
|
||||
severity = event.attributes.get("severity", "info")
|
||||
message = event.attributes.get("message", event.name)
|
||||
if isinstance(message, (dict, list)):
|
||||
if isinstance(message, dict | list):
|
||||
message = json.dumps(message, indent=2)
|
||||
|
||||
severity_colors = {
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import threading
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from opentelemetry import metrics, trace
|
||||
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
|
||||
|
@ -60,7 +60,7 @@ def is_tracing_enabled(tracer):
|
|||
|
||||
|
||||
class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
||||
def __init__(self, config: TelemetryConfig, deps: Dict[Api, Any]) -> None:
|
||||
def __init__(self, config: TelemetryConfig, deps: dict[Api, Any]) -> None:
|
||||
self.config = config
|
||||
self.datasetio_api = deps.get(Api.datasetio)
|
||||
self.meter = None
|
||||
|
@ -231,10 +231,10 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
|
||||
async def query_traces(
|
||||
self,
|
||||
attribute_filters: Optional[List[QueryCondition]] = None,
|
||||
limit: Optional[int] = 100,
|
||||
offset: Optional[int] = 0,
|
||||
order_by: Optional[List[str]] = None,
|
||||
attribute_filters: list[QueryCondition] | None = None,
|
||||
limit: int | None = 100,
|
||||
offset: int | None = 0,
|
||||
order_by: list[str] | None = None,
|
||||
) -> QueryTracesResponse:
|
||||
return QueryTracesResponse(
|
||||
data=await self.trace_store.query_traces(
|
||||
|
@ -254,8 +254,8 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
async def get_span_tree(
|
||||
self,
|
||||
span_id: str,
|
||||
attributes_to_return: Optional[List[str]] = None,
|
||||
max_depth: Optional[int] = None,
|
||||
attributes_to_return: list[str] | None = None,
|
||||
max_depth: int | None = None,
|
||||
) -> QuerySpanTreeResponse:
|
||||
return QuerySpanTreeResponse(
|
||||
data=await self.trace_store.get_span_tree(
|
||||
|
|
|
@ -1,19 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from .config import CodeInterpreterToolConfig
|
||||
|
||||
__all__ = ["CodeInterpreterToolConfig", "CodeInterpreterToolRuntimeImpl"]
|
||||
|
||||
|
||||
async def get_provider_impl(config: CodeInterpreterToolConfig, _deps: Dict[str, Any]):
|
||||
from .code_interpreter import CodeInterpreterToolRuntimeImpl
|
||||
|
||||
impl = CodeInterpreterToolRuntimeImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -1,131 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import errno
|
||||
|
||||
# Disabling potentially dangerous functions
|
||||
import os as _os
|
||||
from functools import partial
|
||||
|
||||
os_funcs_to_disable = [
|
||||
"kill",
|
||||
"system",
|
||||
"putenv",
|
||||
"remove",
|
||||
"removedirs",
|
||||
"rmdir",
|
||||
"fchdir",
|
||||
"setuid",
|
||||
"fork",
|
||||
"forkpty",
|
||||
"killpg",
|
||||
"rename",
|
||||
"renames",
|
||||
"truncate",
|
||||
"replace",
|
||||
# "unlink", # Commenting as this was blocking matpltlib from rendering plots correctly
|
||||
"fchmod",
|
||||
"fchown",
|
||||
"chmod",
|
||||
"chown",
|
||||
"chroot",
|
||||
"fchdir",
|
||||
"lchflags",
|
||||
"lchmod",
|
||||
"lchown",
|
||||
"chdir",
|
||||
]
|
||||
|
||||
|
||||
def call_not_allowed(*args, **kwargs):
|
||||
raise OSError(errno.EPERM, "Call are not permitted in this environment")
|
||||
|
||||
|
||||
for func_name in os_funcs_to_disable:
|
||||
if hasattr(_os, func_name):
|
||||
setattr(_os, func_name, partial(call_not_allowed, _func_name=f"os.{func_name}"))
|
||||
|
||||
import shutil as _shutil
|
||||
|
||||
for func_name in ["rmtree", "move", "chown"]:
|
||||
if hasattr(_shutil, func_name):
|
||||
setattr(
|
||||
_shutil,
|
||||
func_name,
|
||||
partial(call_not_allowed, _func_name=f"shutil.{func_name}"),
|
||||
)
|
||||
|
||||
import subprocess as _subprocess
|
||||
|
||||
|
||||
def popen_not_allowed(*args, **kwargs):
|
||||
raise _subprocess.CalledProcessError(
|
||||
-1,
|
||||
args[0] if args else "unknown",
|
||||
stderr="subprocess.Popen is not allowed in this environment",
|
||||
)
|
||||
|
||||
|
||||
_subprocess.Popen = popen_not_allowed # type: ignore
|
||||
|
||||
|
||||
import atexit as _atexit
|
||||
import builtins as _builtins
|
||||
import io as _io
|
||||
import json as _json
|
||||
import sys as _sys
|
||||
|
||||
# NB! The following "unused" imports crucial, make sure not not to remove
|
||||
# them with linters - they're used in code_execution.py
|
||||
from contextlib import ( # noqa
|
||||
contextmanager as _contextmanager,
|
||||
)
|
||||
from multiprocessing.connection import Connection as _Connection
|
||||
|
||||
# Mangle imports to avoid polluting model execution namespace.
|
||||
|
||||
_IO_SINK = _io.StringIO()
|
||||
_NETWORK_TIMEOUT = 5
|
||||
_NETWORK_CONNECTIONS = None
|
||||
|
||||
|
||||
def _open_connections():
|
||||
global _NETWORK_CONNECTIONS
|
||||
if _NETWORK_CONNECTIONS is not None:
|
||||
# Ensure connections only opened once.
|
||||
return _NETWORK_CONNECTIONS
|
||||
req_w_fd, resp_r_fd = _sys.argv[1], _sys.argv[2]
|
||||
req_con = _Connection(int(req_w_fd), readable=False)
|
||||
resp_con = _Connection(int(resp_r_fd), writable=False)
|
||||
_NETWORK_CONNECTIONS = (req_con, resp_con)
|
||||
return _NETWORK_CONNECTIONS
|
||||
|
||||
|
||||
_builtins._open_connections = _open_connections # type: ignore
|
||||
|
||||
|
||||
@_atexit.register
|
||||
def _close_connections():
|
||||
global _NETWORK_CONNECTIONS
|
||||
if _NETWORK_CONNECTIONS is None:
|
||||
return
|
||||
for con in _NETWORK_CONNECTIONS:
|
||||
con.close()
|
||||
del _NETWORK_CONNECTIONS
|
||||
|
||||
|
||||
def _network_call(request):
|
||||
# NOTE: We communicate with the parent process in json, encoded
|
||||
# in raw bytes. We do this because native send/recv methods use
|
||||
# pickle which involves execution of arbitrary code.
|
||||
_open_connections()
|
||||
req_con, resp_con = _NETWORK_CONNECTIONS
|
||||
|
||||
req_con.send_bytes(_json.dumps(request).encode("utf-8"))
|
||||
if resp_con.poll(timeout=_NETWORK_TIMEOUT) is None:
|
||||
raise Exception(f"Network request timed out: {_json.dumps(request)}")
|
||||
else:
|
||||
return _json.loads(resp_con.recv_bytes().decode("utf-8"))
|
|
@ -1,257 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import textwrap
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from .utils import get_code_env_prefix
|
||||
|
||||
TOOLS_ATTACHMENT_KEY = "__tools_attachment__"
|
||||
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
||||
|
||||
DIRNAME = Path(__file__).parent
|
||||
|
||||
CODE_EXEC_TIMEOUT = 20
|
||||
CODE_ENV_PREFIX = get_code_env_prefix()
|
||||
|
||||
STDOUTERR_SINK_WRAPPER_TEMPLATE = """\
|
||||
with _redirect_stdout(_IO_SINK), _redirect_stderr(_IO_SINK):
|
||||
{code}\
|
||||
"""
|
||||
|
||||
TRYEXCEPT_WRAPPER_TEMPLATE = """\
|
||||
try:
|
||||
{code}
|
||||
except:
|
||||
pass\
|
||||
"""
|
||||
|
||||
|
||||
def generate_bwrap_command(bind_dirs: List[str]) -> str:
|
||||
"""
|
||||
Generate the bwrap command string for binding all
|
||||
directories in the current directory read-only.
|
||||
"""
|
||||
bwrap_args = ""
|
||||
bwrap_args += "--ro-bind / / "
|
||||
# Add the --dev flag to mount device files
|
||||
bwrap_args += "--dev /dev "
|
||||
for d in bind_dirs:
|
||||
bwrap_args += f"--bind {d} {d} "
|
||||
|
||||
# Add the --unshare-all flag to isolate the sandbox from the rest of the system
|
||||
bwrap_args += "--unshare-all "
|
||||
# Add the --die-with-parent flag to ensure the child process dies when bwrap's parent dies
|
||||
bwrap_args += "--die-with-parent "
|
||||
return bwrap_args
|
||||
|
||||
|
||||
@dataclass
|
||||
class CodeExecutionContext:
|
||||
matplotlib_dump_dir: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class CodeExecutionRequest:
|
||||
scripts: List[str]
|
||||
only_last_cell_stdouterr: bool = True
|
||||
only_last_cell_fail: bool = True
|
||||
seed: int = 0
|
||||
strip_fpaths_in_stderr: bool = True
|
||||
use_bwrap: bool = True
|
||||
|
||||
|
||||
class CodeExecutor:
|
||||
def __init__(self, context: CodeExecutionContext):
|
||||
self.context = context
|
||||
|
||||
def execute(self, req: CodeExecutionRequest) -> dict:
|
||||
scripts = req.scripts
|
||||
for i in range(len(scripts) - 1):
|
||||
if req.only_last_cell_stdouterr:
|
||||
scripts[i] = STDOUTERR_SINK_WRAPPER_TEMPLATE.format(code=textwrap.indent(scripts[i], " " * 4))
|
||||
if req.only_last_cell_fail:
|
||||
scripts[i] = TRYEXCEPT_WRAPPER_TEMPLATE.format(code=textwrap.indent(scripts[i], " " * 4))
|
||||
|
||||
# Seeds prefix:
|
||||
seed = req.seed
|
||||
seeds_prefix = f"""\
|
||||
def _set_seeds():
|
||||
import random
|
||||
random.seed({seed})
|
||||
import numpy as np
|
||||
np.random.seed({seed})
|
||||
_set_seeds()\
|
||||
"""
|
||||
|
||||
script = "\n\n".join([seeds_prefix] + [CODE_ENV_PREFIX] + scripts)
|
||||
with tempfile.TemporaryDirectory() as dpath:
|
||||
code_fpath = os.path.join(dpath, "code.py")
|
||||
with open(code_fpath, "w") as f:
|
||||
f.write(script)
|
||||
|
||||
try:
|
||||
python_path = os.environ.get("PYTHONPATH", "")
|
||||
env = dict(
|
||||
os.environ,
|
||||
PYTHONHASHSEED=str(seed),
|
||||
MPLCONFIGDIR=dpath,
|
||||
MPLBACKEND="module://matplotlib_custom_backend",
|
||||
PYTHONPATH=f"{DIRNAME}:{python_path}",
|
||||
)
|
||||
|
||||
if req.use_bwrap:
|
||||
bwrap_prefix = "bwrap " + generate_bwrap_command(bind_dirs=[dpath])
|
||||
cmd = [*bwrap_prefix.split(), sys.executable, "-c", script]
|
||||
else:
|
||||
cmd = [sys.executable, "-c", script]
|
||||
|
||||
stdout, stderr, returncode = do_subprocess(
|
||||
cmd=cmd,
|
||||
env=env,
|
||||
ctx=self.context,
|
||||
)
|
||||
|
||||
stderr = stderr.strip()
|
||||
if req.strip_fpaths_in_stderr:
|
||||
pattern = r'File "([^"]+)", line (\d+)'
|
||||
stderr = re.sub(pattern, r"line \2", stderr)
|
||||
|
||||
return {
|
||||
"process_status": "completed",
|
||||
"returncode": returncode,
|
||||
"stdout": stdout.strip(),
|
||||
"stderr": stderr,
|
||||
}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return {
|
||||
"process_status": "timeout",
|
||||
"stdout": "Timed out",
|
||||
"stderr": "Timed out",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"process_status": "error",
|
||||
"error_type": type(e).__name__,
|
||||
"stderr": str(e),
|
||||
"stdout": str(e),
|
||||
}
|
||||
|
||||
|
||||
def process_matplotlib_response(response, matplotlib_dump_dir: str):
|
||||
image_data = response["image_data"]
|
||||
# Convert the base64 string to a bytes object
|
||||
images_raw = [base64.b64decode(d["image_base64"]) for d in image_data]
|
||||
# Create a list of PIL images from the bytes objects
|
||||
images = [Image.open(BytesIO(img)) for img in images_raw]
|
||||
# Create a list of image paths
|
||||
image_paths = []
|
||||
for i, img in enumerate(images):
|
||||
# create new directory for each day to better organize data:
|
||||
dump_dname = datetime.today().strftime("%Y-%m-%d") # noqa: DTZ002 - we don't care about timezones here since we are displaying the date
|
||||
dump_dpath = Path(matplotlib_dump_dir, dump_dname)
|
||||
dump_dpath.mkdir(parents=True, exist_ok=True)
|
||||
# save image into a file
|
||||
dump_fname = f"matplotlib_{str(time.time()).replace('.', '_')}_{i}.png"
|
||||
dump_fpath = dump_dpath / dump_fname
|
||||
img.save(dump_fpath, "PNG")
|
||||
image_paths.append(str(dump_fpath))
|
||||
|
||||
# this is kind of convoluted, we send back this response to the subprocess which
|
||||
# prints it out
|
||||
info = {
|
||||
"filepath": str(image_paths[-1]),
|
||||
"mimetype": "image/png",
|
||||
}
|
||||
return f"{TOOLS_ATTACHMENT_KEY}={json.dumps(info)}"
|
||||
|
||||
|
||||
def execute_subprocess_request(request, ctx: CodeExecutionContext):
|
||||
"Route requests from the subprocess (via network Pipes) to the internet/tools."
|
||||
if request["type"] == "matplotlib":
|
||||
return process_matplotlib_response(request, ctx.matplotlib_dump_dir)
|
||||
else:
|
||||
raise Exception(f"Unrecognised network request type: {request['type']}")
|
||||
|
||||
|
||||
def do_subprocess(*, cmd: list, env: dict, ctx: CodeExecutionContext):
|
||||
# Create Pipes to be used for any external tool/network requests.
|
||||
req_r, req_w = multiprocessing.Pipe(duplex=False)
|
||||
resp_r, resp_w = multiprocessing.Pipe(duplex=False)
|
||||
|
||||
cmd += [str(req_w.fileno()), str(resp_r.fileno())]
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
pass_fds=(req_w.fileno(), resp_r.fileno()),
|
||||
text=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
close_fds=True,
|
||||
env=env,
|
||||
)
|
||||
|
||||
# Close unnecessary fds.
|
||||
req_w.close()
|
||||
resp_r.close()
|
||||
|
||||
pipe_close = False
|
||||
done_read = False
|
||||
start = time.monotonic()
|
||||
while proc.poll() is None and not pipe_close:
|
||||
if req_r.poll(0.1):
|
||||
# NB: Python pipe semantics for poll and recv mean that
|
||||
# poll() returns True is a pipe is closed.
|
||||
# CF old school PEP from '09
|
||||
# https://bugs.python.org/issue5573
|
||||
try:
|
||||
request = json.loads(req_r.recv_bytes().decode("utf-8"))
|
||||
response = execute_subprocess_request(request, ctx)
|
||||
|
||||
resp_w.send_bytes(json.dumps(response).encode("utf-8"))
|
||||
except EOFError:
|
||||
# The request pipe is closed - set a marker to exit
|
||||
# after the next attempt at reading stdout/stderr.
|
||||
pipe_close = True
|
||||
|
||||
try:
|
||||
# If lots has been printed, pipe might be full but
|
||||
# proc cannot exit until all the stdout/stderr
|
||||
# been written/read.
|
||||
stdout, stderr = proc.communicate(timeout=0.3)
|
||||
done_read = True
|
||||
except subprocess.TimeoutExpired:
|
||||
# The program has not terminated. Ignore it, there
|
||||
# may be more network/tool requests.
|
||||
continue
|
||||
if time.monotonic() - start > CODE_EXEC_TIMEOUT:
|
||||
proc.terminate()
|
||||
raise subprocess.TimeoutExpired(cmd, CODE_EXEC_TIMEOUT)
|
||||
|
||||
if not done_read:
|
||||
# Solve race condition where process terminates before
|
||||
# we hit the while loop.
|
||||
stdout, stderr = proc.communicate(timeout=0.3)
|
||||
|
||||
resp_w.close()
|
||||
req_r.close()
|
||||
return stdout, stderr, proc.returncode
|
|
@ -1,80 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolDefsResponse,
|
||||
Tool,
|
||||
ToolDef,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
|
||||
from .code_execution import CodeExecutionContext, CodeExecutionRequest, CodeExecutor
|
||||
from .config import CodeInterpreterToolConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||
def __init__(self, config: CodeInterpreterToolConfig):
|
||||
self.config = config
|
||||
ctx = CodeExecutionContext(
|
||||
matplotlib_dump_dir=tempfile.mkdtemp(),
|
||||
)
|
||||
self.code_executor = CodeExecutor(ctx)
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
return
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
) -> ListToolDefsResponse:
|
||||
return ListToolDefsResponse(
|
||||
data=[
|
||||
ToolDef(
|
||||
name="code_interpreter",
|
||||
description="Execute code",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="code",
|
||||
description="The code to execute",
|
||||
parameter_type="string",
|
||||
),
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
|
||||
script = kwargs["code"]
|
||||
# Use environment variable to control bwrap usage
|
||||
force_disable_bwrap = os.environ.get("DISABLE_CODE_SANDBOX", "").lower() in ("1", "true", "yes")
|
||||
req = CodeExecutionRequest(scripts=[script], use_bwrap=not force_disable_bwrap)
|
||||
res = await asyncio.to_thread(self.code_executor.execute, req)
|
||||
pieces = [res["process_status"]]
|
||||
for out_type in ["stdout", "stderr"]:
|
||||
res_out = res[out_type]
|
||||
if res_out != "":
|
||||
pieces.extend([f"[{out_type}]", res_out, f"[/{out_type}]"])
|
||||
if out_type == "stderr":
|
||||
log.error(f"ipython tool error: ↓\n{res_out}")
|
||||
return ToolInvocationResult(content="\n".join(pieces))
|
|
@ -1,15 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CodeInterpreterToolConfig(BaseModel):
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {}
|
|
@ -1,93 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
A custom Matplotlib backend that overrides the show method to return image bytes.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import json as _json
|
||||
import logging
|
||||
|
||||
import matplotlib
|
||||
from matplotlib.backend_bases import FigureManagerBase
|
||||
|
||||
# Import necessary components from Matplotlib
|
||||
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CustomFigureCanvas(FigureCanvasAgg):
|
||||
def show(self):
|
||||
# Save the figure to a BytesIO object
|
||||
buf = io.BytesIO()
|
||||
self.print_png(buf)
|
||||
image_bytes = buf.getvalue()
|
||||
buf.close()
|
||||
return image_bytes
|
||||
|
||||
|
||||
class CustomFigureManager(FigureManagerBase):
|
||||
def __init__(self, canvas, num):
|
||||
super().__init__(canvas, num)
|
||||
|
||||
|
||||
# Mimic module initialization that integrates with the Matplotlib backend system
|
||||
def _create_figure_manager(num, *args, **kwargs):
|
||||
"""
|
||||
Create a custom figure manager instance.
|
||||
"""
|
||||
FigureClass = kwargs.pop("FigureClass", None) # noqa: N806
|
||||
if FigureClass is None:
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
FigureClass = Figure # noqa: N806
|
||||
fig = FigureClass(*args, **kwargs)
|
||||
canvas = CustomFigureCanvas(fig)
|
||||
manager = CustomFigureManager(canvas, num)
|
||||
return manager
|
||||
|
||||
|
||||
def show():
|
||||
"""
|
||||
Handle all figures and potentially return their images as bytes.
|
||||
|
||||
This function iterates over all figures registered with the custom backend,
|
||||
renders them as images in bytes format, and could return a list of bytes objects,
|
||||
one for each figure, or handle them as needed.
|
||||
"""
|
||||
image_data = []
|
||||
for manager in matplotlib._pylab_helpers.Gcf.get_all_fig_managers():
|
||||
# Get the figure from the manager
|
||||
fig = manager.canvas.figure
|
||||
buf = io.BytesIO() # Create a buffer for the figure
|
||||
fig.savefig(buf, format="png") # Save the figure to the buffer in PNG format
|
||||
buf.seek(0) # Go to the beginning of the buffer
|
||||
image_bytes = buf.getvalue() # Retrieve bytes value
|
||||
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
|
||||
image_data.append({"image_base64": image_base64})
|
||||
buf.close()
|
||||
|
||||
# The _open_connections method is dynamically made available to
|
||||
# the interpreter by bundling code from "code_env_prefix.py" -- by literally prefixing it -- and
|
||||
# then "eval"ing it within a sandboxed interpreter.
|
||||
req_con, resp_con = _open_connections() # noqa: F821
|
||||
|
||||
_json_dump = _json.dumps(
|
||||
{
|
||||
"type": "matplotlib",
|
||||
"image_data": image_data,
|
||||
}
|
||||
)
|
||||
req_con.send_bytes(_json_dump.encode("utf-8"))
|
||||
resp = _json.loads(resp_con.recv_bytes().decode("utf-8"))
|
||||
log.info(resp)
|
||||
|
||||
|
||||
FigureCanvas = CustomFigureCanvas
|
||||
FigureManager = CustomFigureManager
|
|
@ -1,21 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
|
||||
DIR = os.path.dirname(os.path.realpath(__file__))
|
||||
CODE_ENV_PREFIX_FILE = os.path.join(DIR, "code_env_prefix.py")
|
||||
CODE_ENV_PREFIX = None
|
||||
|
||||
|
||||
def get_code_env_prefix() -> str:
|
||||
global CODE_ENV_PREFIX
|
||||
|
||||
if CODE_ENV_PREFIX is None:
|
||||
with open(CODE_ENV_PREFIX_FILE, "r") as f:
|
||||
CODE_ENV_PREFIX = f.read()
|
||||
|
||||
return CODE_ENV_PREFIX
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import RagToolRuntimeConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: RagToolRuntimeConfig, deps: Dict[Api, Any]):
|
||||
async def get_provider_impl(config: RagToolRuntimeConfig, deps: dict[Api, Any]):
|
||||
from .memory import MemoryToolRuntimeImpl
|
||||
|
||||
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference])
|
||||
|
|
|
@ -4,12 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RagToolRuntimeConfig(BaseModel):
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {}
|
||||
|
|
|
@ -8,7 +8,7 @@ import asyncio
|
|||
import logging
|
||||
import secrets
|
||||
import string
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
|
@ -74,7 +74,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
|||
|
||||
async def insert(
|
||||
self,
|
||||
documents: List[RAGDocument],
|
||||
documents: list[RAGDocument],
|
||||
vector_db_id: str,
|
||||
chunk_size_in_tokens: int = 512,
|
||||
) -> None:
|
||||
|
@ -101,8 +101,8 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
|||
async def query(
|
||||
self,
|
||||
content: InterleavedContent,
|
||||
vector_db_ids: List[str],
|
||||
query_config: Optional[RAGQueryConfig] = None,
|
||||
vector_db_ids: list[str],
|
||||
query_config: RAGQueryConfig | None = None,
|
||||
) -> RAGQueryResult:
|
||||
if not vector_db_ids:
|
||||
return RAGQueryResult(content=None)
|
||||
|
@ -123,7 +123,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
|||
)
|
||||
for vector_db_id in vector_db_ids
|
||||
]
|
||||
results: List[QueryChunksResponse] = await asyncio.gather(*tasks)
|
||||
results: list[QueryChunksResponse] = await asyncio.gather(*tasks)
|
||||
chunks = [c for r in results for c in r.chunks]
|
||||
scores = [s for r in results for s in r.scores]
|
||||
|
||||
|
@ -168,7 +168,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
|||
)
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||
) -> ListToolDefsResponse:
|
||||
# Parameters are not listed since these methods are not yet invoked automatically
|
||||
# by the LLM. The method is only implemented so things like /tools can list without
|
||||
|
@ -193,7 +193,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
|||
]
|
||||
)
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
|
||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
||||
vector_db_ids = kwargs.get("vector_db_ids", [])
|
||||
query_config = kwargs.get("query_config")
|
||||
if query_config:
|
||||
|
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import ChromaVectorIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: ChromaVectorIOConfig, deps: Dict[Api, Any]):
|
||||
async def get_provider_impl(config: ChromaVectorIOConfig, deps: dict[Api, Any]):
|
||||
from llama_stack.providers.remote.vector_io.chroma.chroma import (
|
||||
ChromaVectorIOAdapter,
|
||||
)
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -13,5 +13,5 @@ class ChromaVectorIOConfig(BaseModel):
|
|||
db_path: str
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, db_path: str = "${env.CHROMADB_PATH}", **kwargs: Any) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, db_path: str = "${env.CHROMADB_PATH}", **kwargs: Any) -> dict[str, Any]:
|
||||
return {"db_path": db_path}
|
||||
|
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import FaissVectorIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: FaissVectorIOConfig, deps: Dict[Api, Any]):
|
||||
async def get_provider_impl(config: FaissVectorIOConfig, deps: dict[Api, Any]):
|
||||
from .faiss import FaissVectorIOAdapter
|
||||
|
||||
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -20,7 +20,7 @@ class FaissVectorIOConfig(BaseModel):
|
|||
kvstore: KVStoreConfig
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
|
|
|
@ -9,7 +9,7 @@ import base64
|
|||
import io
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
import faiss
|
||||
import numpy as np
|
||||
|
@ -84,7 +84,7 @@ class FaissIndex(EmbeddingIndex):
|
|||
|
||||
await self.kvstore.delete(f"{FAISS_INDEX_PREFIX}{self.bank_id}")
|
||||
|
||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
||||
# Add dimension check
|
||||
embedding_dim = embeddings.shape[1] if len(embeddings.shape) > 1 else embeddings.shape[0]
|
||||
if embedding_dim != self.index.d:
|
||||
|
@ -159,7 +159,7 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
inference_api=self.inference_api,
|
||||
)
|
||||
|
||||
async def list_vector_dbs(self) -> List[VectorDB]:
|
||||
async def list_vector_dbs(self) -> list[VectorDB]:
|
||||
return [i.vector_db for i in self.cache.values()]
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
|
@ -176,8 +176,8 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
chunks: List[Chunk],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
index = self.cache.get(vector_db_id)
|
||||
if index is None:
|
||||
|
@ -189,7 +189,7 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
self,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
index = self.cache.get(vector_db_id)
|
||||
if index is None:
|
||||
|
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import MilvusVectorIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: MilvusVectorIOConfig, deps: Dict[Api, Any]):
|
||||
async def get_provider_impl(config: MilvusVectorIOConfig, deps: dict[Api, Any]):
|
||||
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter
|
||||
|
||||
impl = MilvusVectorIOAdapter(config, deps[Api.inference])
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -16,5 +16,5 @@ class MilvusVectorIOConfig(BaseModel):
|
|||
db_path: str
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {"db_path": "${env.MILVUS_DB_PATH}"}
|
||||
|
|
|
@ -4,14 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import QdrantVectorIOConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: QdrantVectorIOConfig, deps: Dict[Api, ProviderSpec]):
|
||||
async def get_adapter_impl(config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec]):
|
||||
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
|
||||
|
||||
impl = QdrantVectorIOAdapter(config, deps[Api.inference])
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -17,7 +17,7 @@ class QdrantVectorIOConfig(BaseModel):
|
|||
path: str
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||
return {
|
||||
"path": "${env.QDRANT_PATH:~/.llama/" + __distro_dir__ + "}/" + "qdrant.db",
|
||||
}
|
||||
|
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import SQLiteVectorIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: SQLiteVectorIOConfig, deps: Dict[Api, Any]):
|
||||
async def get_provider_impl(config: SQLiteVectorIOConfig, deps: dict[Api, Any]):
|
||||
from .sqlite_vec import SQLiteVecVectorIOAdapter
|
||||
|
||||
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -13,7 +13,7 @@ class SQLiteVectorIOConfig(BaseModel):
|
|||
db_path: str
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||
return {
|
||||
"db_path": "${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + "sqlite_vec.db",
|
||||
}
|
||||
|
|
|
@ -10,7 +10,7 @@ import logging
|
|||
import sqlite3
|
||||
import struct
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import sqlite_vec
|
||||
|
@ -25,7 +25,7 @@ from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, Vect
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def serialize_vector(vector: List[float]) -> bytes:
|
||||
def serialize_vector(vector: list[float]) -> bytes:
|
||||
"""Serialize a list of floats into a compact binary representation."""
|
||||
return struct.pack(f"{len(vector)}f", *vector)
|
||||
|
||||
|
@ -98,7 +98,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
|
||||
await asyncio.to_thread(_drop_tables)
|
||||
|
||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray, batch_size: int = 500):
|
||||
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray, batch_size: int = 500):
|
||||
"""
|
||||
Add new chunks along with their embeddings using batch inserts.
|
||||
For each chunk, we insert its JSON into the metadata table and then insert its
|
||||
|
@ -209,7 +209,7 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
def __init__(self, config, inference_api: Inference) -> None:
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.cache: Dict[str, VectorDBWithIndex] = {}
|
||||
self.cache: dict[str, VectorDBWithIndex] = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
def _setup_connection():
|
||||
|
@ -264,7 +264,7 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.config.db_path, vector_db.identifier)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||
|
||||
async def list_vector_dbs(self) -> List[VectorDB]:
|
||||
async def list_vector_dbs(self) -> list[VectorDB]:
|
||||
return [v.vector_db for v in self.cache.values()]
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
|
@ -286,7 +286,7 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
|
||||
await asyncio.to_thread(_delete_vector_db_from_registry)
|
||||
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: List[Chunk], ttl_seconds: Optional[int] = None) -> None:
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||
if vector_db_id not in self.cache:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found. Found: {list(self.cache.keys())}")
|
||||
# The VectorDBWithIndex helper is expected to compute embeddings via the inference_api
|
||||
|
@ -294,7 +294,7 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
await self.cache[vector_db_id].insert_chunks(chunks)
|
||||
|
||||
async def query_chunks(
|
||||
self, vector_db_id: str, query: Any, params: Optional[Dict[str, Any]] = None
|
||||
self, vector_db_id: str, query: Any, params: dict[str, Any] | None = None
|
||||
) -> QueryChunksResponse:
|
||||
if vector_db_id not in self.cache:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
|
@ -303,5 +303,5 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
|
||||
def generate_chunk_id(document_id: str, chunk_text: str) -> str:
|
||||
"""Generate a unique chunk ID using a hash of document ID and chunk text."""
|
||||
hash_input = f"{document_id}:{chunk_text}".encode("utf-8")
|
||||
hash_input = f"{document_id}:{chunk_text}".encode()
|
||||
return str(uuid.UUID(hashlib.md5(hash_input).hexdigest()))
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
Api,
|
||||
|
@ -14,7 +13,7 @@ from llama_stack.providers.datatypes import (
|
|||
from llama_stack.providers.utils.kvstore import kvstore_dependencies
|
||||
|
||||
|
||||
def available_providers() -> List[ProviderSpec]:
|
||||
def available_providers() -> list[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.agents,
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
|
@ -15,7 +14,7 @@ from llama_stack.providers.datatypes import (
|
|||
)
|
||||
|
||||
|
||||
def available_providers() -> List[ProviderSpec]:
|
||||
def available_providers() -> list[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.datasetio,
|
||||
|
|
|
@ -4,12 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List
|
||||
|
||||
from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec
|
||||
|
||||
|
||||
def available_providers() -> List[ProviderSpec]:
|
||||
def available_providers() -> list[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.eval,
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
|
@ -29,7 +28,7 @@ META_REFERENCE_DEPS = [
|
|||
]
|
||||
|
||||
|
||||
def available_providers() -> List[ProviderSpec]:
|
||||
def available_providers() -> list[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.inference,
|
||||
|
|
|
@ -4,12 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List
|
||||
|
||||
from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec
|
||||
|
||||
|
||||
def available_providers() -> List[ProviderSpec]:
|
||||
def available_providers() -> list[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.post_training,
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
|
@ -15,7 +14,7 @@ from llama_stack.providers.datatypes import (
|
|||
)
|
||||
|
||||
|
||||
def available_providers() -> List[ProviderSpec]:
|
||||
def available_providers() -> list[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.safety,
|
||||
|
|
|
@ -4,12 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List
|
||||
|
||||
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
|
||||
|
||||
|
||||
def available_providers() -> List[ProviderSpec]:
|
||||
def available_providers() -> list[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.scoring,
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
Api,
|
||||
|
@ -13,7 +12,7 @@ from llama_stack.providers.datatypes import (
|
|||
)
|
||||
|
||||
|
||||
def available_providers() -> List[ProviderSpec]:
|
||||
def available_providers() -> list[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.telemetry,
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
|
@ -15,7 +14,7 @@ from llama_stack.providers.datatypes import (
|
|||
)
|
||||
|
||||
|
||||
def available_providers() -> List[ProviderSpec]:
|
||||
def available_providers() -> list[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.tool_runtime,
|
||||
|
@ -36,13 +35,6 @@ def available_providers() -> List[ProviderSpec]:
|
|||
config_class="llama_stack.providers.inline.tool_runtime.rag.config.RagToolRuntimeConfig",
|
||||
api_dependencies=[Api.vector_io, Api.inference],
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.tool_runtime,
|
||||
provider_type="inline::code-interpreter",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.inline.tool_runtime.code_interpreter",
|
||||
config_class="llama_stack.providers.inline.tool_runtime.code_interpreter.config.CodeInterpreterToolConfig",
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.tool_runtime,
|
||||
adapter=AdapterSpec(
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
|
@ -15,7 +14,7 @@ from llama_stack.providers.datatypes import (
|
|||
)
|
||||
|
||||
|
||||
def available_providers() -> List[ProviderSpec]:
|
||||
def available_providers() -> list[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -17,7 +17,7 @@ class HuggingfaceDatasetIOConfig(BaseModel):
|
|||
kvstore: KVStoreConfig
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue