From 3d119a86d4ceebdd1364c18fe94d422f57241431 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Fri, 22 Aug 2025 16:17:30 -0500 Subject: [PATCH] chore: indicate to mypy that InferenceProvider.batch_completion/batch_chat_completion is concrete (#3239) # What does this PR do? closes https://github.com/llamastack/llama-stack/issues/3236 mypy considered our default implementations (raise NotImplementedError) to be trivial. the result was we implemented the same stubs in providers. this change puts enough into the default impls so mypy considers them non-trivial. this allows us to remove the duplicate implementations. --- llama_stack/apis/inference/inference.py | 2 ++ .../sentence_transformers.py | 23 ------------------- .../remote/inference/ollama/ollama.py | 22 ------------------ .../providers/remote/inference/vllm/vllm.py | 22 ------------------ .../utils/inference/litellm_openai_mixin.py | 22 ------------------ 5 files changed, 2 insertions(+), 89 deletions(-) diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 570ed3d2b..bd4737ca7 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -1068,6 +1068,7 @@ class InferenceProvider(Protocol): :returns: A BatchCompletionResponse with the full completions. """ raise NotImplementedError("Batch completion is not implemented") + return # this is so mypy's safe-super rule will consider the method concrete @webmethod(route="/inference/chat-completion", method="POST") async def chat_completion( @@ -1132,6 +1133,7 @@ class InferenceProvider(Protocol): :returns: A BatchChatCompletionResponse with the full completions. """ raise NotImplementedError("Batch chat completion is not implemented") + return # this is so mypy's safe-super rule will consider the method concrete @webmethod(route="/inference/embeddings", method="POST") async def embeddings( diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index 600a5bd37..34665b63e 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -9,7 +9,6 @@ from collections.abc import AsyncGenerator from llama_stack.apis.inference import ( CompletionResponse, InferenceProvider, - InterleavedContent, LogProbConfig, Message, ResponseFormat, @@ -100,25 +99,3 @@ class SentenceTransformersInferenceImpl( tool_config: ToolConfig | None = None, ) -> AsyncGenerator: raise ValueError("Sentence transformers don't support chat completion") - - async def batch_completion( - self, - model_id: str, - content_batch: list[InterleavedContent], - sampling_params: SamplingParams | None = None, - response_format: ResponseFormat | None = None, - logprobs: LogProbConfig | None = None, - ): - raise NotImplementedError("Batch completion is not supported for Sentence Transformers") - - async def batch_chat_completion( - self, - model_id: str, - messages_batch: list[list[Message]], - sampling_params: SamplingParams | None = None, - tools: list[ToolDefinition] | None = None, - tool_config: ToolConfig | None = None, - response_format: ResponseFormat | None = None, - logprobs: LogProbConfig | None = None, - ): - raise NotImplementedError("Batch chat completion is not supported for Sentence Transformers") diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index d8b331ef7..fcaf5ee92 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -619,28 +619,6 @@ class OllamaInferenceAdapter( response.id = id return response - async def batch_completion( - self, - model_id: str, - content_batch: list[InterleavedContent], - sampling_params: SamplingParams | None = None, - response_format: ResponseFormat | None = None, - logprobs: LogProbConfig | None = None, - ): - raise NotImplementedError("Batch completion is not supported for Ollama") - - async def batch_chat_completion( - self, - model_id: str, - messages_batch: list[list[Message]], - sampling_params: SamplingParams | None = None, - tools: list[ToolDefinition] | None = None, - tool_config: ToolConfig | None = None, - response_format: ResponseFormat | None = None, - logprobs: LogProbConfig | None = None, - ): - raise NotImplementedError("Batch chat completion is not supported for Ollama") - async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]: async def _convert_content(content) -> dict: diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index f71068318..9e9a80ca5 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -711,25 +711,3 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): user=user, ) return await self.client.chat.completions.create(**params) # type: ignore - - async def batch_completion( - self, - model_id: str, - content_batch: list[InterleavedContent], - sampling_params: SamplingParams | None = None, - response_format: ResponseFormat | None = None, - logprobs: LogProbConfig | None = None, - ): - raise NotImplementedError("Batch completion is not supported for Ollama") - - async def batch_chat_completion( - self, - model_id: str, - messages_batch: list[list[Message]], - sampling_params: SamplingParams | None = None, - tools: list[ToolDefinition] | None = None, - tool_config: ToolConfig | None = None, - response_format: ResponseFormat | None = None, - logprobs: LogProbConfig | None = None, - ): - raise NotImplementedError("Batch chat completion is not supported for vLLM") diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index 880348805..9bd43e4c9 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -429,28 +429,6 @@ class LiteLLMOpenAIMixin( ) return await litellm.acompletion(**params) - async def batch_completion( - self, - model_id: str, - content_batch: list[InterleavedContent], - sampling_params: SamplingParams | None = None, - response_format: ResponseFormat | None = None, - logprobs: LogProbConfig | None = None, - ): - raise NotImplementedError("Batch completion is not supported for OpenAI Compat") - - async def batch_chat_completion( - self, - model_id: str, - messages_batch: list[list[Message]], - sampling_params: SamplingParams | None = None, - tools: list[ToolDefinition] | None = None, - tool_config: ToolConfig | None = None, - response_format: ResponseFormat | None = None, - logprobs: LogProbConfig | None = None, - ): - raise NotImplementedError("Batch chat completion is not supported for OpenAI Compat") - async def check_model_availability(self, model: str) -> bool: """ Check if a specific model is available via LiteLLM for the current