forked from phoenix-oss/llama-stack-mirror
chore: more mypy checks (ollama, vllm, ...) (#1777)
# What does this PR do? - **chore: mypy for strong_typing** - **chore: mypy for remote::vllm** - **chore: mypy for remote::ollama** - **chore: mypy for providers.datatype** --------- Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
d5e0f32485
commit
66d6c2580e
15 changed files with 103 additions and 72 deletions
|
@ -394,7 +394,7 @@ class EmbeddingsResponse(BaseModel):
|
|||
|
||||
|
||||
class ModelStore(Protocol):
|
||||
def get_model(self, identifier: str) -> Model: ...
|
||||
async def get_model(self, identifier: str) -> Model: ...
|
||||
|
||||
|
||||
class TextTruncation(Enum):
|
||||
|
@ -431,7 +431,7 @@ class Inference(Protocol):
|
|||
- Embedding models: these models generate embeddings to be used for semantic search.
|
||||
"""
|
||||
|
||||
model_store: ModelStore
|
||||
model_store: ModelStore | None = None
|
||||
|
||||
@webmethod(route="/inference/completion", method="POST")
|
||||
async def completion(
|
||||
|
|
|
@ -21,7 +21,7 @@ from llama_stack.schema_utils import json_schema_type
|
|||
|
||||
|
||||
class ModelsProtocolPrivate(Protocol):
|
||||
async def register_model(self, model: Model) -> None: ...
|
||||
async def register_model(self, model: Model) -> Model: ...
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None: ...
|
||||
|
||||
|
@ -113,8 +113,7 @@ Fully-qualified name of the module to import. The module is expected to have:
|
|||
default_factory=list,
|
||||
description="The pip dependencies needed for this implementation",
|
||||
)
|
||||
config_class: Optional[str] = Field(
|
||||
default=None,
|
||||
config_class: str = Field(
|
||||
description="Fully-qualified classname of the config for this provider",
|
||||
)
|
||||
provider_data_validator: Optional[str] = Field(
|
||||
|
@ -162,7 +161,8 @@ class RemoteProviderConfig(BaseModel):
|
|||
@classmethod
|
||||
def from_url(cls, url: str) -> "RemoteProviderConfig":
|
||||
parsed = urlparse(url)
|
||||
return cls(host=parsed.hostname, port=parsed.port, protocol=parsed.scheme)
|
||||
attrs = {k: v for k, v in parsed._asdict().items() if v is not None}
|
||||
return cls(**attrs)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -43,7 +43,7 @@ class SentenceTransformersInferenceImpl(
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_model(self, model: Model) -> None:
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
return model
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
from typing import Any, AsyncGenerator, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
from ollama import AsyncClient
|
||||
|
@ -19,10 +19,15 @@ from llama_stack.apis.common.content_types import (
|
|||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
GrammarResponseFormat,
|
||||
Inference,
|
||||
JsonSchemaResponseFormat,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
|
@ -86,6 +91,11 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
async def _get_model(self, model_id: str) -> Model:
|
||||
if not self.model_store:
|
||||
raise ValueError("Model store not set")
|
||||
return await self.model_store.get_model(model_id)
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
|
@ -94,10 +104,10 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
model = await self._get_model(model_id)
|
||||
request = CompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
content=content,
|
||||
|
@ -111,7 +121,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
else:
|
||||
return await self._nonstream_completion(request)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
async def _stream_completion(
|
||||
self, request: CompletionRequest
|
||||
) -> AsyncGenerator[CompletionResponseStreamChunk, None]:
|
||||
params = await self._get_params(request)
|
||||
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
|
@ -129,7 +141,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
async for chunk in process_completion_stream_response(stream):
|
||||
yield chunk
|
||||
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
r = await self.client.generate(**params)
|
||||
|
||||
|
@ -148,17 +160,17 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
model = await self._get_model(model_id)
|
||||
request = ChatCompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
messages=messages,
|
||||
|
@ -181,7 +193,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
if sampling_options.get("max_tokens") is not None:
|
||||
sampling_options["num_predict"] = sampling_options["max_tokens"]
|
||||
|
||||
input_dict = {}
|
||||
input_dict: dict[str, Any] = {}
|
||||
media_present = request_has_media(request)
|
||||
llama_model = self.register_helper.get_llama_model(request.model)
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
|
@ -201,9 +213,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
input_dict["raw"] = True
|
||||
|
||||
if fmt := request.response_format:
|
||||
if fmt.type == "json_schema":
|
||||
if isinstance(fmt, JsonSchemaResponseFormat):
|
||||
input_dict["format"] = fmt.json_schema
|
||||
elif fmt.type == "grammar":
|
||||
elif isinstance(fmt, GrammarResponseFormat):
|
||||
raise NotImplementedError("Grammar response format is not supported")
|
||||
else:
|
||||
raise ValueError(f"Unknown response format type: {fmt.type}")
|
||||
|
@ -240,7 +252,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
)
|
||||
return process_chat_completion_response(response, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||
params = await self._get_params(request)
|
||||
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
|
@ -275,7 +289,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
model = await self._get_model(model_id)
|
||||
|
||||
assert all(not content_has_media(content) for content in contents), (
|
||||
"Ollama does not support media for embeddings"
|
||||
|
|
|
@ -83,7 +83,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_model(self, model: Model) -> None:
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
model = await self.register_helper.register_model(model)
|
||||
if model.provider_resource_id != self.model_id:
|
||||
raise ValueError(
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
import json
|
||||
import logging
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
from typing import Any, AsyncGenerator, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
from openai import AsyncOpenAI
|
||||
|
@ -32,11 +32,12 @@ from llama_stack.apis.inference import (
|
|||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
GrammarResponseFormat,
|
||||
Inference,
|
||||
JsonSchemaResponseFormat,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
ResponseFormatType,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
|
@ -102,9 +103,6 @@ def _convert_to_vllm_tool_calls_in_response(
|
|||
|
||||
|
||||
def _convert_to_vllm_tools_in_request(tools: List[ToolDefinition]) -> List[dict]:
|
||||
if tools is None:
|
||||
return tools
|
||||
|
||||
compat_tools = []
|
||||
|
||||
for tool in tools:
|
||||
|
@ -141,9 +139,7 @@ def _convert_to_vllm_tools_in_request(tools: List[ToolDefinition]) -> List[dict]
|
|||
|
||||
compat_tools.append(compat_tool)
|
||||
|
||||
if len(compat_tools) > 0:
|
||||
return compat_tools
|
||||
return None
|
||||
return compat_tools
|
||||
|
||||
|
||||
def _convert_to_vllm_finish_reason(finish_reason: str) -> StopReason:
|
||||
|
@ -206,9 +202,10 @@ async def _process_vllm_chat_completion_stream_response(
|
|||
)
|
||||
elif choice.delta.tool_calls:
|
||||
tool_call = convert_tool_call(choice.delta.tool_calls[0])
|
||||
tool_call_buf.tool_name += tool_call.tool_name
|
||||
tool_call_buf.tool_name += str(tool_call.tool_name)
|
||||
tool_call_buf.call_id += tool_call.call_id
|
||||
tool_call_buf.arguments += tool_call.arguments
|
||||
# TODO: remove str() when dict type for 'arguments' is no longer allowed
|
||||
tool_call_buf.arguments += str(tool_call.arguments)
|
||||
else:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
|
@ -240,6 +237,11 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
async def _get_model(self, model_id: str) -> Model:
|
||||
if not self.model_store:
|
||||
raise ValueError("Model store not set")
|
||||
return await self.model_store.get_model(model_id)
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
|
@ -248,10 +250,10 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
model = await self._get_model(model_id)
|
||||
request = CompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
content=content,
|
||||
|
@ -270,17 +272,17 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
model = await self._get_model(model_id)
|
||||
# This is to be consistent with OpenAI API and support vLLM <= v0.6.3
|
||||
# References:
|
||||
# * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
|
||||
|
@ -318,11 +320,13 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
)
|
||||
return result
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: AsyncOpenAI) -> AsyncGenerator:
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: AsyncOpenAI
|
||||
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||
params = await self._get_params(request)
|
||||
|
||||
stream = await client.chat.completions.create(**params)
|
||||
if len(request.tools) > 0:
|
||||
if request.tools:
|
||||
res = _process_vllm_chat_completion_stream_response(stream)
|
||||
else:
|
||||
res = process_chat_completion_stream_response(stream, request)
|
||||
|
@ -330,11 +334,15 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
yield chunk
|
||||
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||
assert self.client is not None
|
||||
params = await self._get_params(request)
|
||||
r = await self.client.completions.create(**params)
|
||||
return process_completion_response(r)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
async def _stream_completion(
|
||||
self, request: CompletionRequest
|
||||
) -> AsyncGenerator[CompletionResponseStreamChunk, None]:
|
||||
assert self.client is not None
|
||||
params = await self._get_params(request)
|
||||
|
||||
stream = await self.client.completions.create(**params)
|
||||
|
@ -342,6 +350,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
yield chunk
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
assert self.client is not None
|
||||
model = await self.register_helper.register_model(model)
|
||||
res = await self.client.models.list()
|
||||
available_models = [m.id async for m in res]
|
||||
|
@ -357,7 +366,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
if "max_tokens" not in options:
|
||||
options["max_tokens"] = self.config.max_tokens
|
||||
|
||||
input_dict = {}
|
||||
input_dict: dict[str, Any] = {}
|
||||
if isinstance(request, ChatCompletionRequest) and request.tools is not None:
|
||||
input_dict = {"tools": _convert_to_vllm_tools_in_request(request.tools)}
|
||||
|
||||
|
@ -368,9 +377,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||
|
||||
if fmt := request.response_format:
|
||||
if fmt.type == ResponseFormatType.json_schema.value:
|
||||
input_dict["extra_body"] = {"guided_json": request.response_format.json_schema}
|
||||
elif fmt.type == ResponseFormatType.grammar.value:
|
||||
if isinstance(fmt, JsonSchemaResponseFormat):
|
||||
input_dict["extra_body"] = {"guided_json": fmt.json_schema}
|
||||
elif isinstance(fmt, GrammarResponseFormat):
|
||||
raise NotImplementedError("Grammar response format not supported yet")
|
||||
else:
|
||||
raise ValueError(f"Unknown response format {fmt.type}")
|
||||
|
@ -393,7 +402,8 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
assert self.client is not None
|
||||
model = await self._get_model(model_id)
|
||||
|
||||
kwargs = {}
|
||||
assert model.model_type == ModelType.embedding
|
||||
|
|
|
@ -104,3 +104,6 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
)
|
||||
|
||||
return model
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
|
|
@ -137,7 +137,10 @@ def get_sampling_strategy_options(params: SamplingParams) -> dict:
|
|||
return options
|
||||
|
||||
|
||||
def get_sampling_options(params: SamplingParams) -> dict:
|
||||
def get_sampling_options(params: SamplingParams | None) -> dict:
|
||||
if not params:
|
||||
return {}
|
||||
|
||||
options = {}
|
||||
if params:
|
||||
options.update(get_sampling_strategy_options(params))
|
||||
|
@ -297,7 +300,7 @@ def process_chat_completion_response(
|
|||
|
||||
async def process_completion_stream_response(
|
||||
stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
|
||||
) -> AsyncGenerator:
|
||||
) -> AsyncGenerator[CompletionResponseStreamChunk, None]:
|
||||
stop_reason = None
|
||||
|
||||
async for chunk in stream:
|
||||
|
@ -334,7 +337,7 @@ async def process_completion_stream_response(
|
|||
async def process_chat_completion_stream_response(
|
||||
stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
|
||||
request: ChatCompletionRequest,
|
||||
) -> AsyncGenerator:
|
||||
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
|
|
|
@ -77,7 +77,9 @@ def typeannotation(
|
|||
"""
|
||||
|
||||
def wrap(cls: Type[T]) -> Type[T]:
|
||||
cls.__repr__ = _compact_dataclass_repr
|
||||
# mypy fails to equate bound-y functions (first argument interpreted as
|
||||
# the bound object) with class methods, hence the `ignore` directive.
|
||||
cls.__repr__ = _compact_dataclass_repr # type: ignore[method-assign]
|
||||
if not dataclasses.is_dataclass(cls):
|
||||
cls = dataclasses.dataclass( # type: ignore[call-overload]
|
||||
cls,
|
||||
|
|
|
@ -627,7 +627,8 @@ class NamedTupleDeserializer(ClassDeserializer[NamedTuple]):
|
|||
super().assign(property_parsers)
|
||||
|
||||
def create(self, **field_values: Any) -> NamedTuple:
|
||||
return self.class_type(**field_values)
|
||||
# mypy fails to deduce that this class returns NamedTuples only, hence the `ignore` directive
|
||||
return self.class_type(**field_values) # type: ignore[no-any-return]
|
||||
|
||||
|
||||
class DataclassDeserializer(ClassDeserializer[T]):
|
||||
|
|
|
@ -48,7 +48,7 @@ class DocstringParam:
|
|||
|
||||
name: str
|
||||
description: str
|
||||
param_type: type = inspect.Signature.empty
|
||||
param_type: type | str = inspect.Signature.empty
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f":param {self.name}: {self.description}"
|
||||
|
|
|
@ -260,7 +260,8 @@ def extend_enum(
|
|||
values: Dict[str, Any] = {}
|
||||
values.update((e.name, e.value) for e in source)
|
||||
values.update((e.name, e.value) for e in extend)
|
||||
enum_class: Type[enum.Enum] = enum.Enum(extend.__name__, values) # type: ignore
|
||||
# mypy fails to determine that __name__ is always a string; hence the `ignore` directive.
|
||||
enum_class: Type[enum.Enum] = enum.Enum(extend.__name__, values) # type: ignore[misc]
|
||||
|
||||
# assign the newly created type to the same module where the extending class is defined
|
||||
enum_class.__module__ = extend.__module__
|
||||
|
@ -327,9 +328,7 @@ def _unwrap_optional_type(typ: Type[Optional[T]]) -> Type[T]:
|
|||
raise TypeError("optional type must have un-subscripted type of Union")
|
||||
|
||||
# will automatically unwrap Union[T] into T
|
||||
return Union[
|
||||
tuple(filter(lambda item: item is not type(None), typing.get_args(typ))) # type: ignore
|
||||
]
|
||||
return Union[tuple(filter(lambda item: item is not type(None), typing.get_args(typ)))] # type: ignore[return-value]
|
||||
|
||||
|
||||
def is_type_union(typ: object) -> bool:
|
||||
|
@ -431,7 +430,7 @@ def _unwrap_generic_list(typ: Type[List[T]]) -> Type[T]:
|
|||
"Extracts the item type of a list type (e.g. returns `T` for `List[T]`)."
|
||||
|
||||
(list_type,) = typing.get_args(typ) # unpack single tuple element
|
||||
return list_type
|
||||
return list_type # type: ignore[no-any-return]
|
||||
|
||||
|
||||
def is_generic_set(typ: object) -> TypeGuard[Type[set]]:
|
||||
|
@ -456,7 +455,7 @@ def _unwrap_generic_set(typ: Type[Set[T]]) -> Type[T]:
|
|||
"Extracts the item type of a set type (e.g. returns `T` for `Set[T]`)."
|
||||
|
||||
(set_type,) = typing.get_args(typ) # unpack single tuple element
|
||||
return set_type
|
||||
return set_type # type: ignore[no-any-return]
|
||||
|
||||
|
||||
def is_generic_dict(typ: object) -> TypeGuard[Type[dict]]:
|
||||
|
@ -513,7 +512,7 @@ def unwrap_annotated_type(typ: T) -> T:
|
|||
|
||||
if is_type_annotated(typ):
|
||||
# type is Annotated[T, ...]
|
||||
return typing.get_args(typ)[0]
|
||||
return typing.get_args(typ)[0] # type: ignore[no-any-return]
|
||||
else:
|
||||
# type is a regular type
|
||||
return typ
|
||||
|
@ -538,7 +537,7 @@ def rewrap_annotated_type(transform: Callable[[Type[S]], Type[T]], typ: Type[S])
|
|||
transformed_type = transform(inner_type)
|
||||
|
||||
if metadata is not None:
|
||||
return Annotated[(transformed_type, *metadata)] # type: ignore
|
||||
return Annotated[(transformed_type, *metadata)] # type: ignore[return-value]
|
||||
else:
|
||||
return transformed_type
|
||||
|
||||
|
@ -563,7 +562,7 @@ else:
|
|||
return typing.get_type_hints(typ)
|
||||
|
||||
|
||||
def get_class_properties(typ: type) -> Iterable[Tuple[str, type]]:
|
||||
def get_class_properties(typ: type) -> Iterable[Tuple[str, type | str]]:
|
||||
"Returns all properties of a class."
|
||||
|
||||
if is_dataclass_type(typ):
|
||||
|
@ -573,7 +572,7 @@ def get_class_properties(typ: type) -> Iterable[Tuple[str, type]]:
|
|||
return resolved_hints.items()
|
||||
|
||||
|
||||
def get_class_property(typ: type, name: str) -> Optional[type]:
|
||||
def get_class_property(typ: type, name: str) -> Optional[type | str]:
|
||||
"Looks up the annotated type of a property in a class by its property name."
|
||||
|
||||
for property_name, property_type in get_class_properties(typ):
|
||||
|
|
|
@ -460,13 +460,17 @@ class JsonSchemaGenerator:
|
|||
discriminator = None
|
||||
if typing.get_origin(data_type) is Annotated:
|
||||
discriminator = typing.get_args(data_type)[1].discriminator
|
||||
ret = {"oneOf": [self.type_to_schema(union_type) for union_type in typing.get_args(typ)]}
|
||||
ret: Schema = {"oneOf": [self.type_to_schema(union_type) for union_type in typing.get_args(typ)]}
|
||||
if discriminator:
|
||||
# for each union type, we need to read the value of the discriminator
|
||||
mapping = {}
|
||||
mapping: dict[str, JsonType] = {}
|
||||
for union_type in typing.get_args(typ):
|
||||
props = self.type_to_schema(union_type, force_expand=True)["properties"]
|
||||
mapping[props[discriminator]["default"]] = self.type_to_schema(union_type)["$ref"]
|
||||
# mypy is confused here because JsonType allows multiple types, some of them
|
||||
# not indexable (bool?) or not indexable by string (list?). The correctness of
|
||||
# types depends on correct model definitions. Hence multiple ignore statements below.
|
||||
discriminator_value = props[discriminator]["default"] # type: ignore[index,call-overload]
|
||||
mapping[discriminator_value] = self.type_to_schema(union_type)["$ref"] # type: ignore[index]
|
||||
|
||||
ret["discriminator"] = {
|
||||
"propertyName": discriminator,
|
||||
|
|
|
@ -134,7 +134,10 @@ class IPv6Serializer(Serializer[ipaddress.IPv6Address]):
|
|||
|
||||
class EnumSerializer(Serializer[enum.Enum]):
|
||||
def generate(self, obj: enum.Enum) -> Union[int, str]:
|
||||
return obj.value
|
||||
value = obj.value
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
|
||||
class UntypedListSerializer(Serializer[list]):
|
||||
|
|
|
@ -214,7 +214,6 @@ exclude = [
|
|||
"^llama_stack/models/llama/llama3/tool_utils\\.py$",
|
||||
"^llama_stack/models/llama/llama3_3/prompts\\.py$",
|
||||
"^llama_stack/models/llama/sku_list\\.py$",
|
||||
"^llama_stack/providers/datatypes\\.py$",
|
||||
"^llama_stack/providers/inline/agents/meta_reference/",
|
||||
"^llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$",
|
||||
"^llama_stack/providers/inline/agents/meta_reference/agents\\.py$",
|
||||
|
@ -248,7 +247,6 @@ exclude = [
|
|||
"^llama_stack/providers/remote/inference/gemini/",
|
||||
"^llama_stack/providers/remote/inference/groq/",
|
||||
"^llama_stack/providers/remote/inference/nvidia/",
|
||||
"^llama_stack/providers/remote/inference/ollama/",
|
||||
"^llama_stack/providers/remote/inference/openai/",
|
||||
"^llama_stack/providers/remote/inference/passthrough/",
|
||||
"^llama_stack/providers/remote/inference/runpod/",
|
||||
|
@ -256,7 +254,6 @@ exclude = [
|
|||
"^llama_stack/providers/remote/inference/sample/",
|
||||
"^llama_stack/providers/remote/inference/tgi/",
|
||||
"^llama_stack/providers/remote/inference/together/",
|
||||
"^llama_stack/providers/remote/inference/vllm/",
|
||||
"^llama_stack/providers/remote/safety/bedrock/",
|
||||
"^llama_stack/providers/remote/safety/nvidia/",
|
||||
"^llama_stack/providers/remote/safety/sample/",
|
||||
|
@ -292,11 +289,6 @@ exclude = [
|
|||
"^llama_stack/providers/utils/telemetry/dataset_mixin\\.py$",
|
||||
"^llama_stack/providers/utils/telemetry/trace_protocol\\.py$",
|
||||
"^llama_stack/providers/utils/telemetry/tracing\\.py$",
|
||||
"^llama_stack/strong_typing/auxiliary\\.py$",
|
||||
"^llama_stack/strong_typing/deserializer\\.py$",
|
||||
"^llama_stack/strong_typing/inspection\\.py$",
|
||||
"^llama_stack/strong_typing/schema\\.py$",
|
||||
"^llama_stack/strong_typing/serializer\\.py$",
|
||||
"^llama_stack/templates/dev/dev\\.py$",
|
||||
"^llama_stack/templates/groq/groq\\.py$",
|
||||
"^llama_stack/templates/sambanova/sambanova\\.py$",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue