chore: more mypy checks (ollama, vllm, ...) (#1777)

# What does this PR do?

- **chore: mypy for strong_typing**
- **chore: mypy for remote::vllm**
- **chore: mypy for remote::ollama**
- **chore: mypy for providers.datatype**

---------

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-04-01 11:12:39 -04:00 committed by GitHub
parent d5e0f32485
commit 66d6c2580e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 103 additions and 72 deletions

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from typing import AsyncGenerator, List, Optional, Union
from typing import Any, AsyncGenerator, List, Optional, Union
import httpx
from ollama import AsyncClient
@ -19,10 +19,15 @@ from llama_stack.apis.common.content_types import (
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
CompletionRequest,
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
GrammarResponseFormat,
Inference,
JsonSchemaResponseFormat,
LogProbConfig,
Message,
ResponseFormat,
@ -86,6 +91,11 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
async def unregister_model(self, model_id: str) -> None:
pass
async def _get_model(self, model_id: str) -> Model:
if not self.model_store:
raise ValueError("Model store not set")
return await self.model_store.get_model(model_id)
async def completion(
self,
model_id: str,
@ -94,10 +104,10 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
model = await self._get_model(model_id)
request = CompletionRequest(
model=model.provider_resource_id,
content=content,
@ -111,7 +121,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
else:
return await self._nonstream_completion(request)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
async def _stream_completion(
self, request: CompletionRequest
) -> AsyncGenerator[CompletionResponseStreamChunk, None]:
params = await self._get_params(request)
async def _generate_and_convert_to_openai_compat():
@ -129,7 +141,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
async for chunk in process_completion_stream_response(stream):
yield chunk
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
params = await self._get_params(request)
r = await self.client.generate(**params)
@ -148,17 +160,17 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
model = await self._get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
@ -181,7 +193,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
if sampling_options.get("max_tokens") is not None:
sampling_options["num_predict"] = sampling_options["max_tokens"]
input_dict = {}
input_dict: dict[str, Any] = {}
media_present = request_has_media(request)
llama_model = self.register_helper.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest):
@ -201,9 +213,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
input_dict["raw"] = True
if fmt := request.response_format:
if fmt.type == "json_schema":
if isinstance(fmt, JsonSchemaResponseFormat):
input_dict["format"] = fmt.json_schema
elif fmt.type == "grammar":
elif isinstance(fmt, GrammarResponseFormat):
raise NotImplementedError("Grammar response format is not supported")
else:
raise ValueError(f"Unknown response format type: {fmt.type}")
@ -240,7 +252,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
)
return process_chat_completion_response(response, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
params = await self._get_params(request)
async def _generate_and_convert_to_openai_compat():
@ -275,7 +289,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
model = await self._get_model(model_id)
assert all(not content_has_media(content) for content in contents), (
"Ollama does not support media for embeddings"