mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 18:50:44 +00:00
chore: mypy for remote::ollama
Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
3e6c47ce10
commit
d95b92571b
3 changed files with 28 additions and 19 deletions
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
from typing import AsyncGenerator, List, Optional, Union
|
from typing import Any, AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from ollama import AsyncClient
|
from ollama import AsyncClient
|
||||||
|
@ -19,10 +19,15 @@ from llama_stack.apis.common.content_types import (
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
|
ChatCompletionResponseStreamChunk,
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
|
CompletionResponse,
|
||||||
|
CompletionResponseStreamChunk,
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
EmbeddingTaskType,
|
EmbeddingTaskType,
|
||||||
|
GrammarResponseFormat,
|
||||||
Inference,
|
Inference,
|
||||||
|
JsonSchemaResponseFormat,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
|
@ -94,10 +99,11 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
|
||||||
|
assert self.model_store is not None
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
sampling_params = SamplingParams()
|
sampling_params = SamplingParams()
|
||||||
model = await self.model_store.get_model(model_id)
|
model = self.model_store.get_model(model_id)
|
||||||
request = CompletionRequest(
|
request = CompletionRequest(
|
||||||
model=model.provider_resource_id,
|
model=model.provider_resource_id,
|
||||||
content=content,
|
content=content,
|
||||||
|
@ -111,7 +117,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
else:
|
else:
|
||||||
return await self._nonstream_completion(request)
|
return await self._nonstream_completion(request)
|
||||||
|
|
||||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
async def _stream_completion(
|
||||||
|
self, request: CompletionRequest
|
||||||
|
) -> AsyncGenerator[CompletionResponseStreamChunk, None]:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
async def _generate_and_convert_to_openai_compat():
|
async def _generate_and_convert_to_openai_compat():
|
||||||
|
@ -129,7 +137,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
async for chunk in process_completion_stream_response(stream):
|
async for chunk in process_completion_stream_response(stream):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
r = await self.client.generate(**params)
|
r = await self.client.generate(**params)
|
||||||
|
|
||||||
|
@ -148,17 +156,18 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||||
|
assert self.model_store is not None
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
sampling_params = SamplingParams()
|
sampling_params = SamplingParams()
|
||||||
model = await self.model_store.get_model(model_id)
|
model = self.model_store.get_model(model_id)
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=model.provider_resource_id,
|
model=model.provider_resource_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -181,7 +190,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
if sampling_options.get("max_tokens") is not None:
|
if sampling_options.get("max_tokens") is not None:
|
||||||
sampling_options["num_predict"] = sampling_options["max_tokens"]
|
sampling_options["num_predict"] = sampling_options["max_tokens"]
|
||||||
|
|
||||||
input_dict = {}
|
input_dict: dict[str, Any] = {}
|
||||||
media_present = request_has_media(request)
|
media_present = request_has_media(request)
|
||||||
llama_model = self.register_helper.get_llama_model(request.model)
|
llama_model = self.register_helper.get_llama_model(request.model)
|
||||||
if isinstance(request, ChatCompletionRequest):
|
if isinstance(request, ChatCompletionRequest):
|
||||||
|
@ -201,9 +210,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
input_dict["raw"] = True
|
input_dict["raw"] = True
|
||||||
|
|
||||||
if fmt := request.response_format:
|
if fmt := request.response_format:
|
||||||
if fmt.type == "json_schema":
|
if isinstance(fmt, JsonSchemaResponseFormat):
|
||||||
input_dict["format"] = fmt.json_schema
|
input_dict["format"] = fmt.json_schema
|
||||||
elif fmt.type == "grammar":
|
elif isinstance(fmt, GrammarResponseFormat):
|
||||||
raise NotImplementedError("Grammar response format is not supported")
|
raise NotImplementedError("Grammar response format is not supported")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown response format type: {fmt.type}")
|
raise ValueError(f"Unknown response format type: {fmt.type}")
|
||||||
|
@ -240,7 +249,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
)
|
)
|
||||||
return process_chat_completion_response(response, request)
|
return process_chat_completion_response(response, request)
|
||||||
|
|
||||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
async def _stream_chat_completion(
|
||||||
|
self, request: ChatCompletionRequest
|
||||||
|
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
async def _generate_and_convert_to_openai_compat():
|
async def _generate_and_convert_to_openai_compat():
|
||||||
|
@ -275,7 +286,8 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
output_dimension: Optional[int] = None,
|
output_dimension: Optional[int] = None,
|
||||||
task_type: Optional[EmbeddingTaskType] = None,
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.model_store.get_model(model_id)
|
assert self.model_store is not None
|
||||||
|
model = self.model_store.get_model(model_id)
|
||||||
|
|
||||||
assert all(not content_has_media(content) for content in contents), (
|
assert all(not content_has_media(content) for content in contents), (
|
||||||
"Ollama does not support media for embeddings"
|
"Ollama does not support media for embeddings"
|
||||||
|
@ -288,7 +300,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
|
|
||||||
return EmbeddingsResponse(embeddings=embeddings)
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> Model:
|
async def register_model(self, model: Model):
|
||||||
model = await self.register_helper.register_model(model)
|
model = await self.register_helper.register_model(model)
|
||||||
if model.model_type == ModelType.embedding:
|
if model.model_type == ModelType.embedding:
|
||||||
logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
|
logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
|
||||||
|
@ -302,8 +314,6 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
f"Model '{model.provider_resource_id}' is not available in Ollama. Available models: {', '.join(available_models)}"
|
f"Model '{model.provider_resource_id}' is not available in Ollama. Available models: {', '.join(available_models)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]:
|
async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]:
|
||||||
async def _convert_content(content) -> dict:
|
async def _convert_content(content) -> dict:
|
||||||
|
|
|
@ -300,7 +300,7 @@ def process_chat_completion_response(
|
||||||
|
|
||||||
async def process_completion_stream_response(
|
async def process_completion_stream_response(
|
||||||
stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
|
stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator[CompletionResponseStreamChunk, None]:
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
|
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
|
@ -337,7 +337,7 @@ async def process_completion_stream_response(
|
||||||
async def process_chat_completion_stream_response(
|
async def process_chat_completion_stream_response(
|
||||||
stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
|
stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator[ChatCompletionResponseStreamChunk]:
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.start,
|
event_type=ChatCompletionResponseEventType.start,
|
||||||
|
|
|
@ -245,7 +245,6 @@ exclude = [
|
||||||
"^llama_stack/providers/remote/inference/gemini/",
|
"^llama_stack/providers/remote/inference/gemini/",
|
||||||
"^llama_stack/providers/remote/inference/groq/",
|
"^llama_stack/providers/remote/inference/groq/",
|
||||||
"^llama_stack/providers/remote/inference/nvidia/",
|
"^llama_stack/providers/remote/inference/nvidia/",
|
||||||
"^llama_stack/providers/remote/inference/ollama/",
|
|
||||||
"^llama_stack/providers/remote/inference/openai/",
|
"^llama_stack/providers/remote/inference/openai/",
|
||||||
"^llama_stack/providers/remote/inference/passthrough/",
|
"^llama_stack/providers/remote/inference/passthrough/",
|
||||||
"^llama_stack/providers/remote/inference/runpod/",
|
"^llama_stack/providers/remote/inference/runpod/",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue