diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index 34665b63e..3c1690b6c 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -9,6 +9,7 @@ from collections.abc import AsyncGenerator from llama_stack.apis.inference import ( CompletionResponse, InferenceProvider, + InterleavedContent, LogProbConfig, Message, ResponseFormat, @@ -34,7 +35,7 @@ from .config import SentenceTransformersInferenceConfig log = get_logger(name=__name__, category="inference") -class SentenceTransformersInferenceImpl( +class SentenceTransformersInferenceImpl( # type: ignore[misc] OpenAIChatCompletionToLlamaStackMixin, OpenAICompletionToLlamaStackMixin, SentenceTransformerEmbeddingMixin, @@ -77,7 +78,7 @@ class SentenceTransformersInferenceImpl( async def completion( self, model_id: str, - content: str, + content: InterleavedContent, sampling_params: SamplingParams | None = None, response_format: ResponseFormat | None = None, stream: bool | None = False, @@ -90,10 +91,10 @@ class SentenceTransformersInferenceImpl( model_id: str, messages: list[Message], sampling_params: SamplingParams | None = None, - response_format: ResponseFormat | None = None, tools: list[ToolDefinition] | None = None, tool_choice: ToolChoice | None = ToolChoice.auto, tool_prompt_format: ToolPromptFormat | None = None, + response_format: ResponseFormat | None = None, stream: bool | None = False, logprobs: LogProbConfig | None = None, tool_config: ToolConfig | None = None, diff --git a/llama_stack/providers/registry/post_training.py b/llama_stack/providers/registry/post_training.py index 67238e3fc..095d5ef58 100644 --- a/llama_stack/providers/registry/post_training.py +++ b/llama_stack/providers/registry/post_training.py @@ -5,13 +5,13 @@ # the root directory of this source tree. -from typing import cast +from typing import Any, cast from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec # We provide two versions of these providers so that distributions can package the appropriate version of torch. # The CPU version is used for distributions that don't have GPU support -- they result in smaller container images. -torchtune_def = dict( +torchtune_def: dict[str, Any] = dict( api=Api.post_training, pip_packages=["numpy"], module="llama_stack.providers.inline.post_training.torchtune", diff --git a/pyproject.toml b/pyproject.toml index dd8529546..df66eca60 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -268,7 +268,6 @@ exclude = [ "^llama_stack/models/llama/llama3/multimodal/model\\.py$", "^llama_stack/models/llama/llama4/", "^llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls\\.py$", - "^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$", "^llama_stack/providers/inline/post_training/common/validator\\.py$", "^llama_stack/providers/inline/safety/code_scanner/", "^llama_stack/providers/inline/safety/llama_guard/",