mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
chore: fix typing hints for get_provider_impl deps arguments (#1544)
# What does this PR do? It's a dict that may contain different types, as per resolver:instantiate_provider implementation. (AFAIU it also never contains ProviderSpecs, but *instances* of provider implementations.) [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan mypy passing if enabled checks for these modules. (See #1543) [//]: # (## Documentation) Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
04106b94aa
commit
c3d7d17bc4
18 changed files with 52 additions and 40 deletions
|
@ -4,14 +4,14 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
from llama_stack.distribution.datatypes import Api
|
||||||
|
|
||||||
from .config import MetaReferenceAgentsImplConfig
|
from .config import MetaReferenceAgentsImplConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Api, ProviderSpec]):
|
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Api, Any]):
|
||||||
from .agents import MetaReferenceAgentsImpl
|
from .agents import MetaReferenceAgentsImpl
|
||||||
|
|
||||||
impl = MetaReferenceAgentsImpl(
|
impl = MetaReferenceAgentsImpl(
|
||||||
|
|
|
@ -4,12 +4,14 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
from .config import LocalFSDatasetIOConfig
|
from .config import LocalFSDatasetIOConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(
|
async def get_provider_impl(
|
||||||
config: LocalFSDatasetIOConfig,
|
config: LocalFSDatasetIOConfig,
|
||||||
_deps,
|
_deps: Dict[str, Any],
|
||||||
):
|
):
|
||||||
from .datasetio import LocalFSDatasetIOImpl
|
from .datasetio import LocalFSDatasetIOImpl
|
||||||
|
|
||||||
|
|
|
@ -3,16 +3,16 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from typing import Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
from llama_stack.distribution.datatypes import Api
|
||||||
|
|
||||||
from .config import MetaReferenceEvalConfig
|
from .config import MetaReferenceEvalConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(
|
async def get_provider_impl(
|
||||||
config: MetaReferenceEvalConfig,
|
config: MetaReferenceEvalConfig,
|
||||||
deps: Dict[Api, ProviderSpec],
|
deps: Dict[Api, Any],
|
||||||
):
|
):
|
||||||
from .eval import MetaReferenceEvalImpl
|
from .eval import MetaReferenceEvalImpl
|
||||||
|
|
||||||
|
|
|
@ -4,14 +4,14 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Union
|
from typing import Any, Dict, Union
|
||||||
|
|
||||||
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(
|
async def get_provider_impl(
|
||||||
config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig],
|
config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig],
|
||||||
_deps,
|
_deps: Dict[str, Any],
|
||||||
):
|
):
|
||||||
from .inference import MetaReferenceInferenceImpl
|
from .inference import MetaReferenceInferenceImpl
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,8 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
from llama_stack.providers.inline.inference.sentence_transformers.config import (
|
from llama_stack.providers.inline.inference.sentence_transformers.config import (
|
||||||
SentenceTransformersInferenceConfig,
|
SentenceTransformersInferenceConfig,
|
||||||
)
|
)
|
||||||
|
@ -11,7 +13,7 @@ from llama_stack.providers.inline.inference.sentence_transformers.config import
|
||||||
|
|
||||||
async def get_provider_impl(
|
async def get_provider_impl(
|
||||||
config: SentenceTransformersInferenceConfig,
|
config: SentenceTransformersInferenceConfig,
|
||||||
_deps,
|
_deps: Dict[str, Any],
|
||||||
):
|
):
|
||||||
from .sentence_transformers import SentenceTransformersInferenceImpl
|
from .sentence_transformers import SentenceTransformersInferenceImpl
|
||||||
|
|
||||||
|
|
|
@ -4,12 +4,12 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any, Dict
|
||||||
|
|
||||||
from .config import VLLMConfig
|
from .config import VLLMConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: VLLMConfig, _deps) -> Any:
|
async def get_provider_impl(config: VLLMConfig, _deps: Dict[str, Any]):
|
||||||
from .vllm import VLLMInferenceImpl
|
from .vllm import VLLMInferenceImpl
|
||||||
|
|
||||||
impl = VLLMInferenceImpl(config)
|
impl = VLLMInferenceImpl(config)
|
||||||
|
|
|
@ -4,9 +4,9 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
from llama_stack.distribution.datatypes import Api
|
||||||
|
|
||||||
from .config import TorchtunePostTrainingConfig
|
from .config import TorchtunePostTrainingConfig
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@ from .config import TorchtunePostTrainingConfig
|
||||||
|
|
||||||
async def get_provider_impl(
|
async def get_provider_impl(
|
||||||
config: TorchtunePostTrainingConfig,
|
config: TorchtunePostTrainingConfig,
|
||||||
deps: Dict[Api, ProviderSpec],
|
deps: Dict[Api, Any],
|
||||||
):
|
):
|
||||||
from .post_training import TorchtunePostTrainingImpl
|
from .post_training import TorchtunePostTrainingImpl
|
||||||
|
|
||||||
|
|
|
@ -4,10 +4,12 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
from .config import CodeScannerConfig
|
from .config import CodeScannerConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: CodeScannerConfig, deps):
|
async def get_provider_impl(config: CodeScannerConfig, deps: Dict[str, Any]):
|
||||||
from .code_scanner import MetaReferenceCodeScannerSafetyImpl
|
from .code_scanner import MetaReferenceCodeScannerSafetyImpl
|
||||||
|
|
||||||
impl = MetaReferenceCodeScannerSafetyImpl(config, deps)
|
impl = MetaReferenceCodeScannerSafetyImpl(config, deps)
|
||||||
|
|
|
@ -4,10 +4,12 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
from .config import LlamaGuardConfig
|
from .config import LlamaGuardConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: LlamaGuardConfig, deps):
|
async def get_provider_impl(config: LlamaGuardConfig, deps: Dict[str, Any]):
|
||||||
from .llama_guard import LlamaGuardSafetyImpl
|
from .llama_guard import LlamaGuardSafetyImpl
|
||||||
|
|
||||||
assert isinstance(config, LlamaGuardConfig), f"Unexpected config type: {type(config)}"
|
assert isinstance(config, LlamaGuardConfig), f"Unexpected config type: {type(config)}"
|
||||||
|
|
|
@ -4,10 +4,12 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
from .config import PromptGuardConfig # noqa: F401
|
from .config import PromptGuardConfig # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: PromptGuardConfig, deps):
|
async def get_provider_impl(config: PromptGuardConfig, deps: Dict[str, Any]):
|
||||||
from .prompt_guard import PromptGuardSafetyImpl
|
from .prompt_guard import PromptGuardSafetyImpl
|
||||||
|
|
||||||
impl = PromptGuardSafetyImpl(config, deps)
|
impl = PromptGuardSafetyImpl(config, deps)
|
||||||
|
|
|
@ -3,16 +3,16 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from typing import Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
from llama_stack.distribution.datatypes import Api
|
||||||
|
|
||||||
from .config import BasicScoringConfig
|
from .config import BasicScoringConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(
|
async def get_provider_impl(
|
||||||
config: BasicScoringConfig,
|
config: BasicScoringConfig,
|
||||||
deps: Dict[Api, ProviderSpec],
|
deps: Dict[Api, Any],
|
||||||
):
|
):
|
||||||
from .scoring import BasicScoringImpl
|
from .scoring import BasicScoringImpl
|
||||||
|
|
||||||
|
|
|
@ -3,11 +3,11 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from typing import Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
from llama_stack.distribution.datatypes import Api
|
||||||
|
|
||||||
from .config import BraintrustScoringConfig
|
from .config import BraintrustScoringConfig
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@ class BraintrustProviderDataValidator(BaseModel):
|
||||||
|
|
||||||
async def get_provider_impl(
|
async def get_provider_impl(
|
||||||
config: BraintrustScoringConfig,
|
config: BraintrustScoringConfig,
|
||||||
deps: Dict[Api, ProviderSpec],
|
deps: Dict[Api, Any],
|
||||||
):
|
):
|
||||||
from .braintrust import BraintrustScoringImpl
|
from .braintrust import BraintrustScoringImpl
|
||||||
|
|
||||||
|
|
|
@ -3,16 +3,16 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from typing import Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
from llama_stack.distribution.datatypes import Api
|
||||||
|
|
||||||
from .config import LlmAsJudgeScoringConfig
|
from .config import LlmAsJudgeScoringConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(
|
async def get_provider_impl(
|
||||||
config: LlmAsJudgeScoringConfig,
|
config: LlmAsJudgeScoringConfig,
|
||||||
deps: Dict[Api, ProviderSpec],
|
deps: Dict[Api, Any],
|
||||||
):
|
):
|
||||||
from .scoring import LlmAsJudgeScoringImpl
|
from .scoring import LlmAsJudgeScoringImpl
|
||||||
|
|
||||||
|
|
|
@ -4,12 +4,14 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
from .config import CodeInterpreterToolConfig
|
from .config import CodeInterpreterToolConfig
|
||||||
|
|
||||||
__all__ = ["CodeInterpreterToolConfig", "CodeInterpreterToolRuntimeImpl"]
|
__all__ = ["CodeInterpreterToolConfig", "CodeInterpreterToolRuntimeImpl"]
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: CodeInterpreterToolConfig, _deps):
|
async def get_provider_impl(config: CodeInterpreterToolConfig, _deps: Dict[str, Any]):
|
||||||
from .code_interpreter import CodeInterpreterToolRuntimeImpl
|
from .code_interpreter import CodeInterpreterToolRuntimeImpl
|
||||||
|
|
||||||
impl = CodeInterpreterToolRuntimeImpl(config)
|
impl = CodeInterpreterToolRuntimeImpl(config)
|
||||||
|
|
|
@ -4,14 +4,14 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
from .config import ChromaVectorIOConfig
|
from .config import ChromaVectorIOConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: ChromaVectorIOConfig, deps: Dict[Api, ProviderSpec]):
|
async def get_provider_impl(config: ChromaVectorIOConfig, deps: Dict[Api, Any]):
|
||||||
from llama_stack.providers.remote.vector_io.chroma.chroma import (
|
from llama_stack.providers.remote.vector_io.chroma.chroma import (
|
||||||
ChromaVectorIOAdapter,
|
ChromaVectorIOAdapter,
|
||||||
)
|
)
|
||||||
|
|
|
@ -4,14 +4,14 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
from .config import FaissVectorIOConfig
|
from .config import FaissVectorIOConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: FaissVectorIOConfig, deps: Dict[Api, ProviderSpec]):
|
async def get_provider_impl(config: FaissVectorIOConfig, deps: Dict[Api, Any]):
|
||||||
from .faiss import FaissVectorIOAdapter
|
from .faiss import FaissVectorIOAdapter
|
||||||
|
|
||||||
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"
|
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||||
|
|
|
@ -4,14 +4,14 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
from .config import MilvusVectorIOConfig
|
from .config import MilvusVectorIOConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: MilvusVectorIOConfig, deps: Dict[Api, ProviderSpec]):
|
async def get_provider_impl(config: MilvusVectorIOConfig, deps: Dict[Api, Any]):
|
||||||
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter
|
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter
|
||||||
|
|
||||||
impl = MilvusVectorIOAdapter(config, deps[Api.inference])
|
impl = MilvusVectorIOAdapter(config, deps[Api.inference])
|
||||||
|
|
|
@ -4,14 +4,14 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
from .config import SQLiteVectorIOConfig
|
from .config import SQLiteVectorIOConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: SQLiteVectorIOConfig, deps: Dict[Api, ProviderSpec]):
|
async def get_provider_impl(config: SQLiteVectorIOConfig, deps: Dict[Api, Any]):
|
||||||
from .sqlite_vec import SQLiteVecVectorIOAdapter
|
from .sqlite_vec import SQLiteVecVectorIOAdapter
|
||||||
|
|
||||||
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"
|
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue