diff --git a/src/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/src/llama_stack/providers/utils/inference/litellm_openai_mixin.py index 3eef1f272..223497fb8 100644 --- a/src/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/src/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -128,7 +128,9 @@ class LiteLLMOpenAIMixin( return schema async def _get_params(self, request: ChatCompletionRequest) -> dict: - input_dict = {} + from typing import Any + + input_dict: dict[str, Any] = {} input_dict["messages"] = [ await convert_message_to_openai_dict_new(m, download_images=self.download_images) for m in request.messages @@ -139,30 +141,27 @@ class LiteLLMOpenAIMixin( f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported." ) - fmt = fmt.json_schema - name = fmt["title"] - del fmt["title"] - fmt["additionalProperties"] = False + # Convert to dict for manipulation + fmt_dict = dict(fmt.json_schema) + name = fmt_dict["title"] + del fmt_dict["title"] + fmt_dict["additionalProperties"] = False # Apply additionalProperties: False recursively to all objects - fmt = self._add_additional_properties_recursive(fmt) + fmt_dict = self._add_additional_properties_recursive(fmt_dict) input_dict["response_format"] = { "type": "json_schema", "json_schema": { "name": name, - "schema": fmt, + "schema": fmt_dict, "strict": self.json_schema_strict, }, } if request.tools: input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools] - if request.tool_config.tool_choice: - input_dict["tool_choice"] = ( - request.tool_config.tool_choice.value - if isinstance(request.tool_config.tool_choice, ToolChoice) - else request.tool_config.tool_choice - ) + if request.tool_config and (tool_choice := request.tool_config.tool_choice): + input_dict["tool_choice"] = tool_choice.value if isinstance(tool_choice, ToolChoice) else tool_choice return { "model": request.model, @@ -176,10 +175,10 @@ class LiteLLMOpenAIMixin( def get_api_key(self) -> str: provider_data = self.get_request_provider_data() key_field = self.provider_data_api_key_field - if provider_data and getattr(provider_data, key_field, None): - api_key = getattr(provider_data, key_field) - else: - api_key = self.api_key_from_config + if provider_data and key_field and (api_key := getattr(provider_data, key_field, None)): + return str(api_key) # type: ignore[no-any-return] # getattr returns Any, can't narrow without runtime type inspection + + api_key = self.api_key_from_config if not api_key: raise ValueError( "API key is not set. Please provide a valid API key in the " @@ -192,7 +191,13 @@ class LiteLLMOpenAIMixin( self, params: OpenAIEmbeddingsRequestWithExtraBody, ) -> OpenAIEmbeddingsResponse: + if not self.model_store: + raise ValueError("Model store is not initialized") + model_obj = await self.model_store.get_model(params.model) + if model_obj.provider_resource_id is None: + raise ValueError(f"Model {params.model} has no provider_resource_id") + provider_resource_id = model_obj.provider_resource_id # Convert input to list if it's a string input_list = [params.input] if isinstance(params.input, str) else params.input @@ -200,7 +205,7 @@ class LiteLLMOpenAIMixin( # Call litellm embedding function # litellm.drop_params = True response = litellm.embedding( - model=self.get_litellm_model_name(model_obj.provider_resource_id), + model=self.get_litellm_model_name(provider_resource_id), input=input_list, api_key=self.get_api_key(), api_base=self.api_base, @@ -217,7 +222,7 @@ class LiteLLMOpenAIMixin( return OpenAIEmbeddingsResponse( data=data, - model=model_obj.provider_resource_id, + model=provider_resource_id, usage=usage, ) @@ -225,10 +230,16 @@ class LiteLLMOpenAIMixin( self, params: OpenAICompletionRequestWithExtraBody, ) -> OpenAICompletion: + if not self.model_store: + raise ValueError("Model store is not initialized") + model_obj = await self.model_store.get_model(params.model) + if model_obj.provider_resource_id is None: + raise ValueError(f"Model {params.model} has no provider_resource_id") + provider_resource_id = model_obj.provider_resource_id request_params = await prepare_openai_completion_params( - model=self.get_litellm_model_name(model_obj.provider_resource_id), + model=self.get_litellm_model_name(provider_resource_id), prompt=params.prompt, best_of=params.best_of, echo=params.echo, @@ -249,7 +260,8 @@ class LiteLLMOpenAIMixin( api_key=self.get_api_key(), api_base=self.api_base, ) - return await litellm.atext_completion(**request_params) + # LiteLLM returns compatible type but mypy can't verify external library + return await litellm.atext_completion(**request_params) # type: ignore[no-any-return] # external lib lacks type stubs async def openai_chat_completion( self, @@ -265,10 +277,16 @@ class LiteLLMOpenAIMixin( elif "include_usage" not in stream_options: stream_options = {**stream_options, "include_usage": True} + if not self.model_store: + raise ValueError("Model store is not initialized") + model_obj = await self.model_store.get_model(params.model) + if model_obj.provider_resource_id is None: + raise ValueError(f"Model {params.model} has no provider_resource_id") + provider_resource_id = model_obj.provider_resource_id request_params = await prepare_openai_completion_params( - model=self.get_litellm_model_name(model_obj.provider_resource_id), + model=self.get_litellm_model_name(provider_resource_id), messages=params.messages, frequency_penalty=params.frequency_penalty, function_call=params.function_call, @@ -294,7 +312,8 @@ class LiteLLMOpenAIMixin( api_key=self.get_api_key(), api_base=self.api_base, ) - return await litellm.acompletion(**request_params) + # LiteLLM returns compatible type but mypy can't verify external library + return await litellm.acompletion(**request_params) # type: ignore[no-any-return] # external lib lacks type stubs async def check_model_availability(self, model: str) -> bool: """ diff --git a/src/llama_stack/providers/utils/inference/openai_compat.py b/src/llama_stack/providers/utils/inference/openai_compat.py index 7e465a14c..aabcb50f8 100644 --- a/src/llama_stack/providers/utils/inference/openai_compat.py +++ b/src/llama_stack/providers/utils/inference/openai_compat.py @@ -161,8 +161,10 @@ def get_sampling_strategy_options(params: SamplingParams) -> dict: if isinstance(params.strategy, GreedySamplingStrategy): options["temperature"] = 0.0 elif isinstance(params.strategy, TopPSamplingStrategy): - options["temperature"] = params.strategy.temperature - options["top_p"] = params.strategy.top_p + if params.strategy.temperature is not None: + options["temperature"] = params.strategy.temperature + if params.strategy.top_p is not None: + options["top_p"] = params.strategy.top_p elif isinstance(params.strategy, TopKSamplingStrategy): options["top_k"] = params.strategy.top_k else: @@ -192,12 +194,12 @@ def get_sampling_options(params: SamplingParams | None) -> dict: def text_from_choice(choice) -> str: if hasattr(choice, "delta") and choice.delta: - return choice.delta.content + return choice.delta.content # type: ignore[no-any-return] # external OpenAI types lack precise annotations if hasattr(choice, "message"): - return choice.message.content + return choice.message.content # type: ignore[no-any-return] # external OpenAI types lack precise annotations - return choice.text + return choice.text # type: ignore[no-any-return] # external OpenAI types lack precise annotations def get_stop_reason(finish_reason: str) -> StopReason: @@ -216,7 +218,7 @@ def convert_openai_completion_logprobs( ) -> list[TokenLogProbs] | None: if not logprobs: return None - if hasattr(logprobs, "top_logprobs"): + if hasattr(logprobs, "top_logprobs") and logprobs.top_logprobs: return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs] # Together supports logprobs with top_k=1 only. This means for each token position, @@ -236,7 +238,7 @@ def convert_openai_completion_logprobs_stream(text: str, logprobs: float | OpenA if isinstance(logprobs, float): # Adapt response from Together CompletionChoicesChunk return [TokenLogProbs(logprobs_by_token={text: logprobs})] - if hasattr(logprobs, "top_logprobs"): + if hasattr(logprobs, "top_logprobs") and logprobs.top_logprobs: return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs] return None @@ -245,23 +247,24 @@ def process_completion_response( response: OpenAICompatCompletionResponse, ) -> CompletionResponse: choice = response.choices[0] + text = choice.text or "" # drop suffix if present and return stop reason as end of turn - if choice.text.endswith("<|eot_id|>"): + if text.endswith("<|eot_id|>"): return CompletionResponse( stop_reason=StopReason.end_of_turn, - content=choice.text[: -len("<|eot_id|>")], + content=text[: -len("<|eot_id|>")], logprobs=convert_openai_completion_logprobs(choice.logprobs), ) # drop suffix if present and return stop reason as end of message - if choice.text.endswith("<|eom_id|>"): + if text.endswith("<|eom_id|>"): return CompletionResponse( stop_reason=StopReason.end_of_message, - content=choice.text[: -len("<|eom_id|>")], + content=text[: -len("<|eom_id|>")], logprobs=convert_openai_completion_logprobs(choice.logprobs), ) return CompletionResponse( - stop_reason=get_stop_reason(choice.finish_reason), - content=choice.text, + stop_reason=get_stop_reason(choice.finish_reason or "stop"), + content=text, logprobs=convert_openai_completion_logprobs(choice.logprobs), ) @@ -272,10 +275,10 @@ def process_chat_completion_response( ) -> ChatCompletionResponse: choice = response.choices[0] if choice.finish_reason == "tool_calls": - if not choice.message or not choice.message.tool_calls: + if not hasattr(choice, "message") or not choice.message or not choice.message.tool_calls: # type: ignore[attr-defined] # OpenAICompatCompletionChoice is runtime duck-typed raise ValueError("Tool calls are not present in the response") - tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls] + tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls] # type: ignore[attr-defined] # OpenAICompatCompletionChoice is runtime duck-typed if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls): # If we couldn't parse a tool call, jsonify the tool calls and return them return ChatCompletionResponse( @@ -287,9 +290,11 @@ def process_chat_completion_response( ) else: # Otherwise, return tool calls as normal + # Filter to only valid ToolCall objects + valid_tool_calls = [tc for tc in tool_calls if isinstance(tc, ToolCall)] return ChatCompletionResponse( completion_message=CompletionMessage( - tool_calls=tool_calls, + tool_calls=valid_tool_calls, stop_reason=StopReason.end_of_turn, # Content is not optional content="", @@ -299,7 +304,7 @@ def process_chat_completion_response( # TODO: This does not work well with tool calls for vLLM remote provider # Ref: https://github.com/meta-llama/llama-stack/issues/1058 - raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason)) + raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason or "stop")) # NOTE: If we do not set tools in chat-completion request, we should not # expect the ToolCall in the response. Instead, we should return the raw @@ -324,8 +329,8 @@ def process_chat_completion_response( return ChatCompletionResponse( completion_message=CompletionMessage( - content=raw_message.content, - stop_reason=raw_message.stop_reason, + content=raw_message.content, # type: ignore[arg-type] # decode_assistant_message returns Union[str, InterleavedContent] + stop_reason=raw_message.stop_reason or StopReason.end_of_turn, tool_calls=raw_message.tool_calls, ), logprobs=None, @@ -448,7 +453,7 @@ async def process_chat_completion_stream_response( ) # parse tool calls and report errors - message = decode_assistant_message(buffer, stop_reason) + message = decode_assistant_message(buffer, stop_reason or StopReason.end_of_turn) parsed_tool_calls = len(message.tool_calls) > 0 if ipython and not parsed_tool_calls: @@ -463,7 +468,7 @@ async def process_chat_completion_stream_response( ) ) - request_tools = {t.tool_name: t for t in request.tools} + request_tools = {t.tool_name: t for t in (request.tools or [])} for tool_call in message.tool_calls: if tool_call.tool_name in request_tools: yield ChatCompletionResponseStreamChunk( @@ -525,7 +530,7 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals } if hasattr(message, "tool_calls") and message.tool_calls: - result["tool_calls"] = [] + tool_calls_list = [] for tc in message.tool_calls: # The tool.tool_name can be a str or a BuiltinTool enum. If # it's the latter, convert to a string. @@ -533,7 +538,7 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals if isinstance(tool_name, BuiltinTool): tool_name = tool_name.value - result["tool_calls"].append( + tool_calls_list.append( { "id": tc.call_id, "type": "function", @@ -543,6 +548,7 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals }, } ) + result["tool_calls"] = tool_calls_list # type: ignore[assignment] # dict allows Any value, stricter type expected return result @@ -608,7 +614,7 @@ async def convert_message_to_openai_dict_new( ), ) elif isinstance(content_, list): - return [await impl(item) for item in content_] + return [await impl(item) for item in content_] # type: ignore[misc] # recursive list comprehension confuses mypy's type narrowing else: raise ValueError(f"Unsupported content type: {type(content_)}") @@ -620,7 +626,7 @@ async def convert_message_to_openai_dict_new( else: return [ret] - out: OpenAIChatCompletionMessage = None + out: OpenAIChatCompletionMessage if isinstance(message, UserMessage): out = OpenAIChatCompletionUserMessage( role="user", @@ -636,7 +642,7 @@ async def convert_message_to_openai_dict_new( ), type="function", ) - for tool in message.tool_calls + for tool in (message.tool_calls or []) ] params = {} if tool_calls: @@ -644,18 +650,18 @@ async def convert_message_to_openai_dict_new( out = OpenAIChatCompletionAssistantMessage( role="assistant", content=await _convert_message_content(message.content), - **params, + **params, # type: ignore[typeddict-item] # tool_calls dict expansion conflicts with TypedDict optional field ) elif isinstance(message, ToolResponseMessage): out = OpenAIChatCompletionToolMessage( role="tool", tool_call_id=message.call_id, - content=await _convert_message_content(message.content), + content=await _convert_message_content(message.content), # type: ignore[typeddict-item] # content union type incompatible with TypedDict str requirement ) elif isinstance(message, SystemMessage): out = OpenAIChatCompletionSystemMessage( role="system", - content=await _convert_message_content(message.content), + content=await _convert_message_content(message.content), # type: ignore[typeddict-item] # content union type incompatible with TypedDict str requirement ) else: raise ValueError(f"Unsupported message type: {type(message)}") @@ -758,16 +764,16 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict: function = out["function"] if isinstance(tool.tool_name, BuiltinTool): - function["name"] = tool.tool_name.value + function["name"] = tool.tool_name.value # type: ignore[index] # dict value inferred as Any but mypy sees Collection[str] else: - function["name"] = tool.tool_name + function["name"] = tool.tool_name # type: ignore[index] # dict value inferred as Any but mypy sees Collection[str] if tool.description: - function["description"] = tool.description + function["description"] = tool.description # type: ignore[index] # dict value inferred as Any but mypy sees Collection[str] if tool.input_schema: # Pass through the entire JSON Schema as-is - function["parameters"] = tool.input_schema + function["parameters"] = tool.input_schema # type: ignore[index] # dict value inferred as Any but mypy sees Collection[str] # NOTE: OpenAI does not support output_schema, so we drop it here # It's stored in LlamaStack for validation and other provider usage @@ -815,15 +821,15 @@ def _convert_openai_request_tool_config(tool_choice: str | dict[str, Any] | None tool_config = ToolConfig() if tool_choice: try: - tool_choice = ToolChoice(tool_choice) + tool_choice = ToolChoice(tool_choice) # type: ignore[assignment] # reassigning to enum narrows union but mypy can't track after exception except ValueError: pass - tool_config.tool_choice = tool_choice + tool_config.tool_choice = tool_choice # type: ignore[assignment] # ToolConfig.tool_choice accepts Union[ToolChoice, dict] but mypy tracks narrower type return tool_config def _convert_openai_request_tools(tools: list[dict[str, Any]] | None = None) -> list[ToolDefinition]: - lls_tools = [] + lls_tools: list[ToolDefinition] = [] if not tools: return lls_tools @@ -843,16 +849,16 @@ def _convert_openai_request_tools(tools: list[dict[str, Any]] | None = None) -> def _convert_openai_request_response_format( - response_format: OpenAIResponseFormatParam = None, + response_format: OpenAIResponseFormatParam | None = None, ): if not response_format: return None # response_format can be a dict or a pydantic model - response_format = dict(response_format) - if response_format.get("type", "") == "json_schema": + response_format_dict = dict(response_format) # type: ignore[arg-type] # OpenAIResponseFormatParam union needs dict conversion + if response_format_dict.get("type", "") == "json_schema": return JsonSchemaResponseFormat( - type="json_schema", - json_schema=response_format.get("json_schema", {}).get("schema", ""), + type="json_schema", # type: ignore[arg-type] # Literal["json_schema"] incompatible with expected type + json_schema=response_format_dict.get("json_schema", {}).get("schema", ""), ) return None @@ -938,16 +944,15 @@ def _convert_openai_sampling_params( # Map an explicit temperature of 0 to greedy sampling if temperature == 0: - strategy = GreedySamplingStrategy() + sampling_params.strategy = GreedySamplingStrategy() else: # OpenAI defaults to 1.0 for temperature and top_p if unset if temperature is None: temperature = 1.0 if top_p is None: top_p = 1.0 - strategy = TopPSamplingStrategy(temperature=temperature, top_p=top_p) + sampling_params.strategy = TopPSamplingStrategy(temperature=temperature, top_p=top_p) # type: ignore[assignment] # SamplingParams.strategy union accepts this type - sampling_params.strategy = strategy return sampling_params @@ -957,23 +962,24 @@ def openai_messages_to_messages( """ Convert a list of OpenAIChatCompletionMessage into a list of Message. """ - converted_messages = [] + converted_messages: list[Message] = [] for message in messages: + converted_message: Message if message.role == "system": - converted_message = SystemMessage(content=openai_content_to_content(message.content)) + converted_message = SystemMessage(content=openai_content_to_content(message.content)) # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types elif message.role == "user": - converted_message = UserMessage(content=openai_content_to_content(message.content)) + converted_message = UserMessage(content=openai_content_to_content(message.content)) # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types elif message.role == "assistant": converted_message = CompletionMessage( - content=openai_content_to_content(message.content), - tool_calls=_convert_openai_tool_calls(message.tool_calls), + content=openai_content_to_content(message.content), # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types + tool_calls=_convert_openai_tool_calls(message.tool_calls) if message.tool_calls else [], # type: ignore[arg-type] # OpenAI tool_calls type incompatible with conversion function stop_reason=StopReason.end_of_turn, ) elif message.role == "tool": converted_message = ToolResponseMessage( role="tool", call_id=message.tool_call_id, - content=openai_content_to_content(message.content), + content=openai_content_to_content(message.content), # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types ) else: raise ValueError(f"Unknown role {message.role}") @@ -990,9 +996,9 @@ def openai_content_to_content(content: str | Iterable[OpenAIChatCompletionConten return [openai_content_to_content(c) for c in content] elif hasattr(content, "type"): if content.type == "text": - return TextContentItem(type="text", text=content.text) + return TextContentItem(type="text", text=content.text) # type: ignore[attr-defined] # Iterable narrowed by hasattr check but mypy doesn't track elif content.type == "image_url": - return ImageContentItem(type="image", image=_URLOrData(url=URL(uri=content.image_url.url))) + return ImageContentItem(type="image", image=_URLOrData(url=URL(uri=content.image_url.url))) # type: ignore[attr-defined] # Iterable narrowed by hasattr check but mypy doesn't track else: raise ValueError(f"Unknown content type: {content.type}") else: @@ -1041,9 +1047,9 @@ def convert_openai_chat_completion_choice( completion_message=CompletionMessage( content=choice.message.content or "", # CompletionMessage content is not optional stop_reason=_convert_openai_finish_reason(choice.finish_reason), - tool_calls=_convert_openai_tool_calls(choice.message.tool_calls), + tool_calls=_convert_openai_tool_calls(choice.message.tool_calls) if choice.message.tool_calls else [], # type: ignore[arg-type] # OpenAI tool_calls Optional type broadens union ), - logprobs=_convert_openai_logprobs(getattr(choice, "logprobs", None)), + logprobs=_convert_openai_logprobs(getattr(choice, "logprobs", None)), # type: ignore[arg-type] # getattr returns Any, can't narrow without inspection ) @@ -1070,7 +1076,7 @@ async def convert_openai_chat_completion_stream( choice = chunk.choices[0] # assuming only one choice per chunk # we assume there's only one finish_reason in the stream - stop_reason = _convert_openai_finish_reason(choice.finish_reason) or stop_reason + stop_reason = _convert_openai_finish_reason(choice.finish_reason) if choice.finish_reason else stop_reason logprobs = getattr(choice, "logprobs", None) # if there's a tool call, emit an event for each tool in the list @@ -1083,7 +1089,7 @@ async def convert_openai_chat_completion_stream( event=ChatCompletionResponseEvent( event_type=event_type, delta=TextDelta(text=choice.delta.content), - logprobs=_convert_openai_logprobs(logprobs), + logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result ) ) @@ -1101,10 +1107,10 @@ async def convert_openai_chat_completion_stream( event=ChatCompletionResponseEvent( event_type=event_type, delta=ToolCallDelta( - tool_call=_convert_openai_tool_calls([tool_call])[0], + tool_call=_convert_openai_tool_calls([tool_call])[0], # type: ignore[arg-type, list-item] # delta tool_call type differs from complete tool_call parse_status=ToolCallParseStatus.succeeded, ), - logprobs=_convert_openai_logprobs(logprobs), + logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result ) ) else: @@ -1125,12 +1131,15 @@ async def convert_openai_chat_completion_stream( if tool_call.function.name: buffer["name"] = tool_call.function.name delta = f"{buffer['name']}(" - buffer["content"] += delta + if buffer["content"] is not None: + buffer["content"] += delta if tool_call.function.arguments: delta = tool_call.function.arguments - buffer["arguments"] += delta - buffer["content"] += delta + if buffer["arguments"] is not None and delta: + buffer["arguments"] += delta + if buffer["content"] is not None and delta: + buffer["content"] += delta yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( @@ -1139,7 +1148,7 @@ async def convert_openai_chat_completion_stream( tool_call=delta, parse_status=ToolCallParseStatus.in_progress, ), - logprobs=_convert_openai_logprobs(logprobs), + logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result ) ) elif choice.delta.content: @@ -1147,7 +1156,7 @@ async def convert_openai_chat_completion_stream( event=ChatCompletionResponseEvent( event_type=event_type, delta=TextDelta(text=choice.delta.content or ""), - logprobs=_convert_openai_logprobs(logprobs), + logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result ) ) @@ -1155,7 +1164,8 @@ async def convert_openai_chat_completion_stream( logger.debug(f"toolcall_buffer[{idx}]: {buffer}") if buffer["name"]: delta = ")" - buffer["content"] += delta + if buffer["content"] is not None: + buffer["content"] += delta yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=event_type, @@ -1168,16 +1178,16 @@ async def convert_openai_chat_completion_stream( ) try: - tool_call = ToolCall( - call_id=buffer["call_id"], - tool_name=buffer["name"], - arguments=buffer["arguments"], + parsed_tool_call = ToolCall( + call_id=buffer["call_id"] or "", + tool_name=buffer["name"] or "", + arguments=buffer["arguments"] or "", ) yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.progress, delta=ToolCallDelta( - tool_call=tool_call, + tool_call=parsed_tool_call, # type: ignore[arg-type] # ToolCallDelta.tool_call accepts Union[str, ToolCall] parse_status=ToolCallParseStatus.succeeded, ), stop_reason=stop_reason, @@ -1189,7 +1199,7 @@ async def convert_openai_chat_completion_stream( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.progress, delta=ToolCallDelta( - tool_call=buffer["content"], + tool_call=buffer["content"], # type: ignore[arg-type] # ToolCallDelta.tool_call accepts Union[str, ToolCall] parse_status=ToolCallParseStatus.failed, ), stop_reason=stop_reason, @@ -1250,7 +1260,7 @@ class OpenAIChatCompletionToLlamaStackMixin: top_p: float | None = None, user: str | None = None, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: - messages = openai_messages_to_messages(messages) + messages = openai_messages_to_messages(messages) # type: ignore[assignment] # converted from OpenAI to LlamaStack message format response_format = _convert_openai_request_response_format(response_format) sampling_params = _convert_openai_sampling_params( max_tokens=max_tokens, @@ -1259,15 +1269,15 @@ class OpenAIChatCompletionToLlamaStackMixin: ) tool_config = _convert_openai_request_tool_config(tool_choice) - tools = _convert_openai_request_tools(tools) + tools = _convert_openai_request_tools(tools) # type: ignore[assignment] # converted from OpenAI to LlamaStack tool format if tool_config.tool_choice == ToolChoice.none: - tools = [] + tools = [] # type: ignore[assignment] # empty list narrows return type but mypy tracks broader type outstanding_responses = [] # "n" is the number of completions to generate per prompt n = n or 1 for _i in range(0, n): - response = self.chat_completion( + response = self.chat_completion( # type: ignore[attr-defined] # mixin expects class to implement chat_completion model_id=model, messages=messages, sampling_params=sampling_params, @@ -1279,7 +1289,7 @@ class OpenAIChatCompletionToLlamaStackMixin: outstanding_responses.append(response) if stream: - return OpenAIChatCompletionToLlamaStackMixin._process_stream_response(self, model, outstanding_responses) + return OpenAIChatCompletionToLlamaStackMixin._process_stream_response(self, model, outstanding_responses) # type: ignore[no-any-return] # mixin async generator return type too complex for mypy return await OpenAIChatCompletionToLlamaStackMixin._process_non_stream_response( self, model, outstanding_responses @@ -1295,14 +1305,16 @@ class OpenAIChatCompletionToLlamaStackMixin: response = await outstanding_response async for chunk in response: event = chunk.event - finish_reason = _convert_stop_reason_to_openai_finish_reason(event.stop_reason) + finish_reason = ( + _convert_stop_reason_to_openai_finish_reason(event.stop_reason) if event.stop_reason else None + ) if isinstance(event.delta, TextDelta): text_delta = event.delta.text delta = OpenAIChoiceDelta(content=text_delta) yield OpenAIChatCompletionChunk( id=id, - choices=[OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)], + choices=[OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)], # type: ignore[arg-type] # finish_reason Optional[str] incompatible with Literal union created=int(time.time()), model=model, object="chat.completion.chunk", @@ -1310,13 +1322,17 @@ class OpenAIChatCompletionToLlamaStackMixin: elif isinstance(event.delta, ToolCallDelta): if event.delta.parse_status == ToolCallParseStatus.succeeded: tool_call = event.delta.tool_call + if isinstance(tool_call, str): + continue # First chunk includes full structure openai_tool_call = OpenAIChoiceDeltaToolCall( index=0, id=tool_call.call_id, function=OpenAIChoiceDeltaToolCallFunction( - name=tool_call.tool_name, + name=tool_call.tool_name + if isinstance(tool_call.tool_name, str) + else tool_call.tool_name.value, # type: ignore[arg-type] # enum .value extraction on Union confuses mypy arguments="", ), ) @@ -1324,7 +1340,7 @@ class OpenAIChatCompletionToLlamaStackMixin: yield OpenAIChatCompletionChunk( id=id, choices=[ - OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta) + OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta) # type: ignore[arg-type] # finish_reason Optional[str] incompatible with Literal union ], created=int(time.time()), model=model, @@ -1341,7 +1357,7 @@ class OpenAIChatCompletionToLlamaStackMixin: yield OpenAIChatCompletionChunk( id=id, choices=[ - OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta) + OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta) # type: ignore[arg-type] # finish_reason Optional[str] incompatible with Literal union ], created=int(time.time()), model=model, @@ -1351,7 +1367,7 @@ class OpenAIChatCompletionToLlamaStackMixin: async def _process_non_stream_response( self, model: str, outstanding_responses: list[Awaitable[ChatCompletionResponse]] ) -> OpenAIChatCompletion: - choices = [] + choices: list[OpenAIChatCompletionChoice] = [] for outstanding_response in outstanding_responses: response = await outstanding_response completion_message = response.completion_message @@ -1360,14 +1376,14 @@ class OpenAIChatCompletionToLlamaStackMixin: choice = OpenAIChatCompletionChoice( index=len(choices), - message=message, + message=message, # type: ignore[arg-type] # OpenAIChatCompletionMessage union incompatible with narrower Message type finish_reason=finish_reason, ) - choices.append(choice) + choices.append(choice) # type: ignore[arg-type] # OpenAIChatCompletionChoice type annotation mismatch return OpenAIChatCompletion( id=f"chatcmpl-{uuid.uuid4()}", - choices=choices, + choices=choices, # type: ignore[arg-type] # list[OpenAIChatCompletionChoice] union incompatible created=int(time.time()), model=model, object="chat.completion",