mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
chore: indicate to mypy that InferenceProvider.batch_completion/batch_chat_completion is concrete (#3239)
# What does this PR do? closes https://github.com/llamastack/llama-stack/issues/3236 mypy considered our default implementations (raise NotImplementedError) to be trivial. the result was we implemented the same stubs in providers. this change puts enough into the default impls so mypy considers them non-trivial. this allows us to remove the duplicate implementations.
This commit is contained in:
parent
2ee898cc4c
commit
3d119a86d4
5 changed files with 2 additions and 89 deletions
|
@ -1068,6 +1068,7 @@ class InferenceProvider(Protocol):
|
||||||
:returns: A BatchCompletionResponse with the full completions.
|
:returns: A BatchCompletionResponse with the full completions.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("Batch completion is not implemented")
|
raise NotImplementedError("Batch completion is not implemented")
|
||||||
|
return # this is so mypy's safe-super rule will consider the method concrete
|
||||||
|
|
||||||
@webmethod(route="/inference/chat-completion", method="POST")
|
@webmethod(route="/inference/chat-completion", method="POST")
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
|
@ -1132,6 +1133,7 @@ class InferenceProvider(Protocol):
|
||||||
:returns: A BatchChatCompletionResponse with the full completions.
|
:returns: A BatchChatCompletionResponse with the full completions.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("Batch chat completion is not implemented")
|
raise NotImplementedError("Batch chat completion is not implemented")
|
||||||
|
return # this is so mypy's safe-super rule will consider the method concrete
|
||||||
|
|
||||||
@webmethod(route="/inference/embeddings", method="POST")
|
@webmethod(route="/inference/embeddings", method="POST")
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
|
|
|
@ -9,7 +9,6 @@ from collections.abc import AsyncGenerator
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
InferenceProvider,
|
InferenceProvider,
|
||||||
InterleavedContent,
|
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
|
@ -100,25 +99,3 @@ class SentenceTransformersInferenceImpl(
|
||||||
tool_config: ToolConfig | None = None,
|
tool_config: ToolConfig | None = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
raise ValueError("Sentence transformers don't support chat completion")
|
raise ValueError("Sentence transformers don't support chat completion")
|
||||||
|
|
||||||
async def batch_completion(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
content_batch: list[InterleavedContent],
|
|
||||||
sampling_params: SamplingParams | None = None,
|
|
||||||
response_format: ResponseFormat | None = None,
|
|
||||||
logprobs: LogProbConfig | None = None,
|
|
||||||
):
|
|
||||||
raise NotImplementedError("Batch completion is not supported for Sentence Transformers")
|
|
||||||
|
|
||||||
async def batch_chat_completion(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
messages_batch: list[list[Message]],
|
|
||||||
sampling_params: SamplingParams | None = None,
|
|
||||||
tools: list[ToolDefinition] | None = None,
|
|
||||||
tool_config: ToolConfig | None = None,
|
|
||||||
response_format: ResponseFormat | None = None,
|
|
||||||
logprobs: LogProbConfig | None = None,
|
|
||||||
):
|
|
||||||
raise NotImplementedError("Batch chat completion is not supported for Sentence Transformers")
|
|
||||||
|
|
|
@ -619,28 +619,6 @@ class OllamaInferenceAdapter(
|
||||||
response.id = id
|
response.id = id
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def batch_completion(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
content_batch: list[InterleavedContent],
|
|
||||||
sampling_params: SamplingParams | None = None,
|
|
||||||
response_format: ResponseFormat | None = None,
|
|
||||||
logprobs: LogProbConfig | None = None,
|
|
||||||
):
|
|
||||||
raise NotImplementedError("Batch completion is not supported for Ollama")
|
|
||||||
|
|
||||||
async def batch_chat_completion(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
messages_batch: list[list[Message]],
|
|
||||||
sampling_params: SamplingParams | None = None,
|
|
||||||
tools: list[ToolDefinition] | None = None,
|
|
||||||
tool_config: ToolConfig | None = None,
|
|
||||||
response_format: ResponseFormat | None = None,
|
|
||||||
logprobs: LogProbConfig | None = None,
|
|
||||||
):
|
|
||||||
raise NotImplementedError("Batch chat completion is not supported for Ollama")
|
|
||||||
|
|
||||||
|
|
||||||
async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]:
|
async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]:
|
||||||
async def _convert_content(content) -> dict:
|
async def _convert_content(content) -> dict:
|
||||||
|
|
|
@ -711,25 +711,3 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
return await self.client.chat.completions.create(**params) # type: ignore
|
return await self.client.chat.completions.create(**params) # type: ignore
|
||||||
|
|
||||||
async def batch_completion(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
content_batch: list[InterleavedContent],
|
|
||||||
sampling_params: SamplingParams | None = None,
|
|
||||||
response_format: ResponseFormat | None = None,
|
|
||||||
logprobs: LogProbConfig | None = None,
|
|
||||||
):
|
|
||||||
raise NotImplementedError("Batch completion is not supported for Ollama")
|
|
||||||
|
|
||||||
async def batch_chat_completion(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
messages_batch: list[list[Message]],
|
|
||||||
sampling_params: SamplingParams | None = None,
|
|
||||||
tools: list[ToolDefinition] | None = None,
|
|
||||||
tool_config: ToolConfig | None = None,
|
|
||||||
response_format: ResponseFormat | None = None,
|
|
||||||
logprobs: LogProbConfig | None = None,
|
|
||||||
):
|
|
||||||
raise NotImplementedError("Batch chat completion is not supported for vLLM")
|
|
||||||
|
|
|
@ -429,28 +429,6 @@ class LiteLLMOpenAIMixin(
|
||||||
)
|
)
|
||||||
return await litellm.acompletion(**params)
|
return await litellm.acompletion(**params)
|
||||||
|
|
||||||
async def batch_completion(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
content_batch: list[InterleavedContent],
|
|
||||||
sampling_params: SamplingParams | None = None,
|
|
||||||
response_format: ResponseFormat | None = None,
|
|
||||||
logprobs: LogProbConfig | None = None,
|
|
||||||
):
|
|
||||||
raise NotImplementedError("Batch completion is not supported for OpenAI Compat")
|
|
||||||
|
|
||||||
async def batch_chat_completion(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
messages_batch: list[list[Message]],
|
|
||||||
sampling_params: SamplingParams | None = None,
|
|
||||||
tools: list[ToolDefinition] | None = None,
|
|
||||||
tool_config: ToolConfig | None = None,
|
|
||||||
response_format: ResponseFormat | None = None,
|
|
||||||
logprobs: LogProbConfig | None = None,
|
|
||||||
):
|
|
||||||
raise NotImplementedError("Batch chat completion is not supported for OpenAI Compat")
|
|
||||||
|
|
||||||
async def check_model_availability(self, model: str) -> bool:
|
async def check_model_availability(self, model: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if a specific model is available via LiteLLM for the current
|
Check if a specific model is available via LiteLLM for the current
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue