From 66d6c2580e62c267cb7fc42c7135a1a768c87629 Mon Sep 17 00:00:00 2001 From: Ihar Hrachyshka Date: Tue, 1 Apr 2025 11:12:39 -0400 Subject: [PATCH] chore: more mypy checks (ollama, vllm, ...) (#1777) # What does this PR do? - **chore: mypy for strong_typing** - **chore: mypy for remote::vllm** - **chore: mypy for remote::ollama** - **chore: mypy for providers.datatype** --------- Signed-off-by: Ihar Hrachyshka --- llama_stack/apis/inference/inference.py | 4 +- llama_stack/providers/datatypes.py | 8 +-- .../sentence_transformers.py | 2 +- .../remote/inference/ollama/ollama.py | 40 ++++++++----- .../providers/remote/inference/tgi/tgi.py | 2 +- .../providers/remote/inference/vllm/vllm.py | 56 +++++++++++-------- .../utils/inference/model_registry.py | 3 + .../utils/inference/openai_compat.py | 9 ++- llama_stack/strong_typing/auxiliary.py | 4 +- llama_stack/strong_typing/deserializer.py | 3 +- llama_stack/strong_typing/docstring.py | 2 +- llama_stack/strong_typing/inspection.py | 19 +++---- llama_stack/strong_typing/schema.py | 10 +++- llama_stack/strong_typing/serializer.py | 5 +- pyproject.toml | 8 --- 15 files changed, 103 insertions(+), 72 deletions(-) diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 7d3539dcb..1d4012c19 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -394,7 +394,7 @@ class EmbeddingsResponse(BaseModel): class ModelStore(Protocol): - def get_model(self, identifier: str) -> Model: ... + async def get_model(self, identifier: str) -> Model: ... class TextTruncation(Enum): @@ -431,7 +431,7 @@ class Inference(Protocol): - Embedding models: these models generate embeddings to be used for semantic search. """ - model_store: ModelStore + model_store: ModelStore | None = None @webmethod(route="/inference/completion", method="POST") async def completion( diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 384582423..32dfba30c 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -21,7 +21,7 @@ from llama_stack.schema_utils import json_schema_type class ModelsProtocolPrivate(Protocol): - async def register_model(self, model: Model) -> None: ... + async def register_model(self, model: Model) -> Model: ... async def unregister_model(self, model_id: str) -> None: ... @@ -113,8 +113,7 @@ Fully-qualified name of the module to import. The module is expected to have: default_factory=list, description="The pip dependencies needed for this implementation", ) - config_class: Optional[str] = Field( - default=None, + config_class: str = Field( description="Fully-qualified classname of the config for this provider", ) provider_data_validator: Optional[str] = Field( @@ -162,7 +161,8 @@ class RemoteProviderConfig(BaseModel): @classmethod def from_url(cls, url: str) -> "RemoteProviderConfig": parsed = urlparse(url) - return cls(host=parsed.hostname, port=parsed.port, protocol=parsed.scheme) + attrs = {k: v for k, v in parsed._asdict().items() if v is not None} + return cls(**attrs) @json_schema_type diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index b583896ad..39847e085 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -43,7 +43,7 @@ class SentenceTransformersInferenceImpl( async def shutdown(self) -> None: pass - async def register_model(self, model: Model) -> None: + async def register_model(self, model: Model) -> Model: return model async def unregister_model(self, model_id: str) -> None: diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 36941480c..5a78c07cc 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -5,7 +5,7 @@ # the root directory of this source tree. -from typing import AsyncGenerator, List, Optional, Union +from typing import Any, AsyncGenerator, List, Optional, Union import httpx from ollama import AsyncClient @@ -19,10 +19,15 @@ from llama_stack.apis.common.content_types import ( from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, + ChatCompletionResponseStreamChunk, CompletionRequest, + CompletionResponse, + CompletionResponseStreamChunk, EmbeddingsResponse, EmbeddingTaskType, + GrammarResponseFormat, Inference, + JsonSchemaResponseFormat, LogProbConfig, Message, ResponseFormat, @@ -86,6 +91,11 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): async def unregister_model(self, model_id: str) -> None: pass + async def _get_model(self, model_id: str) -> Model: + if not self.model_store: + raise ValueError("Model store not set") + return await self.model_store.get_model(model_id) + async def completion( self, model_id: str, @@ -94,10 +104,10 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, - ) -> AsyncGenerator: + ) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]: if sampling_params is None: sampling_params = SamplingParams() - model = await self.model_store.get_model(model_id) + model = await self._get_model(model_id) request = CompletionRequest( model=model.provider_resource_id, content=content, @@ -111,7 +121,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): else: return await self._nonstream_completion(request) - async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: + async def _stream_completion( + self, request: CompletionRequest + ) -> AsyncGenerator[CompletionResponseStreamChunk, None]: params = await self._get_params(request) async def _generate_and_convert_to_openai_compat(): @@ -129,7 +141,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): async for chunk in process_completion_stream_response(stream): yield chunk - async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator: + async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse: params = await self._get_params(request) r = await self.client.generate(**params) @@ -148,17 +160,17 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): model_id: str, messages: List[Message], sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = None, + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, tool_config: Optional[ToolConfig] = None, - ) -> AsyncGenerator: + ) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]: if sampling_params is None: sampling_params = SamplingParams() - model = await self.model_store.get_model(model_id) + model = await self._get_model(model_id) request = ChatCompletionRequest( model=model.provider_resource_id, messages=messages, @@ -181,7 +193,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): if sampling_options.get("max_tokens") is not None: sampling_options["num_predict"] = sampling_options["max_tokens"] - input_dict = {} + input_dict: dict[str, Any] = {} media_present = request_has_media(request) llama_model = self.register_helper.get_llama_model(request.model) if isinstance(request, ChatCompletionRequest): @@ -201,9 +213,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): input_dict["raw"] = True if fmt := request.response_format: - if fmt.type == "json_schema": + if isinstance(fmt, JsonSchemaResponseFormat): input_dict["format"] = fmt.json_schema - elif fmt.type == "grammar": + elif isinstance(fmt, GrammarResponseFormat): raise NotImplementedError("Grammar response format is not supported") else: raise ValueError(f"Unknown response format type: {fmt.type}") @@ -240,7 +252,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): ) return process_chat_completion_response(response, request) - async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: + async def _stream_chat_completion( + self, request: ChatCompletionRequest + ) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]: params = await self._get_params(request) async def _generate_and_convert_to_openai_compat(): @@ -275,7 +289,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): output_dimension: Optional[int] = None, task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: - model = await self.model_store.get_model(model_id) + model = await self._get_model(model_id) assert all(not content_has_media(content) for content in contents), ( "Ollama does not support media for embeddings" diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index 757085fb1..fe99fafe1 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -83,7 +83,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): async def shutdown(self) -> None: pass - async def register_model(self, model: Model) -> None: + async def register_model(self, model: Model) -> Model: model = await self.register_helper.register_model(model) if model.provider_resource_id != self.model_id: raise ValueError( diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index eda1a179c..6a828322f 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import json import logging -from typing import AsyncGenerator, List, Optional, Union +from typing import Any, AsyncGenerator, List, Optional, Union import httpx from openai import AsyncOpenAI @@ -32,11 +32,12 @@ from llama_stack.apis.inference import ( CompletionResponseStreamChunk, EmbeddingsResponse, EmbeddingTaskType, + GrammarResponseFormat, Inference, + JsonSchemaResponseFormat, LogProbConfig, Message, ResponseFormat, - ResponseFormatType, SamplingParams, TextTruncation, ToolChoice, @@ -102,9 +103,6 @@ def _convert_to_vllm_tool_calls_in_response( def _convert_to_vllm_tools_in_request(tools: List[ToolDefinition]) -> List[dict]: - if tools is None: - return tools - compat_tools = [] for tool in tools: @@ -141,9 +139,7 @@ def _convert_to_vllm_tools_in_request(tools: List[ToolDefinition]) -> List[dict] compat_tools.append(compat_tool) - if len(compat_tools) > 0: - return compat_tools - return None + return compat_tools def _convert_to_vllm_finish_reason(finish_reason: str) -> StopReason: @@ -206,9 +202,10 @@ async def _process_vllm_chat_completion_stream_response( ) elif choice.delta.tool_calls: tool_call = convert_tool_call(choice.delta.tool_calls[0]) - tool_call_buf.tool_name += tool_call.tool_name + tool_call_buf.tool_name += str(tool_call.tool_name) tool_call_buf.call_id += tool_call.call_id - tool_call_buf.arguments += tool_call.arguments + # TODO: remove str() when dict type for 'arguments' is no longer allowed + tool_call_buf.arguments += str(tool_call.arguments) else: yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( @@ -240,6 +237,11 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): async def unregister_model(self, model_id: str) -> None: pass + async def _get_model(self, model_id: str) -> Model: + if not self.model_store: + raise ValueError("Model store not set") + return await self.model_store.get_model(model_id) + async def completion( self, model_id: str, @@ -248,10 +250,10 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, - ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: + ) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]: if sampling_params is None: sampling_params = SamplingParams() - model = await self.model_store.get_model(model_id) + model = await self._get_model(model_id) request = CompletionRequest( model=model.provider_resource_id, content=content, @@ -270,17 +272,17 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): model_id: str, messages: List[Message], sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = None, + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, tool_config: Optional[ToolConfig] = None, - ) -> AsyncGenerator: + ) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]: if sampling_params is None: sampling_params = SamplingParams() - model = await self.model_store.get_model(model_id) + model = await self._get_model(model_id) # This is to be consistent with OpenAI API and support vLLM <= v0.6.3 # References: # * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice @@ -318,11 +320,13 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): ) return result - async def _stream_chat_completion(self, request: ChatCompletionRequest, client: AsyncOpenAI) -> AsyncGenerator: + async def _stream_chat_completion( + self, request: ChatCompletionRequest, client: AsyncOpenAI + ) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]: params = await self._get_params(request) stream = await client.chat.completions.create(**params) - if len(request.tools) > 0: + if request.tools: res = _process_vllm_chat_completion_stream_response(stream) else: res = process_chat_completion_stream_response(stream, request) @@ -330,11 +334,15 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): yield chunk async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse: + assert self.client is not None params = await self._get_params(request) r = await self.client.completions.create(**params) return process_completion_response(r) - async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: + async def _stream_completion( + self, request: CompletionRequest + ) -> AsyncGenerator[CompletionResponseStreamChunk, None]: + assert self.client is not None params = await self._get_params(request) stream = await self.client.completions.create(**params) @@ -342,6 +350,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): yield chunk async def register_model(self, model: Model) -> Model: + assert self.client is not None model = await self.register_helper.register_model(model) res = await self.client.models.list() available_models = [m.id async for m in res] @@ -357,7 +366,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): if "max_tokens" not in options: options["max_tokens"] = self.config.max_tokens - input_dict = {} + input_dict: dict[str, Any] = {} if isinstance(request, ChatCompletionRequest) and request.tools is not None: input_dict = {"tools": _convert_to_vllm_tools_in_request(request.tools)} @@ -368,9 +377,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): input_dict["prompt"] = await completion_request_to_prompt(request) if fmt := request.response_format: - if fmt.type == ResponseFormatType.json_schema.value: - input_dict["extra_body"] = {"guided_json": request.response_format.json_schema} - elif fmt.type == ResponseFormatType.grammar.value: + if isinstance(fmt, JsonSchemaResponseFormat): + input_dict["extra_body"] = {"guided_json": fmt.json_schema} + elif isinstance(fmt, GrammarResponseFormat): raise NotImplementedError("Grammar response format not supported yet") else: raise ValueError(f"Unknown response format {fmt.type}") @@ -393,7 +402,8 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): output_dimension: Optional[int] = None, task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: - model = await self.model_store.get_model(model_id) + assert self.client is not None + model = await self._get_model(model_id) kwargs = {} assert model.model_type == ModelType.embedding diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index d9e24662a..a11c734df 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -104,3 +104,6 @@ class ModelRegistryHelper(ModelsProtocolPrivate): ) return model + + async def unregister_model(self, model_id: str) -> None: + pass diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 07976e811..340ce8923 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -137,7 +137,10 @@ def get_sampling_strategy_options(params: SamplingParams) -> dict: return options -def get_sampling_options(params: SamplingParams) -> dict: +def get_sampling_options(params: SamplingParams | None) -> dict: + if not params: + return {} + options = {} if params: options.update(get_sampling_strategy_options(params)) @@ -297,7 +300,7 @@ def process_chat_completion_response( async def process_completion_stream_response( stream: AsyncGenerator[OpenAICompatCompletionResponse, None], -) -> AsyncGenerator: +) -> AsyncGenerator[CompletionResponseStreamChunk, None]: stop_reason = None async for chunk in stream: @@ -334,7 +337,7 @@ async def process_completion_stream_response( async def process_chat_completion_stream_response( stream: AsyncGenerator[OpenAICompatCompletionResponse, None], request: ChatCompletionRequest, -) -> AsyncGenerator: +) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]: yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.start, diff --git a/llama_stack/strong_typing/auxiliary.py b/llama_stack/strong_typing/auxiliary.py index cf19d6083..965ffa079 100644 --- a/llama_stack/strong_typing/auxiliary.py +++ b/llama_stack/strong_typing/auxiliary.py @@ -77,7 +77,9 @@ def typeannotation( """ def wrap(cls: Type[T]) -> Type[T]: - cls.__repr__ = _compact_dataclass_repr + # mypy fails to equate bound-y functions (first argument interpreted as + # the bound object) with class methods, hence the `ignore` directive. + cls.__repr__ = _compact_dataclass_repr # type: ignore[method-assign] if not dataclasses.is_dataclass(cls): cls = dataclasses.dataclass( # type: ignore[call-overload] cls, diff --git a/llama_stack/strong_typing/deserializer.py b/llama_stack/strong_typing/deserializer.py index fc0f40f83..883590862 100644 --- a/llama_stack/strong_typing/deserializer.py +++ b/llama_stack/strong_typing/deserializer.py @@ -627,7 +627,8 @@ class NamedTupleDeserializer(ClassDeserializer[NamedTuple]): super().assign(property_parsers) def create(self, **field_values: Any) -> NamedTuple: - return self.class_type(**field_values) + # mypy fails to deduce that this class returns NamedTuples only, hence the `ignore` directive + return self.class_type(**field_values) # type: ignore[no-any-return] class DataclassDeserializer(ClassDeserializer[T]): diff --git a/llama_stack/strong_typing/docstring.py b/llama_stack/strong_typing/docstring.py index 9169aadfe..b038d1024 100644 --- a/llama_stack/strong_typing/docstring.py +++ b/llama_stack/strong_typing/docstring.py @@ -48,7 +48,7 @@ class DocstringParam: name: str description: str - param_type: type = inspect.Signature.empty + param_type: type | str = inspect.Signature.empty def __str__(self) -> str: return f":param {self.name}: {self.description}" diff --git a/llama_stack/strong_typing/inspection.py b/llama_stack/strong_typing/inspection.py index 8bc313021..a75a170cf 100644 --- a/llama_stack/strong_typing/inspection.py +++ b/llama_stack/strong_typing/inspection.py @@ -260,7 +260,8 @@ def extend_enum( values: Dict[str, Any] = {} values.update((e.name, e.value) for e in source) values.update((e.name, e.value) for e in extend) - enum_class: Type[enum.Enum] = enum.Enum(extend.__name__, values) # type: ignore + # mypy fails to determine that __name__ is always a string; hence the `ignore` directive. + enum_class: Type[enum.Enum] = enum.Enum(extend.__name__, values) # type: ignore[misc] # assign the newly created type to the same module where the extending class is defined enum_class.__module__ = extend.__module__ @@ -327,9 +328,7 @@ def _unwrap_optional_type(typ: Type[Optional[T]]) -> Type[T]: raise TypeError("optional type must have un-subscripted type of Union") # will automatically unwrap Union[T] into T - return Union[ - tuple(filter(lambda item: item is not type(None), typing.get_args(typ))) # type: ignore - ] + return Union[tuple(filter(lambda item: item is not type(None), typing.get_args(typ)))] # type: ignore[return-value] def is_type_union(typ: object) -> bool: @@ -431,7 +430,7 @@ def _unwrap_generic_list(typ: Type[List[T]]) -> Type[T]: "Extracts the item type of a list type (e.g. returns `T` for `List[T]`)." (list_type,) = typing.get_args(typ) # unpack single tuple element - return list_type + return list_type # type: ignore[no-any-return] def is_generic_set(typ: object) -> TypeGuard[Type[set]]: @@ -456,7 +455,7 @@ def _unwrap_generic_set(typ: Type[Set[T]]) -> Type[T]: "Extracts the item type of a set type (e.g. returns `T` for `Set[T]`)." (set_type,) = typing.get_args(typ) # unpack single tuple element - return set_type + return set_type # type: ignore[no-any-return] def is_generic_dict(typ: object) -> TypeGuard[Type[dict]]: @@ -513,7 +512,7 @@ def unwrap_annotated_type(typ: T) -> T: if is_type_annotated(typ): # type is Annotated[T, ...] - return typing.get_args(typ)[0] + return typing.get_args(typ)[0] # type: ignore[no-any-return] else: # type is a regular type return typ @@ -538,7 +537,7 @@ def rewrap_annotated_type(transform: Callable[[Type[S]], Type[T]], typ: Type[S]) transformed_type = transform(inner_type) if metadata is not None: - return Annotated[(transformed_type, *metadata)] # type: ignore + return Annotated[(transformed_type, *metadata)] # type: ignore[return-value] else: return transformed_type @@ -563,7 +562,7 @@ else: return typing.get_type_hints(typ) -def get_class_properties(typ: type) -> Iterable[Tuple[str, type]]: +def get_class_properties(typ: type) -> Iterable[Tuple[str, type | str]]: "Returns all properties of a class." if is_dataclass_type(typ): @@ -573,7 +572,7 @@ def get_class_properties(typ: type) -> Iterable[Tuple[str, type]]: return resolved_hints.items() -def get_class_property(typ: type, name: str) -> Optional[type]: +def get_class_property(typ: type, name: str) -> Optional[type | str]: "Looks up the annotated type of a property in a class by its property name." for property_name, property_type in get_class_properties(typ): diff --git a/llama_stack/strong_typing/schema.py b/llama_stack/strong_typing/schema.py index de69c9b82..0f5121906 100644 --- a/llama_stack/strong_typing/schema.py +++ b/llama_stack/strong_typing/schema.py @@ -460,13 +460,17 @@ class JsonSchemaGenerator: discriminator = None if typing.get_origin(data_type) is Annotated: discriminator = typing.get_args(data_type)[1].discriminator - ret = {"oneOf": [self.type_to_schema(union_type) for union_type in typing.get_args(typ)]} + ret: Schema = {"oneOf": [self.type_to_schema(union_type) for union_type in typing.get_args(typ)]} if discriminator: # for each union type, we need to read the value of the discriminator - mapping = {} + mapping: dict[str, JsonType] = {} for union_type in typing.get_args(typ): props = self.type_to_schema(union_type, force_expand=True)["properties"] - mapping[props[discriminator]["default"]] = self.type_to_schema(union_type)["$ref"] + # mypy is confused here because JsonType allows multiple types, some of them + # not indexable (bool?) or not indexable by string (list?). The correctness of + # types depends on correct model definitions. Hence multiple ignore statements below. + discriminator_value = props[discriminator]["default"] # type: ignore[index,call-overload] + mapping[discriminator_value] = self.type_to_schema(union_type)["$ref"] # type: ignore[index] ret["discriminator"] = { "propertyName": discriminator, diff --git a/llama_stack/strong_typing/serializer.py b/llama_stack/strong_typing/serializer.py index 4ca4a4119..17848c14b 100644 --- a/llama_stack/strong_typing/serializer.py +++ b/llama_stack/strong_typing/serializer.py @@ -134,7 +134,10 @@ class IPv6Serializer(Serializer[ipaddress.IPv6Address]): class EnumSerializer(Serializer[enum.Enum]): def generate(self, obj: enum.Enum) -> Union[int, str]: - return obj.value + value = obj.value + if isinstance(value, int): + return value + return str(value) class UntypedListSerializer(Serializer[list]): diff --git a/pyproject.toml b/pyproject.toml index 1f7073411..5ddab1065 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -214,7 +214,6 @@ exclude = [ "^llama_stack/models/llama/llama3/tool_utils\\.py$", "^llama_stack/models/llama/llama3_3/prompts\\.py$", "^llama_stack/models/llama/sku_list\\.py$", - "^llama_stack/providers/datatypes\\.py$", "^llama_stack/providers/inline/agents/meta_reference/", "^llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$", "^llama_stack/providers/inline/agents/meta_reference/agents\\.py$", @@ -248,7 +247,6 @@ exclude = [ "^llama_stack/providers/remote/inference/gemini/", "^llama_stack/providers/remote/inference/groq/", "^llama_stack/providers/remote/inference/nvidia/", - "^llama_stack/providers/remote/inference/ollama/", "^llama_stack/providers/remote/inference/openai/", "^llama_stack/providers/remote/inference/passthrough/", "^llama_stack/providers/remote/inference/runpod/", @@ -256,7 +254,6 @@ exclude = [ "^llama_stack/providers/remote/inference/sample/", "^llama_stack/providers/remote/inference/tgi/", "^llama_stack/providers/remote/inference/together/", - "^llama_stack/providers/remote/inference/vllm/", "^llama_stack/providers/remote/safety/bedrock/", "^llama_stack/providers/remote/safety/nvidia/", "^llama_stack/providers/remote/safety/sample/", @@ -292,11 +289,6 @@ exclude = [ "^llama_stack/providers/utils/telemetry/dataset_mixin\\.py$", "^llama_stack/providers/utils/telemetry/trace_protocol\\.py$", "^llama_stack/providers/utils/telemetry/tracing\\.py$", - "^llama_stack/strong_typing/auxiliary\\.py$", - "^llama_stack/strong_typing/deserializer\\.py$", - "^llama_stack/strong_typing/inspection\\.py$", - "^llama_stack/strong_typing/schema\\.py$", - "^llama_stack/strong_typing/serializer\\.py$", "^llama_stack/templates/dev/dev\\.py$", "^llama_stack/templates/groq/groq\\.py$", "^llama_stack/templates/sambanova/sambanova\\.py$",