chore: more mypy checks (ollama, vllm, ...) (#1777)

# What does this PR do?

- **chore: mypy for strong_typing**
- **chore: mypy for remote::vllm**
- **chore: mypy for remote::ollama**
- **chore: mypy for providers.datatype**

---------

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-04-01 11:12:39 -04:00 committed by GitHub
parent d5e0f32485
commit 66d6c2580e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 103 additions and 72 deletions

View file

@ -394,7 +394,7 @@ class EmbeddingsResponse(BaseModel):
class ModelStore(Protocol):
def get_model(self, identifier: str) -> Model: ...
async def get_model(self, identifier: str) -> Model: ...
class TextTruncation(Enum):
@ -431,7 +431,7 @@ class Inference(Protocol):
- Embedding models: these models generate embeddings to be used for semantic search.
"""
model_store: ModelStore
model_store: ModelStore | None = None
@webmethod(route="/inference/completion", method="POST")
async def completion(

View file

@ -21,7 +21,7 @@ from llama_stack.schema_utils import json_schema_type
class ModelsProtocolPrivate(Protocol):
async def register_model(self, model: Model) -> None: ...
async def register_model(self, model: Model) -> Model: ...
async def unregister_model(self, model_id: str) -> None: ...
@ -113,8 +113,7 @@ Fully-qualified name of the module to import. The module is expected to have:
default_factory=list,
description="The pip dependencies needed for this implementation",
)
config_class: Optional[str] = Field(
default=None,
config_class: str = Field(
description="Fully-qualified classname of the config for this provider",
)
provider_data_validator: Optional[str] = Field(
@ -162,7 +161,8 @@ class RemoteProviderConfig(BaseModel):
@classmethod
def from_url(cls, url: str) -> "RemoteProviderConfig":
parsed = urlparse(url)
return cls(host=parsed.hostname, port=parsed.port, protocol=parsed.scheme)
attrs = {k: v for k, v in parsed._asdict().items() if v is not None}
return cls(**attrs)
@json_schema_type

View file

@ -43,7 +43,7 @@ class SentenceTransformersInferenceImpl(
async def shutdown(self) -> None:
pass
async def register_model(self, model: Model) -> None:
async def register_model(self, model: Model) -> Model:
return model
async def unregister_model(self, model_id: str) -> None:

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from typing import AsyncGenerator, List, Optional, Union
from typing import Any, AsyncGenerator, List, Optional, Union
import httpx
from ollama import AsyncClient
@ -19,10 +19,15 @@ from llama_stack.apis.common.content_types import (
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
CompletionRequest,
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
GrammarResponseFormat,
Inference,
JsonSchemaResponseFormat,
LogProbConfig,
Message,
ResponseFormat,
@ -86,6 +91,11 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
async def unregister_model(self, model_id: str) -> None:
pass
async def _get_model(self, model_id: str) -> Model:
if not self.model_store:
raise ValueError("Model store not set")
return await self.model_store.get_model(model_id)
async def completion(
self,
model_id: str,
@ -94,10 +104,10 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
model = await self._get_model(model_id)
request = CompletionRequest(
model=model.provider_resource_id,
content=content,
@ -111,7 +121,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
else:
return await self._nonstream_completion(request)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
async def _stream_completion(
self, request: CompletionRequest
) -> AsyncGenerator[CompletionResponseStreamChunk, None]:
params = await self._get_params(request)
async def _generate_and_convert_to_openai_compat():
@ -129,7 +141,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
async for chunk in process_completion_stream_response(stream):
yield chunk
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
params = await self._get_params(request)
r = await self.client.generate(**params)
@ -148,17 +160,17 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
model = await self._get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
@ -181,7 +193,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
if sampling_options.get("max_tokens") is not None:
sampling_options["num_predict"] = sampling_options["max_tokens"]
input_dict = {}
input_dict: dict[str, Any] = {}
media_present = request_has_media(request)
llama_model = self.register_helper.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest):
@ -201,9 +213,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
input_dict["raw"] = True
if fmt := request.response_format:
if fmt.type == "json_schema":
if isinstance(fmt, JsonSchemaResponseFormat):
input_dict["format"] = fmt.json_schema
elif fmt.type == "grammar":
elif isinstance(fmt, GrammarResponseFormat):
raise NotImplementedError("Grammar response format is not supported")
else:
raise ValueError(f"Unknown response format type: {fmt.type}")
@ -240,7 +252,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
)
return process_chat_completion_response(response, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
params = await self._get_params(request)
async def _generate_and_convert_to_openai_compat():
@ -275,7 +289,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
model = await self._get_model(model_id)
assert all(not content_has_media(content) for content in contents), (
"Ollama does not support media for embeddings"

View file

@ -83,7 +83,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
async def shutdown(self) -> None:
pass
async def register_model(self, model: Model) -> None:
async def register_model(self, model: Model) -> Model:
model = await self.register_helper.register_model(model)
if model.provider_resource_id != self.model_id:
raise ValueError(

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
import json
import logging
from typing import AsyncGenerator, List, Optional, Union
from typing import Any, AsyncGenerator, List, Optional, Union
import httpx
from openai import AsyncOpenAI
@ -32,11 +32,12 @@ from llama_stack.apis.inference import (
CompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
GrammarResponseFormat,
Inference,
JsonSchemaResponseFormat,
LogProbConfig,
Message,
ResponseFormat,
ResponseFormatType,
SamplingParams,
TextTruncation,
ToolChoice,
@ -102,9 +103,6 @@ def _convert_to_vllm_tool_calls_in_response(
def _convert_to_vllm_tools_in_request(tools: List[ToolDefinition]) -> List[dict]:
if tools is None:
return tools
compat_tools = []
for tool in tools:
@ -141,9 +139,7 @@ def _convert_to_vllm_tools_in_request(tools: List[ToolDefinition]) -> List[dict]
compat_tools.append(compat_tool)
if len(compat_tools) > 0:
return compat_tools
return None
return compat_tools
def _convert_to_vllm_finish_reason(finish_reason: str) -> StopReason:
@ -206,9 +202,10 @@ async def _process_vllm_chat_completion_stream_response(
)
elif choice.delta.tool_calls:
tool_call = convert_tool_call(choice.delta.tool_calls[0])
tool_call_buf.tool_name += tool_call.tool_name
tool_call_buf.tool_name += str(tool_call.tool_name)
tool_call_buf.call_id += tool_call.call_id
tool_call_buf.arguments += tool_call.arguments
# TODO: remove str() when dict type for 'arguments' is no longer allowed
tool_call_buf.arguments += str(tool_call.arguments)
else:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
@ -240,6 +237,11 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
async def unregister_model(self, model_id: str) -> None:
pass
async def _get_model(self, model_id: str) -> Model:
if not self.model_store:
raise ValueError("Model store not set")
return await self.model_store.get_model(model_id)
async def completion(
self,
model_id: str,
@ -248,10 +250,10 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
model = await self._get_model(model_id)
request = CompletionRequest(
model=model.provider_resource_id,
content=content,
@ -270,17 +272,17 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
model = await self._get_model(model_id)
# This is to be consistent with OpenAI API and support vLLM <= v0.6.3
# References:
# * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
@ -318,11 +320,13 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
)
return result
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: AsyncOpenAI) -> AsyncGenerator:
async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: AsyncOpenAI
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
params = await self._get_params(request)
stream = await client.chat.completions.create(**params)
if len(request.tools) > 0:
if request.tools:
res = _process_vllm_chat_completion_stream_response(stream)
else:
res = process_chat_completion_stream_response(stream, request)
@ -330,11 +334,15 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
yield chunk
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
assert self.client is not None
params = await self._get_params(request)
r = await self.client.completions.create(**params)
return process_completion_response(r)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
async def _stream_completion(
self, request: CompletionRequest
) -> AsyncGenerator[CompletionResponseStreamChunk, None]:
assert self.client is not None
params = await self._get_params(request)
stream = await self.client.completions.create(**params)
@ -342,6 +350,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
yield chunk
async def register_model(self, model: Model) -> Model:
assert self.client is not None
model = await self.register_helper.register_model(model)
res = await self.client.models.list()
available_models = [m.id async for m in res]
@ -357,7 +366,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
if "max_tokens" not in options:
options["max_tokens"] = self.config.max_tokens
input_dict = {}
input_dict: dict[str, Any] = {}
if isinstance(request, ChatCompletionRequest) and request.tools is not None:
input_dict = {"tools": _convert_to_vllm_tools_in_request(request.tools)}
@ -368,9 +377,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
input_dict["prompt"] = await completion_request_to_prompt(request)
if fmt := request.response_format:
if fmt.type == ResponseFormatType.json_schema.value:
input_dict["extra_body"] = {"guided_json": request.response_format.json_schema}
elif fmt.type == ResponseFormatType.grammar.value:
if isinstance(fmt, JsonSchemaResponseFormat):
input_dict["extra_body"] = {"guided_json": fmt.json_schema}
elif isinstance(fmt, GrammarResponseFormat):
raise NotImplementedError("Grammar response format not supported yet")
else:
raise ValueError(f"Unknown response format {fmt.type}")
@ -393,7 +402,8 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
assert self.client is not None
model = await self._get_model(model_id)
kwargs = {}
assert model.model_type == ModelType.embedding

View file

@ -104,3 +104,6 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
)
return model
async def unregister_model(self, model_id: str) -> None:
pass

View file

@ -137,7 +137,10 @@ def get_sampling_strategy_options(params: SamplingParams) -> dict:
return options
def get_sampling_options(params: SamplingParams) -> dict:
def get_sampling_options(params: SamplingParams | None) -> dict:
if not params:
return {}
options = {}
if params:
options.update(get_sampling_strategy_options(params))
@ -297,7 +300,7 @@ def process_chat_completion_response(
async def process_completion_stream_response(
stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
) -> AsyncGenerator:
) -> AsyncGenerator[CompletionResponseStreamChunk, None]:
stop_reason = None
async for chunk in stream:
@ -334,7 +337,7 @@ async def process_completion_stream_response(
async def process_chat_completion_stream_response(
stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
request: ChatCompletionRequest,
) -> AsyncGenerator:
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,

View file

@ -77,7 +77,9 @@ def typeannotation(
"""
def wrap(cls: Type[T]) -> Type[T]:
cls.__repr__ = _compact_dataclass_repr
# mypy fails to equate bound-y functions (first argument interpreted as
# the bound object) with class methods, hence the `ignore` directive.
cls.__repr__ = _compact_dataclass_repr # type: ignore[method-assign]
if not dataclasses.is_dataclass(cls):
cls = dataclasses.dataclass( # type: ignore[call-overload]
cls,

View file

@ -627,7 +627,8 @@ class NamedTupleDeserializer(ClassDeserializer[NamedTuple]):
super().assign(property_parsers)
def create(self, **field_values: Any) -> NamedTuple:
return self.class_type(**field_values)
# mypy fails to deduce that this class returns NamedTuples only, hence the `ignore` directive
return self.class_type(**field_values) # type: ignore[no-any-return]
class DataclassDeserializer(ClassDeserializer[T]):

View file

@ -48,7 +48,7 @@ class DocstringParam:
name: str
description: str
param_type: type = inspect.Signature.empty
param_type: type | str = inspect.Signature.empty
def __str__(self) -> str:
return f":param {self.name}: {self.description}"

View file

@ -260,7 +260,8 @@ def extend_enum(
values: Dict[str, Any] = {}
values.update((e.name, e.value) for e in source)
values.update((e.name, e.value) for e in extend)
enum_class: Type[enum.Enum] = enum.Enum(extend.__name__, values) # type: ignore
# mypy fails to determine that __name__ is always a string; hence the `ignore` directive.
enum_class: Type[enum.Enum] = enum.Enum(extend.__name__, values) # type: ignore[misc]
# assign the newly created type to the same module where the extending class is defined
enum_class.__module__ = extend.__module__
@ -327,9 +328,7 @@ def _unwrap_optional_type(typ: Type[Optional[T]]) -> Type[T]:
raise TypeError("optional type must have un-subscripted type of Union")
# will automatically unwrap Union[T] into T
return Union[
tuple(filter(lambda item: item is not type(None), typing.get_args(typ))) # type: ignore
]
return Union[tuple(filter(lambda item: item is not type(None), typing.get_args(typ)))] # type: ignore[return-value]
def is_type_union(typ: object) -> bool:
@ -431,7 +430,7 @@ def _unwrap_generic_list(typ: Type[List[T]]) -> Type[T]:
"Extracts the item type of a list type (e.g. returns `T` for `List[T]`)."
(list_type,) = typing.get_args(typ) # unpack single tuple element
return list_type
return list_type # type: ignore[no-any-return]
def is_generic_set(typ: object) -> TypeGuard[Type[set]]:
@ -456,7 +455,7 @@ def _unwrap_generic_set(typ: Type[Set[T]]) -> Type[T]:
"Extracts the item type of a set type (e.g. returns `T` for `Set[T]`)."
(set_type,) = typing.get_args(typ) # unpack single tuple element
return set_type
return set_type # type: ignore[no-any-return]
def is_generic_dict(typ: object) -> TypeGuard[Type[dict]]:
@ -513,7 +512,7 @@ def unwrap_annotated_type(typ: T) -> T:
if is_type_annotated(typ):
# type is Annotated[T, ...]
return typing.get_args(typ)[0]
return typing.get_args(typ)[0] # type: ignore[no-any-return]
else:
# type is a regular type
return typ
@ -538,7 +537,7 @@ def rewrap_annotated_type(transform: Callable[[Type[S]], Type[T]], typ: Type[S])
transformed_type = transform(inner_type)
if metadata is not None:
return Annotated[(transformed_type, *metadata)] # type: ignore
return Annotated[(transformed_type, *metadata)] # type: ignore[return-value]
else:
return transformed_type
@ -563,7 +562,7 @@ else:
return typing.get_type_hints(typ)
def get_class_properties(typ: type) -> Iterable[Tuple[str, type]]:
def get_class_properties(typ: type) -> Iterable[Tuple[str, type | str]]:
"Returns all properties of a class."
if is_dataclass_type(typ):
@ -573,7 +572,7 @@ def get_class_properties(typ: type) -> Iterable[Tuple[str, type]]:
return resolved_hints.items()
def get_class_property(typ: type, name: str) -> Optional[type]:
def get_class_property(typ: type, name: str) -> Optional[type | str]:
"Looks up the annotated type of a property in a class by its property name."
for property_name, property_type in get_class_properties(typ):

View file

@ -460,13 +460,17 @@ class JsonSchemaGenerator:
discriminator = None
if typing.get_origin(data_type) is Annotated:
discriminator = typing.get_args(data_type)[1].discriminator
ret = {"oneOf": [self.type_to_schema(union_type) for union_type in typing.get_args(typ)]}
ret: Schema = {"oneOf": [self.type_to_schema(union_type) for union_type in typing.get_args(typ)]}
if discriminator:
# for each union type, we need to read the value of the discriminator
mapping = {}
mapping: dict[str, JsonType] = {}
for union_type in typing.get_args(typ):
props = self.type_to_schema(union_type, force_expand=True)["properties"]
mapping[props[discriminator]["default"]] = self.type_to_schema(union_type)["$ref"]
# mypy is confused here because JsonType allows multiple types, some of them
# not indexable (bool?) or not indexable by string (list?). The correctness of
# types depends on correct model definitions. Hence multiple ignore statements below.
discriminator_value = props[discriminator]["default"] # type: ignore[index,call-overload]
mapping[discriminator_value] = self.type_to_schema(union_type)["$ref"] # type: ignore[index]
ret["discriminator"] = {
"propertyName": discriminator,

View file

@ -134,7 +134,10 @@ class IPv6Serializer(Serializer[ipaddress.IPv6Address]):
class EnumSerializer(Serializer[enum.Enum]):
def generate(self, obj: enum.Enum) -> Union[int, str]:
return obj.value
value = obj.value
if isinstance(value, int):
return value
return str(value)
class UntypedListSerializer(Serializer[list]):

View file

@ -214,7 +214,6 @@ exclude = [
"^llama_stack/models/llama/llama3/tool_utils\\.py$",
"^llama_stack/models/llama/llama3_3/prompts\\.py$",
"^llama_stack/models/llama/sku_list\\.py$",
"^llama_stack/providers/datatypes\\.py$",
"^llama_stack/providers/inline/agents/meta_reference/",
"^llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$",
"^llama_stack/providers/inline/agents/meta_reference/agents\\.py$",
@ -248,7 +247,6 @@ exclude = [
"^llama_stack/providers/remote/inference/gemini/",
"^llama_stack/providers/remote/inference/groq/",
"^llama_stack/providers/remote/inference/nvidia/",
"^llama_stack/providers/remote/inference/ollama/",
"^llama_stack/providers/remote/inference/openai/",
"^llama_stack/providers/remote/inference/passthrough/",
"^llama_stack/providers/remote/inference/runpod/",
@ -256,7 +254,6 @@ exclude = [
"^llama_stack/providers/remote/inference/sample/",
"^llama_stack/providers/remote/inference/tgi/",
"^llama_stack/providers/remote/inference/together/",
"^llama_stack/providers/remote/inference/vllm/",
"^llama_stack/providers/remote/safety/bedrock/",
"^llama_stack/providers/remote/safety/nvidia/",
"^llama_stack/providers/remote/safety/sample/",
@ -292,11 +289,6 @@ exclude = [
"^llama_stack/providers/utils/telemetry/dataset_mixin\\.py$",
"^llama_stack/providers/utils/telemetry/trace_protocol\\.py$",
"^llama_stack/providers/utils/telemetry/tracing\\.py$",
"^llama_stack/strong_typing/auxiliary\\.py$",
"^llama_stack/strong_typing/deserializer\\.py$",
"^llama_stack/strong_typing/inspection\\.py$",
"^llama_stack/strong_typing/schema\\.py$",
"^llama_stack/strong_typing/serializer\\.py$",
"^llama_stack/templates/dev/dev\\.py$",
"^llama_stack/templates/groq/groq\\.py$",
"^llama_stack/templates/sambanova/sambanova\\.py$",