From 926c3ada41a40e034cd3b57067107e755b747ed1 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Fri, 26 Sep 2025 11:44:43 -0400 Subject: [PATCH] chore: prune mypy exclude list (#3561) # What does this PR do? prune the mypy exclude list, build a stronger foundation for quality code ## Test Plan ci --- llama_stack/apis/inference/inference.py | 1 + .../remote/inference/databricks/databricks.py | 2 +- .../remote/inference/groq/__init__.py | 4 +--- .../remote/inference/sambanova/__init__.py | 4 +--- .../remote/inference/sambanova/sambanova.py | 2 +- .../utils/inference/model_registry.py | 2 +- llama_stack/providers/utils/kvstore/config.py | 10 +++++----- .../utils/kvstore/mongodb/mongodb.py | 11 +++++++--- .../providers/utils/kvstore/sqlite/sqlite.py | 7 +++++++ pyproject.toml | 20 ------------------- 10 files changed, 26 insertions(+), 37 deletions(-) diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 4f5332b5f..c43cee6a8 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -914,6 +914,7 @@ class OpenAIEmbeddingData(BaseModel): """ object: Literal["embedding"] = "embedding" + # TODO: consider dropping str and using openai.types.embeddings.Embedding instead of OpenAIEmbeddingData embedding: list[float] | str index: int diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index 25fd9f3b7..6eac6e4f4 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -24,7 +24,6 @@ from llama_stack.apis.inference import ( LogProbConfig, Message, Model, - ModelType, OpenAICompletion, ResponseFormat, SamplingParams, @@ -34,6 +33,7 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) +from llama_stack.apis.models import ModelType from llama_stack.log import get_logger from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin diff --git a/llama_stack/providers/remote/inference/groq/__init__.py b/llama_stack/providers/remote/inference/groq/__init__.py index 1506e0b06..cca333ccf 100644 --- a/llama_stack/providers/remote/inference/groq/__init__.py +++ b/llama_stack/providers/remote/inference/groq/__init__.py @@ -4,12 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.apis.inference import Inference - from .config import GroqConfig -async def get_adapter_impl(config: GroqConfig, _deps) -> Inference: +async def get_adapter_impl(config: GroqConfig, _deps): # import dynamically so the import is used only when it is needed from .groq import GroqInferenceAdapter diff --git a/llama_stack/providers/remote/inference/sambanova/__init__.py b/llama_stack/providers/remote/inference/sambanova/__init__.py index a3a7b8fbd..2a5448041 100644 --- a/llama_stack/providers/remote/inference/sambanova/__init__.py +++ b/llama_stack/providers/remote/inference/sambanova/__init__.py @@ -4,12 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.apis.inference import Inference - from .config import SambaNovaImplConfig -async def get_adapter_impl(config: SambaNovaImplConfig, _deps) -> Inference: +async def get_adapter_impl(config: SambaNovaImplConfig, _deps): from .sambanova import SambaNovaInferenceAdapter assert isinstance(config, SambaNovaImplConfig), f"Unexpected config type: {type(config)}" diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index 6121e81f7..4d8fd11cd 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -25,7 +25,7 @@ class SambaNovaInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin): def __init__(self, config: SambaNovaImplConfig): self.config = config - self.environment_available_models = [] + self.environment_available_models: list[str] = [] LiteLLMOpenAIMixin.__init__( self, litellm_provider_name="sambanova", diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index ff15b2d43..746ebd8f6 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -20,7 +20,7 @@ logger = get_logger(name=__name__, category="providers::utils") class RemoteInferenceProviderConfig(BaseModel): - allowed_models: list[str] | None = Field( + allowed_models: list[str] | None = Field( # TODO: make this non-optional and give a list() default default=None, description="List of models that should be registered with the model registry. If None, all models are allowed.", ) diff --git a/llama_stack/providers/utils/kvstore/config.py b/llama_stack/providers/utils/kvstore/config.py index d1747d65b..7b6a79350 100644 --- a/llama_stack/providers/utils/kvstore/config.py +++ b/llama_stack/providers/utils/kvstore/config.py @@ -28,7 +28,7 @@ class CommonConfig(BaseModel): class RedisKVStoreConfig(CommonConfig): - type: Literal[KVStoreType.redis.value] = KVStoreType.redis.value + type: Literal["redis"] = KVStoreType.redis.value host: str = "localhost" port: int = 6379 @@ -50,7 +50,7 @@ class RedisKVStoreConfig(CommonConfig): class SqliteKVStoreConfig(CommonConfig): - type: Literal[KVStoreType.sqlite.value] = KVStoreType.sqlite.value + type: Literal["sqlite"] = KVStoreType.sqlite.value db_path: str = Field( default=(RUNTIME_BASE_DIR / "kvstore.db").as_posix(), description="File path for the sqlite database", @@ -69,7 +69,7 @@ class SqliteKVStoreConfig(CommonConfig): class PostgresKVStoreConfig(CommonConfig): - type: Literal[KVStoreType.postgres.value] = KVStoreType.postgres.value + type: Literal["postgres"] = KVStoreType.postgres.value host: str = "localhost" port: int = 5432 db: str = "llamastack" @@ -113,11 +113,11 @@ class PostgresKVStoreConfig(CommonConfig): class MongoDBKVStoreConfig(CommonConfig): - type: Literal[KVStoreType.mongodb.value] = KVStoreType.mongodb.value + type: Literal["mongodb"] = KVStoreType.mongodb.value host: str = "localhost" port: int = 27017 db: str = "llamastack" - user: str = None + user: str | None = None password: str | None = None collection_name: str = "llamastack_kvstore" diff --git a/llama_stack/providers/utils/kvstore/mongodb/mongodb.py b/llama_stack/providers/utils/kvstore/mongodb/mongodb.py index bab87a4aa..4d60949c1 100644 --- a/llama_stack/providers/utils/kvstore/mongodb/mongodb.py +++ b/llama_stack/providers/utils/kvstore/mongodb/mongodb.py @@ -7,6 +7,7 @@ from datetime import datetime from pymongo import AsyncMongoClient +from pymongo.asynchronous.collection import AsyncCollection from llama_stack.log import get_logger from llama_stack.providers.utils.kvstore import KVStore @@ -19,8 +20,13 @@ log = get_logger(name=__name__, category="providers::utils") class MongoDBKVStoreImpl(KVStore): def __init__(self, config: MongoDBKVStoreConfig): self.config = config - self.conn = None - self.collection = None + self.conn: AsyncMongoClient | None = None + + @property + def collection(self) -> AsyncCollection: + if self.conn is None: + raise RuntimeError("MongoDB connection is not initialized") + return self.conn[self.config.db][self.config.collection_name] async def initialize(self) -> None: try: @@ -32,7 +38,6 @@ class MongoDBKVStoreImpl(KVStore): } conn_creds = {k: v for k, v in conn_creds.items() if v is not None} self.conn = AsyncMongoClient(**conn_creds) - self.collection = self.conn[self.config.db][self.config.collection_name] except Exception as e: log.exception("Could not connect to MongoDB database server") raise RuntimeError("Could not connect to MongoDB database server") from e diff --git a/llama_stack/providers/utils/kvstore/sqlite/sqlite.py b/llama_stack/providers/utils/kvstore/sqlite/sqlite.py index 6a6a170dc..5b782902e 100644 --- a/llama_stack/providers/utils/kvstore/sqlite/sqlite.py +++ b/llama_stack/providers/utils/kvstore/sqlite/sqlite.py @@ -9,9 +9,13 @@ from datetime import datetime import aiosqlite +from llama_stack.log import get_logger + from ..api import KVStore from ..config import SqliteKVStoreConfig +logger = get_logger(name=__name__, category="providers::utils") + class SqliteKVStoreImpl(KVStore): def __init__(self, config: SqliteKVStoreConfig): @@ -50,6 +54,9 @@ class SqliteKVStoreImpl(KVStore): if row is None: return None value, expiration = row + if not isinstance(value, str): + logger.warning(f"Expected string value for key {key}, got {type(value)}, returning None") + return None return value async def delete(self, key: str) -> None: diff --git a/pyproject.toml b/pyproject.toml index 86a32f978..a26c4d645 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -259,15 +259,12 @@ exclude = [ "^llama_stack/models/llama/llama3/tokenizer\\.py$", "^llama_stack/models/llama/llama3/tool_utils\\.py$", "^llama_stack/providers/inline/agents/meta_reference/", - "^llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$", - "^llama_stack/providers/inline/agents/meta_reference/agents\\.py$", "^llama_stack/providers/inline/datasetio/localfs/", "^llama_stack/providers/inline/eval/meta_reference/eval\\.py$", "^llama_stack/providers/inline/inference/meta_reference/inference\\.py$", "^llama_stack/models/llama/llama3/generation\\.py$", "^llama_stack/models/llama/llama3/multimodal/model\\.py$", "^llama_stack/models/llama/llama4/", - "^llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls\\.py$", "^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$", "^llama_stack/providers/inline/post_training/common/validator\\.py$", "^llama_stack/providers/inline/safety/code_scanner/", @@ -278,19 +275,13 @@ exclude = [ "^llama_stack/providers/remote/agents/sample/", "^llama_stack/providers/remote/datasetio/huggingface/", "^llama_stack/providers/remote/datasetio/nvidia/", - "^llama_stack/providers/remote/inference/anthropic/", "^llama_stack/providers/remote/inference/bedrock/", "^llama_stack/providers/remote/inference/cerebras/", "^llama_stack/providers/remote/inference/databricks/", "^llama_stack/providers/remote/inference/fireworks/", - "^llama_stack/providers/remote/inference/gemini/", - "^llama_stack/providers/remote/inference/groq/", "^llama_stack/providers/remote/inference/nvidia/", - "^llama_stack/providers/remote/inference/openai/", "^llama_stack/providers/remote/inference/passthrough/", "^llama_stack/providers/remote/inference/runpod/", - "^llama_stack/providers/remote/inference/sambanova/", - "^llama_stack/providers/remote/inference/sample/", "^llama_stack/providers/remote/inference/tgi/", "^llama_stack/providers/remote/inference/together/", "^llama_stack/providers/remote/inference/watsonx/", @@ -310,7 +301,6 @@ exclude = [ "^llama_stack/providers/remote/vector_io/qdrant/", "^llama_stack/providers/remote/vector_io/sample/", "^llama_stack/providers/remote/vector_io/weaviate/", - "^llama_stack/providers/tests/conftest\\.py$", "^llama_stack/providers/utils/bedrock/client\\.py$", "^llama_stack/providers/utils/bedrock/refreshable_boto_session\\.py$", "^llama_stack/providers/utils/inference/embedding_mixin\\.py$", @@ -318,12 +308,9 @@ exclude = [ "^llama_stack/providers/utils/inference/model_registry\\.py$", "^llama_stack/providers/utils/inference/openai_compat\\.py$", "^llama_stack/providers/utils/inference/prompt_adapter\\.py$", - "^llama_stack/providers/utils/kvstore/config\\.py$", "^llama_stack/providers/utils/kvstore/kvstore\\.py$", - "^llama_stack/providers/utils/kvstore/mongodb/mongodb\\.py$", "^llama_stack/providers/utils/kvstore/postgres/postgres\\.py$", "^llama_stack/providers/utils/kvstore/redis/redis\\.py$", - "^llama_stack/providers/utils/kvstore/sqlite/sqlite\\.py$", "^llama_stack/providers/utils/memory/vector_store\\.py$", "^llama_stack/providers/utils/scoring/aggregation_utils\\.py$", "^llama_stack/providers/utils/scoring/base_scoring_fn\\.py$", @@ -331,13 +318,6 @@ exclude = [ "^llama_stack/providers/utils/telemetry/trace_protocol\\.py$", "^llama_stack/providers/utils/telemetry/tracing\\.py$", "^llama_stack/strong_typing/auxiliary\\.py$", - "^llama_stack/strong_typing/deserializer\\.py$", - "^llama_stack/strong_typing/inspection\\.py$", - "^llama_stack/strong_typing/schema\\.py$", - "^llama_stack/strong_typing/serializer\\.py$", - "^llama_stack/distributions/groq/groq\\.py$", - "^llama_stack/distributions/llama_api/llama_api\\.py$", - "^llama_stack/distributions/sambanova/sambanova\\.py$", "^llama_stack/distributions/template\\.py$", ]