From c223b1862b1740494c354eac8f26ee9f9996f9f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Tue, 25 Feb 2025 20:06:47 +0100 Subject: [PATCH] fix: resolve type hint issues and import dependencies (#1176) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? - Fixed type hinting and missing imports across multiple modules. - Improved compatibility by using `TYPE_CHECKING` for conditional imports. - Updated `pyproject.toml` to enforce stricter linting. Signed-off-by: Sébastien Han Signed-off-by: Sébastien Han --- .../inline/inference/meta_reference/parallel_utils.py | 2 +- .../inline/inference/meta_reference/quantization/loader.py | 1 + .../code_interpreter/matplotlib_custom_backend.py | 5 ++++- .../providers/remote/inference/databricks/databricks.py | 1 + .../providers/tests/post_training/test_post_training.py | 2 ++ llama_stack/providers/utils/inference/embedding_mixin.py | 7 +++++-- pyproject.toml | 1 - 7 files changed, 14 insertions(+), 5 deletions(-) diff --git a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py index 711a4632d..658267f7f 100644 --- a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +++ b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py @@ -231,7 +231,7 @@ def worker_process_entrypoint( while True: try: task = req_gen.send(result) - if isinstance(task, str) and task == _END_SENTINEL: + if isinstance(task, str) and task == EndSentinel(): break assert isinstance(task, TaskRequest) diff --git a/llama_stack/providers/inline/inference/meta_reference/quantization/loader.py b/llama_stack/providers/inline/inference/meta_reference/quantization/loader.py index a2dc00916..ba45acc2b 100644 --- a/llama_stack/providers/inline/inference/meta_reference/quantization/loader.py +++ b/llama_stack/providers/inline/inference/meta_reference/quantization/loader.py @@ -12,6 +12,7 @@ import os from typing import Any, Dict, List, Optional import torch +from fairscale.nn.model_parallel.initialize import get_model_parallel_rank from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region from llama_models.llama3.api.args import ModelArgs diff --git a/llama_stack/providers/inline/tool_runtime/code_interpreter/matplotlib_custom_backend.py b/llama_stack/providers/inline/tool_runtime/code_interpreter/matplotlib_custom_backend.py index 7fec08cf2..6454358a5 100644 --- a/llama_stack/providers/inline/tool_runtime/code_interpreter/matplotlib_custom_backend.py +++ b/llama_stack/providers/inline/tool_runtime/code_interpreter/matplotlib_custom_backend.py @@ -73,7 +73,10 @@ def show(): image_data.append({"image_base64": image_base64}) buf.close() - req_con, resp_con = _open_connections() + # The _open_connections method is dynamically made available to + # the interpreter by bundling code from "code_env_prefix.py" -- by literally prefixing it -- and + # then "eval"ing it within a sandboxed interpreter. + req_con, resp_con = _open_connections() # noqa: F821 _json_dump = _json.dumps( { diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index 75751c8b1..9db430e4d 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -24,6 +24,7 @@ from llama_stack.apis.inference import ( SamplingParams, TextTruncation, ToolChoice, + ToolConfig, ToolDefinition, ToolPromptFormat, ) diff --git a/llama_stack/providers/tests/post_training/test_post_training.py b/llama_stack/providers/tests/post_training/test_post_training.py index c2bb4d98b..aefef5332 100644 --- a/llama_stack/providers/tests/post_training/test_post_training.py +++ b/llama_stack/providers/tests/post_training/test_post_training.py @@ -3,6 +3,8 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from typing import List + import pytest from llama_stack.apis.common.job_types import JobStatus diff --git a/llama_stack/providers/utils/inference/embedding_mixin.py b/llama_stack/providers/utils/inference/embedding_mixin.py index f43475554..2a16e7f40 100644 --- a/llama_stack/providers/utils/inference/embedding_mixin.py +++ b/llama_stack/providers/utils/inference/embedding_mixin.py @@ -5,7 +5,10 @@ # the root directory of this source tree. import logging -from typing import List, Optional +from typing import TYPE_CHECKING, List, Optional + +if TYPE_CHECKING: + from sentence_transformers import SentenceTransformer from llama_stack.apis.inference import ( EmbeddingsResponse, @@ -40,7 +43,7 @@ class SentenceTransformerEmbeddingMixin: ) return EmbeddingsResponse(embeddings=embeddings) - def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer": + def _load_sentence_transformer_model(self, model: str) -> SentenceTransformer: global EMBEDDING_MODELS loaded_model = EMBEDDING_MODELS.get(model) diff --git a/pyproject.toml b/pyproject.toml index d65f30c30..2ed2c4fa9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -129,7 +129,6 @@ ignore = [ "E721", "E741", "F405", - "F821", "F841", "C408", # ignored because we like the dict keyword argument syntax "E302",