diff --git a/src/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/src/llama_stack/providers/utils/inference/litellm_openai_mixin.py index 3eef1f272..63e3771a8 100644 --- a/src/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/src/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -128,7 +128,9 @@ class LiteLLMOpenAIMixin( return schema async def _get_params(self, request: ChatCompletionRequest) -> dict: - input_dict = {} + from typing import Any + + input_dict: dict[str, Any] = {} input_dict["messages"] = [ await convert_message_to_openai_dict_new(m, download_images=self.download_images) for m in request.messages @@ -139,29 +141,30 @@ class LiteLLMOpenAIMixin( f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported." ) - fmt = fmt.json_schema - name = fmt["title"] - del fmt["title"] - fmt["additionalProperties"] = False + # Convert to dict for manipulation + fmt_dict = dict(fmt.json_schema) + name = fmt_dict["title"] + del fmt_dict["title"] + fmt_dict["additionalProperties"] = False # Apply additionalProperties: False recursively to all objects - fmt = self._add_additional_properties_recursive(fmt) + fmt_dict = self._add_additional_properties_recursive(fmt_dict) input_dict["response_format"] = { "type": "json_schema", "json_schema": { "name": name, - "schema": fmt, + "schema": fmt_dict, "strict": self.json_schema_strict, }, } if request.tools: input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools] - if request.tool_config.tool_choice: + if request.tool_config and (tool_choice := request.tool_config.tool_choice): input_dict["tool_choice"] = ( - request.tool_config.tool_choice.value - if isinstance(request.tool_config.tool_choice, ToolChoice) - else request.tool_config.tool_choice + tool_choice.value + if isinstance(tool_choice, ToolChoice) + else tool_choice ) return { @@ -176,10 +179,10 @@ class LiteLLMOpenAIMixin( def get_api_key(self) -> str: provider_data = self.get_request_provider_data() key_field = self.provider_data_api_key_field - if provider_data and getattr(provider_data, key_field, None): - api_key = getattr(provider_data, key_field) - else: - api_key = self.api_key_from_config + if provider_data and key_field and (api_key := getattr(provider_data, key_field, None)): + return str(api_key) # type: ignore[no-any-return] + + api_key = self.api_key_from_config if not api_key: raise ValueError( "API key is not set. Please provide a valid API key in the " @@ -192,7 +195,12 @@ class LiteLLMOpenAIMixin( self, params: OpenAIEmbeddingsRequestWithExtraBody, ) -> OpenAIEmbeddingsResponse: + if not self.model_store: + raise ValueError("Model store is not initialized") + model_obj = await self.model_store.get_model(params.model) + # Fallback to params.model ensures provider_resource_id is always str + provider_resource_id: str = (model_obj.provider_resource_id if model_obj else None) or params.model # Convert input to list if it's a string input_list = [params.input] if isinstance(params.input, str) else params.input @@ -200,7 +208,7 @@ class LiteLLMOpenAIMixin( # Call litellm embedding function # litellm.drop_params = True response = litellm.embedding( - model=self.get_litellm_model_name(model_obj.provider_resource_id), + model=self.get_litellm_model_name(provider_resource_id), input=input_list, api_key=self.get_api_key(), api_base=self.api_base, @@ -217,7 +225,7 @@ class LiteLLMOpenAIMixin( return OpenAIEmbeddingsResponse( data=data, - model=model_obj.provider_resource_id, + model=provider_resource_id, usage=usage, ) @@ -225,10 +233,15 @@ class LiteLLMOpenAIMixin( self, params: OpenAICompletionRequestWithExtraBody, ) -> OpenAICompletion: + if not self.model_store: + raise ValueError("Model store is not initialized") + model_obj = await self.model_store.get_model(params.model) + # Fallback to params.model ensures provider_resource_id is always str + provider_resource_id: str = (model_obj.provider_resource_id if model_obj else None) or params.model request_params = await prepare_openai_completion_params( - model=self.get_litellm_model_name(model_obj.provider_resource_id), + model=self.get_litellm_model_name(provider_resource_id), prompt=params.prompt, best_of=params.best_of, echo=params.echo, @@ -249,7 +262,8 @@ class LiteLLMOpenAIMixin( api_key=self.get_api_key(), api_base=self.api_base, ) - return await litellm.atext_completion(**request_params) + # LiteLLM returns compatible type but mypy can't verify external library + return await litellm.atext_completion(**request_params) # type: ignore[no-any-return] async def openai_chat_completion( self, @@ -265,10 +279,15 @@ class LiteLLMOpenAIMixin( elif "include_usage" not in stream_options: stream_options = {**stream_options, "include_usage": True} + if not self.model_store: + raise ValueError("Model store is not initialized") + model_obj = await self.model_store.get_model(params.model) + # Fallback to params.model ensures provider_resource_id is always str + provider_resource_id: str = (model_obj.provider_resource_id if model_obj else None) or params.model request_params = await prepare_openai_completion_params( - model=self.get_litellm_model_name(model_obj.provider_resource_id), + model=self.get_litellm_model_name(provider_resource_id), messages=params.messages, frequency_penalty=params.frequency_penalty, function_call=params.function_call, @@ -294,7 +313,8 @@ class LiteLLMOpenAIMixin( api_key=self.get_api_key(), api_base=self.api_base, ) - return await litellm.acompletion(**request_params) + # LiteLLM returns compatible type but mypy can't verify external library + return await litellm.acompletion(**request_params) # type: ignore[no-any-return] async def check_model_availability(self, model: str) -> bool: """