chore: enable mypy type checking for sentence transformers

Fix method signature incompatibilities with InferenceProvider protocol:
 - Update completion() content parameter type to InterleavedContent
 - Reorder chat_completion() parameters to match protocol
 - Add type ignores for mixin inheritance conflicts

Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
Mustafa Elbehery 2025-08-25 14:59:40 +02:00
parent ed418653ec
commit 821d09af3d
3 changed files with 6 additions and 6 deletions

View file

@ -9,6 +9,7 @@ from collections.abc import AsyncGenerator
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
CompletionResponse, CompletionResponse,
InferenceProvider, InferenceProvider,
InterleavedContent,
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
@ -34,7 +35,7 @@ from .config import SentenceTransformersInferenceConfig
log = get_logger(name=__name__, category="inference") log = get_logger(name=__name__, category="inference")
class SentenceTransformersInferenceImpl( class SentenceTransformersInferenceImpl( # type: ignore[misc]
OpenAIChatCompletionToLlamaStackMixin, OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin, OpenAICompletionToLlamaStackMixin,
SentenceTransformerEmbeddingMixin, SentenceTransformerEmbeddingMixin,
@ -77,7 +78,7 @@ class SentenceTransformersInferenceImpl(
async def completion( async def completion(
self, self,
model_id: str, model_id: str,
content: str, content: InterleavedContent,
sampling_params: SamplingParams | None = None, sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None, response_format: ResponseFormat | None = None,
stream: bool | None = False, stream: bool | None = False,
@ -90,10 +91,10 @@ class SentenceTransformersInferenceImpl(
model_id: str, model_id: str,
messages: list[Message], messages: list[Message],
sampling_params: SamplingParams | None = None, sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None, tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto, tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None, tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False, stream: bool | None = False,
logprobs: LogProbConfig | None = None, logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None, tool_config: ToolConfig | None = None,

View file

@ -5,13 +5,13 @@
# the root directory of this source tree. # the root directory of this source tree.
from typing import cast from typing import Any, cast
from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec
# We provide two versions of these providers so that distributions can package the appropriate version of torch. # We provide two versions of these providers so that distributions can package the appropriate version of torch.
# The CPU version is used for distributions that don't have GPU support -- they result in smaller container images. # The CPU version is used for distributions that don't have GPU support -- they result in smaller container images.
torchtune_def = dict( torchtune_def: dict[str, Any] = dict(
api=Api.post_training, api=Api.post_training,
pip_packages=["numpy"], pip_packages=["numpy"],
module="llama_stack.providers.inline.post_training.torchtune", module="llama_stack.providers.inline.post_training.torchtune",

View file

@ -268,7 +268,6 @@ exclude = [
"^llama_stack/models/llama/llama3/multimodal/model\\.py$", "^llama_stack/models/llama/llama3/multimodal/model\\.py$",
"^llama_stack/models/llama/llama4/", "^llama_stack/models/llama/llama4/",
"^llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls\\.py$", "^llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls\\.py$",
"^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$",
"^llama_stack/providers/inline/post_training/common/validator\\.py$", "^llama_stack/providers/inline/post_training/common/validator\\.py$",
"^llama_stack/providers/inline/safety/code_scanner/", "^llama_stack/providers/inline/safety/code_scanner/",
"^llama_stack/providers/inline/safety/llama_guard/", "^llama_stack/providers/inline/safety/llama_guard/",