chore: mypy for remote::ollama

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-03-25 11:53:57 -04:00
parent 3e6c47ce10
commit d95b92571b
3 changed files with 28 additions and 19 deletions

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from typing import AsyncGenerator, List, Optional, Union
from typing import Any, AsyncGenerator, List, Optional, Union
import httpx
from ollama import AsyncClient
@ -19,10 +19,15 @@ from llama_stack.apis.common.content_types import (
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
CompletionRequest,
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
GrammarResponseFormat,
Inference,
JsonSchemaResponseFormat,
LogProbConfig,
Message,
ResponseFormat,
@ -94,10 +99,11 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
assert self.model_store is not None
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
model = self.model_store.get_model(model_id)
request = CompletionRequest(
model=model.provider_resource_id,
content=content,
@ -111,7 +117,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
else:
return await self._nonstream_completion(request)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
async def _stream_completion(
self, request: CompletionRequest
) -> AsyncGenerator[CompletionResponseStreamChunk, None]:
params = await self._get_params(request)
async def _generate_and_convert_to_openai_compat():
@ -129,7 +137,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
async for chunk in process_completion_stream_response(stream):
yield chunk
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
params = await self._get_params(request)
r = await self.client.generate(**params)
@ -148,17 +156,18 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
assert self.model_store is not None
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
model = self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
@ -181,7 +190,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
if sampling_options.get("max_tokens") is not None:
sampling_options["num_predict"] = sampling_options["max_tokens"]
input_dict = {}
input_dict: dict[str, Any] = {}
media_present = request_has_media(request)
llama_model = self.register_helper.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest):
@ -201,9 +210,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
input_dict["raw"] = True
if fmt := request.response_format:
if fmt.type == "json_schema":
if isinstance(fmt, JsonSchemaResponseFormat):
input_dict["format"] = fmt.json_schema
elif fmt.type == "grammar":
elif isinstance(fmt, GrammarResponseFormat):
raise NotImplementedError("Grammar response format is not supported")
else:
raise ValueError(f"Unknown response format type: {fmt.type}")
@ -240,7 +249,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
)
return process_chat_completion_response(response, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
params = await self._get_params(request)
async def _generate_and_convert_to_openai_compat():
@ -275,7 +286,8 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
assert self.model_store is not None
model = self.model_store.get_model(model_id)
assert all(not content_has_media(content) for content in contents), (
"Ollama does not support media for embeddings"
@ -288,7 +300,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
return EmbeddingsResponse(embeddings=embeddings)
async def register_model(self, model: Model) -> Model:
async def register_model(self, model: Model):
model = await self.register_helper.register_model(model)
if model.model_type == ModelType.embedding:
logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
@ -302,8 +314,6 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
f"Model '{model.provider_resource_id}' is not available in Ollama. Available models: {', '.join(available_models)}"
)
return model
async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]:
async def _convert_content(content) -> dict:

View file

@ -300,7 +300,7 @@ def process_chat_completion_response(
async def process_completion_stream_response(
stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
) -> AsyncGenerator:
) -> AsyncGenerator[CompletionResponseStreamChunk, None]:
stop_reason = None
async for chunk in stream:
@ -337,7 +337,7 @@ async def process_completion_stream_response(
async def process_chat_completion_stream_response(
stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
request: ChatCompletionRequest,
) -> AsyncGenerator:
) -> AsyncGenerator[ChatCompletionResponseStreamChunk]:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,

View file

@ -245,7 +245,6 @@ exclude = [
"^llama_stack/providers/remote/inference/gemini/",
"^llama_stack/providers/remote/inference/groq/",
"^llama_stack/providers/remote/inference/nvidia/",
"^llama_stack/providers/remote/inference/ollama/",
"^llama_stack/providers/remote/inference/openai/",
"^llama_stack/providers/remote/inference/passthrough/",
"^llama_stack/providers/remote/inference/runpod/",