chore: fix typing hints for get_provider_impl deps arguments (#1544)

# What does this PR do?

It's a dict that may contain different types, as per
resolver:instantiate_provider implementation. (AFAIU it also never
contains ProviderSpecs, but *instances* of provider implementations.)

[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

## Test Plan

mypy passing if enabled checks for these modules. (See #1543)

[//]: # (## Documentation)

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-03-11 13:07:28 -04:00 committed by GitHub
parent 04106b94aa
commit c3d7d17bc4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 52 additions and 40 deletions

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict from typing import Any, Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec from llama_stack.distribution.datatypes import Api
from .config import MetaReferenceAgentsImplConfig from .config import MetaReferenceAgentsImplConfig
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Api, ProviderSpec]): async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Api, Any]):
from .agents import MetaReferenceAgentsImpl from .agents import MetaReferenceAgentsImpl
impl = MetaReferenceAgentsImpl( impl = MetaReferenceAgentsImpl(

View file

@ -4,12 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict
from .config import LocalFSDatasetIOConfig from .config import LocalFSDatasetIOConfig
async def get_provider_impl( async def get_provider_impl(
config: LocalFSDatasetIOConfig, config: LocalFSDatasetIOConfig,
_deps, _deps: Dict[str, Any],
): ):
from .datasetio import LocalFSDatasetIOImpl from .datasetio import LocalFSDatasetIOImpl

View file

@ -3,16 +3,16 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict from typing import Any, Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec from llama_stack.distribution.datatypes import Api
from .config import MetaReferenceEvalConfig from .config import MetaReferenceEvalConfig
async def get_provider_impl( async def get_provider_impl(
config: MetaReferenceEvalConfig, config: MetaReferenceEvalConfig,
deps: Dict[Api, ProviderSpec], deps: Dict[Api, Any],
): ):
from .eval import MetaReferenceEvalImpl from .eval import MetaReferenceEvalImpl

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Union from typing import Any, Dict, Union
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
async def get_provider_impl( async def get_provider_impl(
config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig], config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig],
_deps, _deps: Dict[str, Any],
): ):
from .inference import MetaReferenceInferenceImpl from .inference import MetaReferenceInferenceImpl

View file

@ -4,6 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict
from llama_stack.providers.inline.inference.sentence_transformers.config import ( from llama_stack.providers.inline.inference.sentence_transformers.config import (
SentenceTransformersInferenceConfig, SentenceTransformersInferenceConfig,
) )
@ -11,7 +13,7 @@ from llama_stack.providers.inline.inference.sentence_transformers.config import
async def get_provider_impl( async def get_provider_impl(
config: SentenceTransformersInferenceConfig, config: SentenceTransformersInferenceConfig,
_deps, _deps: Dict[str, Any],
): ):
from .sentence_transformers import SentenceTransformersInferenceImpl from .sentence_transformers import SentenceTransformersInferenceImpl

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any from typing import Any, Dict
from .config import VLLMConfig from .config import VLLMConfig
async def get_provider_impl(config: VLLMConfig, _deps) -> Any: async def get_provider_impl(config: VLLMConfig, _deps: Dict[str, Any]):
from .vllm import VLLMInferenceImpl from .vllm import VLLMInferenceImpl
impl = VLLMInferenceImpl(config) impl = VLLMInferenceImpl(config)

View file

@ -4,9 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict from typing import Any, Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec from llama_stack.distribution.datatypes import Api
from .config import TorchtunePostTrainingConfig from .config import TorchtunePostTrainingConfig
@ -15,7 +15,7 @@ from .config import TorchtunePostTrainingConfig
async def get_provider_impl( async def get_provider_impl(
config: TorchtunePostTrainingConfig, config: TorchtunePostTrainingConfig,
deps: Dict[Api, ProviderSpec], deps: Dict[Api, Any],
): ):
from .post_training import TorchtunePostTrainingImpl from .post_training import TorchtunePostTrainingImpl

View file

@ -4,10 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict
from .config import CodeScannerConfig from .config import CodeScannerConfig
async def get_provider_impl(config: CodeScannerConfig, deps): async def get_provider_impl(config: CodeScannerConfig, deps: Dict[str, Any]):
from .code_scanner import MetaReferenceCodeScannerSafetyImpl from .code_scanner import MetaReferenceCodeScannerSafetyImpl
impl = MetaReferenceCodeScannerSafetyImpl(config, deps) impl = MetaReferenceCodeScannerSafetyImpl(config, deps)

View file

@ -4,10 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict
from .config import LlamaGuardConfig from .config import LlamaGuardConfig
async def get_provider_impl(config: LlamaGuardConfig, deps): async def get_provider_impl(config: LlamaGuardConfig, deps: Dict[str, Any]):
from .llama_guard import LlamaGuardSafetyImpl from .llama_guard import LlamaGuardSafetyImpl
assert isinstance(config, LlamaGuardConfig), f"Unexpected config type: {type(config)}" assert isinstance(config, LlamaGuardConfig), f"Unexpected config type: {type(config)}"

View file

@ -4,10 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict
from .config import PromptGuardConfig # noqa: F401 from .config import PromptGuardConfig # noqa: F401
async def get_provider_impl(config: PromptGuardConfig, deps): async def get_provider_impl(config: PromptGuardConfig, deps: Dict[str, Any]):
from .prompt_guard import PromptGuardSafetyImpl from .prompt_guard import PromptGuardSafetyImpl
impl = PromptGuardSafetyImpl(config, deps) impl = PromptGuardSafetyImpl(config, deps)

View file

@ -3,16 +3,16 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict from typing import Any, Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec from llama_stack.distribution.datatypes import Api
from .config import BasicScoringConfig from .config import BasicScoringConfig
async def get_provider_impl( async def get_provider_impl(
config: BasicScoringConfig, config: BasicScoringConfig,
deps: Dict[Api, ProviderSpec], deps: Dict[Api, Any],
): ):
from .scoring import BasicScoringImpl from .scoring import BasicScoringImpl

View file

@ -3,11 +3,11 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict from typing import Any, Dict
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.distribution.datatypes import Api, ProviderSpec from llama_stack.distribution.datatypes import Api
from .config import BraintrustScoringConfig from .config import BraintrustScoringConfig
@ -18,7 +18,7 @@ class BraintrustProviderDataValidator(BaseModel):
async def get_provider_impl( async def get_provider_impl(
config: BraintrustScoringConfig, config: BraintrustScoringConfig,
deps: Dict[Api, ProviderSpec], deps: Dict[Api, Any],
): ):
from .braintrust import BraintrustScoringImpl from .braintrust import BraintrustScoringImpl

View file

@ -3,16 +3,16 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict from typing import Any, Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec from llama_stack.distribution.datatypes import Api
from .config import LlmAsJudgeScoringConfig from .config import LlmAsJudgeScoringConfig
async def get_provider_impl( async def get_provider_impl(
config: LlmAsJudgeScoringConfig, config: LlmAsJudgeScoringConfig,
deps: Dict[Api, ProviderSpec], deps: Dict[Api, Any],
): ):
from .scoring import LlmAsJudgeScoringImpl from .scoring import LlmAsJudgeScoringImpl

View file

@ -4,12 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict
from .config import CodeInterpreterToolConfig from .config import CodeInterpreterToolConfig
__all__ = ["CodeInterpreterToolConfig", "CodeInterpreterToolRuntimeImpl"] __all__ = ["CodeInterpreterToolConfig", "CodeInterpreterToolRuntimeImpl"]
async def get_provider_impl(config: CodeInterpreterToolConfig, _deps): async def get_provider_impl(config: CodeInterpreterToolConfig, _deps: Dict[str, Any]):
from .code_interpreter import CodeInterpreterToolRuntimeImpl from .code_interpreter import CodeInterpreterToolRuntimeImpl
impl = CodeInterpreterToolRuntimeImpl(config) impl = CodeInterpreterToolRuntimeImpl(config)

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict from typing import Any, Dict
from llama_stack.providers.datatypes import Api, ProviderSpec from llama_stack.providers.datatypes import Api
from .config import ChromaVectorIOConfig from .config import ChromaVectorIOConfig
async def get_provider_impl(config: ChromaVectorIOConfig, deps: Dict[Api, ProviderSpec]): async def get_provider_impl(config: ChromaVectorIOConfig, deps: Dict[Api, Any]):
from llama_stack.providers.remote.vector_io.chroma.chroma import ( from llama_stack.providers.remote.vector_io.chroma.chroma import (
ChromaVectorIOAdapter, ChromaVectorIOAdapter,
) )

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict from typing import Any, Dict
from llama_stack.providers.datatypes import Api, ProviderSpec from llama_stack.providers.datatypes import Api
from .config import FaissVectorIOConfig from .config import FaissVectorIOConfig
async def get_provider_impl(config: FaissVectorIOConfig, deps: Dict[Api, ProviderSpec]): async def get_provider_impl(config: FaissVectorIOConfig, deps: Dict[Api, Any]):
from .faiss import FaissVectorIOAdapter from .faiss import FaissVectorIOAdapter
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}" assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict from typing import Any, Dict
from llama_stack.providers.datatypes import Api, ProviderSpec from llama_stack.providers.datatypes import Api
from .config import MilvusVectorIOConfig from .config import MilvusVectorIOConfig
async def get_provider_impl(config: MilvusVectorIOConfig, deps: Dict[Api, ProviderSpec]): async def get_provider_impl(config: MilvusVectorIOConfig, deps: Dict[Api, Any]):
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter
impl = MilvusVectorIOAdapter(config, deps[Api.inference]) impl = MilvusVectorIOAdapter(config, deps[Api.inference])

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict from typing import Any, Dict
from llama_stack.providers.datatypes import Api, ProviderSpec from llama_stack.providers.datatypes import Api
from .config import SQLiteVectorIOConfig from .config import SQLiteVectorIOConfig
async def get_provider_impl(config: SQLiteVectorIOConfig, deps: Dict[Api, ProviderSpec]): async def get_provider_impl(config: SQLiteVectorIOConfig, deps: Dict[Api, Any]):
from .sqlite_vec import SQLiteVecVectorIOAdapter from .sqlite_vec import SQLiteVecVectorIOAdapter
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}" assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"