chore: mypy for remote::ollama

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-03-25 11:53:57 -04:00
parent 3e6c47ce10
commit d95b92571b
3 changed files with 28 additions and 19 deletions

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from typing import AsyncGenerator, List, Optional, Union from typing import Any, AsyncGenerator, List, Optional, Union
import httpx import httpx
from ollama import AsyncClient from ollama import AsyncClient
@ -19,10 +19,15 @@ from llama_stack.apis.common.content_types import (
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
CompletionRequest, CompletionRequest,
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType, EmbeddingTaskType,
GrammarResponseFormat,
Inference, Inference,
JsonSchemaResponseFormat,
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
@ -94,10 +99,11 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
assert self.model_store is not None
if sampling_params is None: if sampling_params is None:
sampling_params = SamplingParams() sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id) model = self.model_store.get_model(model_id)
request = CompletionRequest( request = CompletionRequest(
model=model.provider_resource_id, model=model.provider_resource_id,
content=content, content=content,
@ -111,7 +117,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
else: else:
return await self._nonstream_completion(request) return await self._nonstream_completion(request)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: async def _stream_completion(
self, request: CompletionRequest
) -> AsyncGenerator[CompletionResponseStreamChunk, None]:
params = await self._get_params(request) params = await self._get_params(request)
async def _generate_and_convert_to_openai_compat(): async def _generate_and_convert_to_openai_compat():
@ -129,7 +137,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
async for chunk in process_completion_stream_response(stream): async for chunk in process_completion_stream_response(stream):
yield chunk yield chunk
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator: async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
params = await self._get_params(request) params = await self._get_params(request)
r = await self.client.generate(**params) r = await self.client.generate(**params)
@ -148,17 +156,18 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
model_id: str, model_id: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = None, sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None, tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None, tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator: ) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
assert self.model_store is not None
if sampling_params is None: if sampling_params is None:
sampling_params = SamplingParams() sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id) model = self.model_store.get_model(model_id)
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=model.provider_resource_id, model=model.provider_resource_id,
messages=messages, messages=messages,
@ -181,7 +190,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
if sampling_options.get("max_tokens") is not None: if sampling_options.get("max_tokens") is not None:
sampling_options["num_predict"] = sampling_options["max_tokens"] sampling_options["num_predict"] = sampling_options["max_tokens"]
input_dict = {} input_dict: dict[str, Any] = {}
media_present = request_has_media(request) media_present = request_has_media(request)
llama_model = self.register_helper.get_llama_model(request.model) llama_model = self.register_helper.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest): if isinstance(request, ChatCompletionRequest):
@ -201,9 +210,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
input_dict["raw"] = True input_dict["raw"] = True
if fmt := request.response_format: if fmt := request.response_format:
if fmt.type == "json_schema": if isinstance(fmt, JsonSchemaResponseFormat):
input_dict["format"] = fmt.json_schema input_dict["format"] = fmt.json_schema
elif fmt.type == "grammar": elif isinstance(fmt, GrammarResponseFormat):
raise NotImplementedError("Grammar response format is not supported") raise NotImplementedError("Grammar response format is not supported")
else: else:
raise ValueError(f"Unknown response format type: {fmt.type}") raise ValueError(f"Unknown response format type: {fmt.type}")
@ -240,7 +249,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
) )
return process_chat_completion_response(response, request) return process_chat_completion_response(response, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
params = await self._get_params(request) params = await self._get_params(request)
async def _generate_and_convert_to_openai_compat(): async def _generate_and_convert_to_openai_compat():
@ -275,7 +286,8 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
output_dimension: Optional[int] = None, output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None, task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id) assert self.model_store is not None
model = self.model_store.get_model(model_id)
assert all(not content_has_media(content) for content in contents), ( assert all(not content_has_media(content) for content in contents), (
"Ollama does not support media for embeddings" "Ollama does not support media for embeddings"
@ -288,7 +300,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
return EmbeddingsResponse(embeddings=embeddings) return EmbeddingsResponse(embeddings=embeddings)
async def register_model(self, model: Model) -> Model: async def register_model(self, model: Model):
model = await self.register_helper.register_model(model) model = await self.register_helper.register_model(model)
if model.model_type == ModelType.embedding: if model.model_type == ModelType.embedding:
logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...") logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
@ -302,8 +314,6 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
f"Model '{model.provider_resource_id}' is not available in Ollama. Available models: {', '.join(available_models)}" f"Model '{model.provider_resource_id}' is not available in Ollama. Available models: {', '.join(available_models)}"
) )
return model
async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]: async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]:
async def _convert_content(content) -> dict: async def _convert_content(content) -> dict:

View file

@ -300,7 +300,7 @@ def process_chat_completion_response(
async def process_completion_stream_response( async def process_completion_stream_response(
stream: AsyncGenerator[OpenAICompatCompletionResponse, None], stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
) -> AsyncGenerator: ) -> AsyncGenerator[CompletionResponseStreamChunk, None]:
stop_reason = None stop_reason = None
async for chunk in stream: async for chunk in stream:
@ -337,7 +337,7 @@ async def process_completion_stream_response(
async def process_chat_completion_stream_response( async def process_chat_completion_stream_response(
stream: AsyncGenerator[OpenAICompatCompletionResponse, None], stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
request: ChatCompletionRequest, request: ChatCompletionRequest,
) -> AsyncGenerator: ) -> AsyncGenerator[ChatCompletionResponseStreamChunk]:
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start, event_type=ChatCompletionResponseEventType.start,

View file

@ -245,7 +245,6 @@ exclude = [
"^llama_stack/providers/remote/inference/gemini/", "^llama_stack/providers/remote/inference/gemini/",
"^llama_stack/providers/remote/inference/groq/", "^llama_stack/providers/remote/inference/groq/",
"^llama_stack/providers/remote/inference/nvidia/", "^llama_stack/providers/remote/inference/nvidia/",
"^llama_stack/providers/remote/inference/ollama/",
"^llama_stack/providers/remote/inference/openai/", "^llama_stack/providers/remote/inference/openai/",
"^llama_stack/providers/remote/inference/passthrough/", "^llama_stack/providers/remote/inference/passthrough/",
"^llama_stack/providers/remote/inference/runpod/", "^llama_stack/providers/remote/inference/runpod/",