diff --git a/litellm/litellm_core_utils/core_helpers.py b/litellm/litellm_core_utils/core_helpers.py index bf11205f6d..9c31159517 100644 --- a/litellm/litellm_core_utils/core_helpers.py +++ b/litellm/litellm_core_utils/core_helpers.py @@ -1,10 +1,11 @@ # What is this? ## Helper utilities -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, List, Optional, Union import httpx from litellm._logging import verbose_logger +from litellm.types.llms.openai import AllMessageValues, ChatCompletionToolParam if TYPE_CHECKING: from opentelemetry.trace import Span as _Span @@ -53,17 +54,18 @@ def map_finish_reason( return finish_reason -def remove_index_from_tool_calls(messages, tool_calls): - for tool_call in tool_calls: - if "index" in tool_call: - tool_call.pop("index") - - for message in messages: - if "tool_calls" in message: - tool_calls = message["tool_calls"] - for tool_call in tool_calls: - if "index" in tool_call: - tool_call.pop("index") +def remove_index_from_tool_calls( + messages: Optional[List[AllMessageValues]], +): + if messages is not None: + for message in messages: + _tool_calls = message.get("tool_calls") + if _tool_calls is not None and isinstance(_tool_calls, list): + for tool_call in _tool_calls: + if ( + isinstance(tool_call, dict) and "index" in tool_call + ): # Type guard to ensure it's a dict + tool_call.pop("index", None) return diff --git a/litellm/litellm_core_utils/exception_mapping_utils.py b/litellm/litellm_core_utils/exception_mapping_utils.py index 3d898fe15b..07f517fec7 100644 --- a/litellm/litellm_core_utils/exception_mapping_utils.py +++ b/litellm/litellm_core_utils/exception_mapping_utils.py @@ -148,11 +148,10 @@ def exception_type( # type: ignore # noqa: PLR0915 original_exception=original_exception ) try: + error_str = str(original_exception) if model: if hasattr(original_exception, "message"): error_str = str(original_exception.message) - else: - error_str = str(original_exception) if isinstance(original_exception, BaseException): exception_type = type(original_exception).__name__ else: diff --git a/litellm/llms/anthropic/chat/transformation.py b/litellm/llms/anthropic/chat/transformation.py index fa8a6cee1d..8e3d7d238e 100644 --- a/litellm/llms/anthropic/chat/transformation.py +++ b/litellm/llms/anthropic/chat/transformation.py @@ -741,6 +741,7 @@ class AnthropicConfig(BaseConfig): messages: List[AllMessageValues], optional_params: dict, api_key: Optional[str] = None, + api_base: Optional[str] = None, ) -> Dict: if api_key is None: raise litellm.AuthenticationError( diff --git a/litellm/llms/anthropic/completion/transformation.py b/litellm/llms/anthropic/completion/transformation.py index a94bac0383..e2510d6a98 100644 --- a/litellm/llms/anthropic/completion/transformation.py +++ b/litellm/llms/anthropic/completion/transformation.py @@ -85,6 +85,7 @@ class AnthropicTextConfig(BaseConfig): messages: List[AllMessageValues], optional_params: dict, api_key: Optional[str] = None, + api_base: Optional[str] = None, ) -> dict: if api_key is None: raise ValueError( diff --git a/litellm/llms/azure/chat/gpt_transformation.py b/litellm/llms/azure/chat/gpt_transformation.py index 23353ab0c8..49c3487e65 100644 --- a/litellm/llms/azure/chat/gpt_transformation.py +++ b/litellm/llms/azure/chat/gpt_transformation.py @@ -283,6 +283,7 @@ class AzureOpenAIConfig(BaseConfig): messages: List[AllMessageValues], optional_params: dict, api_key: Optional[str] = None, + api_base: Optional[str] = None, ) -> dict: raise NotImplementedError( "Azure OpenAI has custom logic for validating environment, as it uses the OpenAI SDK." diff --git a/litellm/llms/azure_ai/chat/transformation.py b/litellm/llms/azure_ai/chat/transformation.py index 0523a7e5ef..5c6f004e0e 100644 --- a/litellm/llms/azure_ai/chat/transformation.py +++ b/litellm/llms/azure_ai/chat/transformation.py @@ -1,4 +1,7 @@ -from typing import List, Optional, Tuple +from typing import Any, List, Optional, Tuple, cast + +import httpx +from httpx import Response import litellm from litellm._logging import verbose_logger @@ -6,13 +9,81 @@ from litellm.litellm_core_utils.prompt_templates.common_utils import ( _audio_or_image_in_message_content, convert_content_list_to_str, ) +from litellm.llms.base_llm.chat.transformation import LiteLLMLoggingObj +from litellm.llms.openai.common_utils import drop_params_from_unprocessable_entity_error from litellm.llms.openai.openai import OpenAIConfig from litellm.secret_managers.main import get_secret_str -from litellm.types.llms.openai import AllMessageValues -from litellm.types.utils import ProviderField +from litellm.types.llms.openai import AllMessageValues, ChatCompletionToolParam +from litellm.types.utils import ModelResponse, ProviderField +from litellm.utils import _add_path_to_api_base class AzureAIStudioConfig(OpenAIConfig): + def validate_environment( + self, + headers: dict, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + ) -> dict: + if api_base and "services.ai.azure.com" in api_base: + headers["api-key"] = api_key + else: + headers["Authorization"] = f"Bearer {api_key}" + + return headers + + def get_complete_url( + self, + api_base: str, + model: str, + optional_params: dict, + stream: Optional[bool] = None, + ) -> str: + """ + Constructs a complete URL for the API request. + + Args: + - api_base: Base URL, e.g., + "https://litellm8397336933.services.ai.azure.com" + OR + "https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview" + - model: Model name. + - optional_params: Additional query parameters, including "api_version". + - stream: If streaming is required (optional). + + Returns: + - A complete URL string, e.g., + "https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview" + """ + original_url = httpx.URL(api_base) + + # Extract api_version or use default + api_version = cast(Optional[str], optional_params.get("api_version")) + + # Check if 'api-version' is already present + if "api-version" not in original_url.params and api_version: + # Add api_version to optional_params + original_url.params["api-version"] = api_version + + # Add the path to the base URL + if "services.ai.azure.com" in api_base: + new_url = _add_path_to_api_base( + api_base=api_base, ending_path="/models/chat/completions" + ) + else: + new_url = _add_path_to_api_base( + api_base=api_base, ending_path="/chat/completions" + ) + + # Convert optional_params to query parameters + query_params = original_url.params + final_url = httpx.URL(new_url).copy_with(params=query_params) + + return str(final_url) + def get_required_params(self) -> List[ProviderField]: """For a given provider, return it's required fields with a description""" return [ @@ -62,8 +133,6 @@ class AzureAIStudioConfig(OpenAIConfig): ): return True - if api_base and "services.ai.azure" in api_base: - return True except Exception: return False return False @@ -86,3 +155,81 @@ class AzureAIStudioConfig(OpenAIConfig): ) custom_llm_provider = "azure" return api_base, dynamic_api_key, custom_llm_provider + + def transform_request( + self, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + headers: dict, + ) -> dict: + extra_body = optional_params.pop("extra_body", {}) + if extra_body and isinstance(extra_body, dict): + optional_params.update(extra_body) + optional_params.pop("max_retries", None) + return super().transform_request( + model, messages, optional_params, litellm_params, headers + ) + + def transform_response( + self, + model: str, + raw_response: Response, + model_response: ModelResponse, + logging_obj: LiteLLMLoggingObj, + request_data: dict, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + encoding: Any, + api_key: Optional[str] = None, + json_mode: Optional[bool] = None, + ) -> ModelResponse: + model_response.model = f"azure_ai/{model}" + return super().transform_response( + model=model, + raw_response=raw_response, + model_response=model_response, + logging_obj=logging_obj, + request_data=request_data, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + encoding=encoding, + api_key=api_key, + json_mode=json_mode, + ) + + def should_retry_llm_api_inside_llm_translation_on_http_error( + self, e: httpx.HTTPStatusError, litellm_params: dict + ) -> bool: + should_drop_params = litellm_params.get("drop_params") or litellm.drop_params + error_text = e.response.text + if should_drop_params and "Extra inputs are not permitted" in error_text: + return True + elif ( + "unknown field: parameter index is not a valid field" in error_text + ): # remove index from tool calls + return True + return super().should_retry_llm_api_inside_llm_translation_on_http_error( + e=e, litellm_params=litellm_params + ) + + @property + def max_retry_on_unprocessable_entity_error(self) -> int: + return 2 + + def transform_request_on_unprocessable_entity_error( + self, e: httpx.HTTPStatusError, request_data: dict + ) -> dict: + _messages = cast(Optional[List[AllMessageValues]], request_data.get("messages")) + if ( + "unknown field: parameter index is not a valid field" in e.response.text + and _messages is not None + ): + litellm.remove_index_from_tool_calls( + messages=_messages, + ) + data = drop_params_from_unprocessable_entity_error(e=e, data=request_data) + return data diff --git a/litellm/llms/base_llm/base_model_iterator.py b/litellm/llms/base_llm/base_model_iterator.py index 961941e7e0..67b1466c2a 100644 --- a/litellm/llms/base_llm/base_model_iterator.py +++ b/litellm/llms/base_llm/base_model_iterator.py @@ -1,8 +1,8 @@ import json from abc import abstractmethod -from typing import Optional +from typing import Optional, Union -from litellm.types.utils import GenericStreamingChunk +from litellm.types.utils import GenericStreamingChunk, ModelResponseStream class BaseModelResponseIterator: @@ -13,7 +13,9 @@ class BaseModelResponseIterator: self.response_iterator = self.streaming_response self.json_mode = json_mode - def chunk_parser(self, chunk: dict) -> GenericStreamingChunk: + def chunk_parser( + self, chunk: dict + ) -> Union[GenericStreamingChunk, ModelResponseStream]: return GenericStreamingChunk( text="", is_finished=False, @@ -27,7 +29,9 @@ class BaseModelResponseIterator: def __iter__(self): return self - def _handle_string_chunk(self, str_line: str) -> GenericStreamingChunk: + def _handle_string_chunk( + self, str_line: str + ) -> Union[GenericStreamingChunk, ModelResponseStream]: # chunk is a str at this point if "[DONE]" in str_line: return GenericStreamingChunk( diff --git a/litellm/llms/base_llm/chat/transformation.py b/litellm/llms/base_llm/chat/transformation.py index 363883579b..2d96451239 100644 --- a/litellm/llms/base_llm/chat/transformation.py +++ b/litellm/llms/base_llm/chat/transformation.py @@ -82,6 +82,33 @@ class BaseConfig(ABC): """ return False + def should_retry_llm_api_inside_llm_translation_on_http_error( + self, e: httpx.HTTPStatusError, litellm_params: dict + ) -> bool: + """ + Returns True if the model/provider should retry the LLM API on UnprocessableEntityError + + Overriden by azure ai - where different models support different parameters + """ + return False + + def transform_request_on_unprocessable_entity_error( + self, e: httpx.HTTPStatusError, request_data: dict + ) -> dict: + """ + Transform the request data on UnprocessableEntityError + """ + return request_data + + @property + def max_retry_on_unprocessable_entity_error(self) -> int: + """ + Returns the max retry count for UnprocessableEntityError + + Used if `should_retry_llm_api_inside_llm_translation_on_http_error` is True + """ + return 0 + @abstractmethod def get_supported_openai_params(self, model: str) -> list: pass @@ -104,6 +131,7 @@ class BaseConfig(ABC): messages: List[AllMessageValues], optional_params: dict, api_key: Optional[str] = None, + api_base: Optional[str] = None, ) -> dict: pass diff --git a/litellm/llms/bedrock/common_utils.py b/litellm/llms/bedrock/common_utils.py index c92845d8b5..531b202f89 100644 --- a/litellm/llms/bedrock/common_utils.py +++ b/litellm/llms/bedrock/common_utils.py @@ -115,6 +115,7 @@ class AmazonInvokeMixin: messages: List[AllMessageValues], optional_params: dict, api_key: Optional[str] = None, + api_base: Optional[str] = None, ) -> dict: raise NotImplementedError( "validate_environment not implemented for config. Done in invoke_handler.py" diff --git a/litellm/llms/clarifai/chat/transformation.py b/litellm/llms/clarifai/chat/transformation.py index f7ab00ac31..299dd8637c 100644 --- a/litellm/llms/clarifai/chat/transformation.py +++ b/litellm/llms/clarifai/chat/transformation.py @@ -119,6 +119,7 @@ class ClarifaiConfig(BaseConfig): messages: List[AllMessageValues], optional_params: dict, api_key: Optional[str] = None, + api_base: Optional[str] = None, ) -> dict: headers = { "accept": "application/json", diff --git a/litellm/llms/cloudflare/chat/transformation.py b/litellm/llms/cloudflare/chat/transformation.py index 59ba870de5..ba1e0697ed 100644 --- a/litellm/llms/cloudflare/chat/transformation.py +++ b/litellm/llms/cloudflare/chat/transformation.py @@ -60,6 +60,7 @@ class CloudflareChatConfig(BaseConfig): messages: List[AllMessageValues], optional_params: dict, api_key: Optional[str] = None, + api_base: Optional[str] = None, ) -> dict: if api_key is None: raise ValueError( diff --git a/litellm/llms/cohere/chat/transformation.py b/litellm/llms/cohere/chat/transformation.py index 464ef1f268..1d68735224 100644 --- a/litellm/llms/cohere/chat/transformation.py +++ b/litellm/llms/cohere/chat/transformation.py @@ -116,6 +116,7 @@ class CohereChatConfig(BaseConfig): messages: List[AllMessageValues], optional_params: dict, api_key: Optional[str] = None, + api_base: Optional[str] = None, ) -> dict: return cohere_validate_environment( headers=headers, diff --git a/litellm/llms/cohere/completion/transformation.py b/litellm/llms/cohere/completion/transformation.py index 95faa169a5..7c01523571 100644 --- a/litellm/llms/cohere/completion/transformation.py +++ b/litellm/llms/cohere/completion/transformation.py @@ -102,6 +102,7 @@ class CohereTextConfig(BaseConfig): messages: List[AllMessageValues], optional_params: dict, api_key: Optional[str] = None, + api_base: Optional[str] = None, ) -> dict: return cohere_validate_environment( headers=headers, diff --git a/litellm/llms/custom_httpx/llm_http_handler.py b/litellm/llms/custom_httpx/llm_http_handler.py index 984d703a4f..c7ba9cd096 100644 --- a/litellm/llms/custom_httpx/llm_http_handler.py +++ b/litellm/llms/custom_httpx/llm_http_handler.py @@ -8,7 +8,7 @@ import litellm import litellm.litellm_core_utils import litellm.types import litellm.types.utils -from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException +from litellm.llms.base_llm.chat.transformation import BaseConfig from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig from litellm.llms.custom_httpx.http_handler import ( @@ -30,6 +30,114 @@ else: class BaseLLMHTTPHandler: + + async def _make_common_async_call( + self, + async_httpx_client: AsyncHTTPHandler, + provider_config: BaseConfig, + api_base: str, + headers: dict, + data: dict, + timeout: Union[float, httpx.Timeout], + litellm_params: dict, + stream: bool = False, + ) -> httpx.Response: + """Common implementation across stream + non-stream calls. Meant to ensure consistent error-handling.""" + max_retry_on_unprocessable_entity_error = ( + provider_config.max_retry_on_unprocessable_entity_error + ) + + response: Optional[httpx.Response] = None + for i in range(max(max_retry_on_unprocessable_entity_error, 1)): + try: + response = await async_httpx_client.post( + url=api_base, + headers=headers, + data=json.dumps(data), + timeout=timeout, + stream=stream, + ) + except httpx.HTTPStatusError as e: + hit_max_retry = i + 1 == max_retry_on_unprocessable_entity_error + should_retry = provider_config.should_retry_llm_api_inside_llm_translation_on_http_error( + e=e, litellm_params=litellm_params + ) + if should_retry and not hit_max_retry: + data = ( + provider_config.transform_request_on_unprocessable_entity_error( + e=e, request_data=data + ) + ) + continue + else: + raise self._handle_error(e=e, provider_config=provider_config) + except Exception as e: + raise self._handle_error(e=e, provider_config=provider_config) + break + + if response is None: + raise provider_config.get_error_class( + error_message="No response from the API", + status_code=422, # don't retry on this error + headers={}, + ) + + return response + + def _make_common_sync_call( + self, + sync_httpx_client: HTTPHandler, + provider_config: BaseConfig, + api_base: str, + headers: dict, + data: dict, + timeout: Union[float, httpx.Timeout], + litellm_params: dict, + stream: bool = False, + ) -> httpx.Response: + + max_retry_on_unprocessable_entity_error = ( + provider_config.max_retry_on_unprocessable_entity_error + ) + + response: Optional[httpx.Response] = None + + for i in range(max(max_retry_on_unprocessable_entity_error, 1)): + try: + response = sync_httpx_client.post( + url=api_base, + headers=headers, + data=json.dumps(data), + timeout=timeout, + stream=stream, + ) + except httpx.HTTPStatusError as e: + hit_max_retry = i + 1 == max_retry_on_unprocessable_entity_error + should_retry = provider_config.should_retry_llm_api_inside_llm_translation_on_http_error( + e=e, litellm_params=litellm_params + ) + if should_retry and not hit_max_retry: + data = ( + provider_config.transform_request_on_unprocessable_entity_error( + e=e, request_data=data + ) + ) + continue + else: + raise self._handle_error(e=e, provider_config=provider_config) + except Exception as e: + raise self._handle_error(e=e, provider_config=provider_config) + break + + if response is None: + raise provider_config.get_error_class( + error_message="No response from the API", + status_code=422, # don't retry on this error + headers={}, + ) + + return response + async def async_completion( self, custom_llm_provider: str, @@ -55,15 +163,16 @@ class BaseLLMHTTPHandler: else: async_httpx_client = client - try: - response = await async_httpx_client.post( - url=api_base, - headers=headers, - data=json.dumps(data), - timeout=timeout, - ) - except Exception as e: - raise self._handle_error(e=e, provider_config=provider_config) + response = await self._make_common_async_call( + async_httpx_client=async_httpx_client, + provider_config=provider_config, + api_base=api_base, + headers=headers, + data=data, + timeout=timeout, + litellm_params=litellm_params, + stream=False, + ) return provider_config.transform_response( model=model, raw_response=response, @@ -93,7 +202,7 @@ class BaseLLMHTTPHandler: stream: Optional[bool] = False, fake_stream: bool = False, api_key: Optional[str] = None, - headers={}, + headers: Optional[dict] = {}, client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, ): provider_config = ProviderConfigManager.get_provider_chat_config( @@ -102,10 +211,11 @@ class BaseLLMHTTPHandler: # get config from model, custom llm provider headers = provider_config.validate_environment( api_key=api_key, - headers=headers, + headers=headers or {}, model=model, messages=messages, optional_params=optional_params, + api_base=api_base, ) api_base = provider_config.get_complete_url( @@ -154,6 +264,7 @@ class BaseLLMHTTPHandler: if client is not None and isinstance(client, AsyncHTTPHandler) else None ), + litellm_params=litellm_params, ) else: @@ -186,7 +297,7 @@ class BaseLLMHTTPHandler: provider_config=provider_config, api_base=api_base, headers=headers, # type: ignore - data=json.dumps(data), + data=data, model=model, messages=messages, logging_obj=logging_obj, @@ -197,6 +308,7 @@ class BaseLLMHTTPHandler: if client is not None and isinstance(client, HTTPHandler) else None ), + litellm_params=litellm_params, ) return CustomStreamWrapper( completion_stream=completion_stream, @@ -210,19 +322,15 @@ class BaseLLMHTTPHandler: else: sync_httpx_client = client - try: - response = sync_httpx_client.post( - url=api_base, - headers=headers, - data=json.dumps(data), - timeout=timeout, - ) - except Exception as e: - raise self._handle_error( - e=e, - provider_config=provider_config, - ) - + response = self._make_common_sync_call( + sync_httpx_client=sync_httpx_client, + provider_config=provider_config, + api_base=api_base, + headers=headers, + data=data, + timeout=timeout, + litellm_params=litellm_params, + ) return provider_config.transform_response( model=model, raw_response=response, @@ -241,44 +349,33 @@ class BaseLLMHTTPHandler: provider_config: BaseConfig, api_base: str, headers: dict, - data: str, + data: dict, model: str, messages: list, logging_obj, - timeout: Optional[Union[float, httpx.Timeout]], + litellm_params: dict, + timeout: Union[float, httpx.Timeout], fake_stream: bool = False, client: Optional[HTTPHandler] = None, - ) -> Tuple[Any, httpx.Headers]: + ) -> Tuple[Any, dict]: if client is None or not isinstance(client, HTTPHandler): sync_httpx_client = _get_httpx_client() else: sync_httpx_client = client - try: - stream = True - if fake_stream is True: - stream = False - response = sync_httpx_client.post( - api_base, headers=headers, data=data, timeout=timeout, stream=stream - ) - except httpx.HTTPStatusError as e: - raise self._handle_error( - e=e, - provider_config=provider_config, - ) - except Exception as e: - for exception in litellm.LITELLM_EXCEPTION_TYPES: - if isinstance(e, exception): - raise e - raise self._handle_error( - e=e, - provider_config=provider_config, - ) + stream = True + if fake_stream is True: + stream = False - if response.status_code != 200: - raise BaseLLMException( - status_code=response.status_code, - message=str(response.read()), - ) + response = self._make_common_sync_call( + sync_httpx_client=sync_httpx_client, + provider_config=provider_config, + api_base=api_base, + headers=headers, + data=data, + timeout=timeout, + litellm_params=litellm_params, + stream=stream, + ) if fake_stream is True: completion_stream = provider_config.get_model_response_iterator( @@ -297,7 +394,7 @@ class BaseLLMHTTPHandler: additional_args={"complete_input_dict": data}, ) - return completion_stream, response.headers + return completion_stream, dict(response.headers) async def acompletion_stream_function( self, @@ -310,6 +407,7 @@ class BaseLLMHTTPHandler: timeout: Union[float, httpx.Timeout], logging_obj: LiteLLMLoggingObj, data: dict, + litellm_params: dict, fake_stream: bool = False, client: Optional[AsyncHTTPHandler] = None, ): @@ -318,12 +416,13 @@ class BaseLLMHTTPHandler: provider_config=provider_config, api_base=api_base, headers=headers, - data=json.dumps(data), + data=data, messages=messages, logging_obj=logging_obj, timeout=timeout, fake_stream=fake_stream, client=client, + litellm_params=litellm_params, ) streamwrapper = CustomStreamWrapper( completion_stream=completion_stream, @@ -339,10 +438,11 @@ class BaseLLMHTTPHandler: provider_config: BaseConfig, api_base: str, headers: dict, - data: str, + data: dict, messages: list, logging_obj: LiteLLMLoggingObj, - timeout: Optional[Union[float, httpx.Timeout]], + timeout: Union[float, httpx.Timeout], + litellm_params: dict, fake_stream: bool = False, client: Optional[AsyncHTTPHandler] = None, ) -> Tuple[Any, httpx.Headers]: @@ -355,29 +455,18 @@ class BaseLLMHTTPHandler: stream = True if fake_stream is True: stream = False - try: - response = await async_httpx_client.post( - api_base, headers=headers, data=data, stream=stream, timeout=timeout - ) - except httpx.HTTPStatusError as e: - raise self._handle_error( - e=e, - provider_config=provider_config, - ) - except Exception as e: - for exception in litellm.LITELLM_EXCEPTION_TYPES: - if isinstance(e, exception): - raise e - raise self._handle_error( - e=e, - provider_config=provider_config, - ) - if response.status_code != 200: - raise BaseLLMException( - status_code=response.status_code, - message=str(response.read()), - ) + response = await self._make_common_async_call( + async_httpx_client=async_httpx_client, + provider_config=provider_config, + api_base=api_base, + headers=headers, + data=data, + timeout=timeout, + litellm_params=litellm_params, + stream=stream, + ) + if fake_stream is True: completion_stream = provider_config.get_model_response_iterator( streaming_response=response.json(), sync_stream=False diff --git a/litellm/llms/deepgram/audio_transcription/transformation.py b/litellm/llms/deepgram/audio_transcription/transformation.py index c5ee148265..5464888bbe 100644 --- a/litellm/llms/deepgram/audio_transcription/transformation.py +++ b/litellm/llms/deepgram/audio_transcription/transformation.py @@ -118,6 +118,7 @@ class DeepgramAudioTranscriptionConfig(BaseAudioTranscriptionConfig): messages: List[AllMessageValues], optional_params: dict, api_key: Optional[str] = None, + api_base: Optional[str] = None, ) -> dict: api_key = api_key or get_secret_str("DEEPGRAM_API_KEY") return { diff --git a/litellm/llms/fireworks_ai/common_utils.py b/litellm/llms/fireworks_ai/common_utils.py index ca5d792dac..293403b133 100644 --- a/litellm/llms/fireworks_ai/common_utils.py +++ b/litellm/llms/fireworks_ai/common_utils.py @@ -42,6 +42,7 @@ class FireworksAIMixin: messages: List[AllMessageValues], optional_params: dict, api_key: Optional[str] = None, + api_base: Optional[str] = None, ) -> dict: api_key = self._get_api_key(api_key) if api_key is None: diff --git a/litellm/llms/huggingface/chat/handler.py b/litellm/llms/huggingface/chat/handler.py index df3140e104..e9b40be6a7 100644 --- a/litellm/llms/huggingface/chat/handler.py +++ b/litellm/llms/huggingface/chat/handler.py @@ -724,12 +724,14 @@ class Huggingface(BaseLLM): token_logprob = token["logprob"] # Add the token information to the 'token_info' list - _logprob.tokens.append(token_text) - _logprob.token_logprobs.append(token_logprob) + cast(List[str], _logprob.tokens).append(token_text) + cast(List[float], _logprob.token_logprobs).append(token_logprob) # stub this to work with llm eval harness top_alt_tokens = {"": -1.0, "": -2.0, "": -3.0} # noqa: F601 - _logprob.top_logprobs.append(top_alt_tokens) + cast(List[Dict[str, float]], _logprob.top_logprobs).append( + top_alt_tokens + ) # For each element in the 'tokens' list, extract the relevant information for i, token in enumerate(response_details["tokens"]): @@ -751,13 +753,15 @@ class Huggingface(BaseLLM): top_alt_tokens[text] = logprob # Add the token information to the 'token_info' list - _logprob.tokens.append(token_text) - _logprob.token_logprobs.append(token_logprob) - _logprob.top_logprobs.append(top_alt_tokens) + cast(List[str], _logprob.tokens).append(token_text) + cast(List[float], _logprob.token_logprobs).append(token_logprob) + cast(List[Dict[str, float]], _logprob.top_logprobs).append( + top_alt_tokens + ) # Add the text offset of the token # This is computed as the sum of the lengths of all previous tokens - _logprob.text_offset.append( + cast(List[int], _logprob.text_offset).append( sum(len(t["text"]) for t in response_details["tokens"][:i]) ) diff --git a/litellm/llms/huggingface/chat/transformation.py b/litellm/llms/huggingface/chat/transformation.py index 2d3fa46caf..2f9824b677 100644 --- a/litellm/llms/huggingface/chat/transformation.py +++ b/litellm/llms/huggingface/chat/transformation.py @@ -356,6 +356,7 @@ class HuggingfaceChatConfig(BaseConfig): messages: List[AllMessageValues], optional_params: Dict, api_key: Optional[str] = None, + api_base: Optional[str] = None, ) -> Dict: default_headers = { "content-type": "application/json", diff --git a/litellm/llms/nlp_cloud/chat/transformation.py b/litellm/llms/nlp_cloud/chat/transformation.py index 42bef0f4e8..35ced50242 100644 --- a/litellm/llms/nlp_cloud/chat/transformation.py +++ b/litellm/llms/nlp_cloud/chat/transformation.py @@ -94,6 +94,7 @@ class NLPCloudConfig(BaseConfig): messages: List[AllMessageValues], optional_params: dict, api_key: Optional[str] = None, + api_base: Optional[str] = None, ) -> dict: headers = { "accept": "application/json", diff --git a/litellm/llms/ollama/completion/transformation.py b/litellm/llms/ollama/completion/transformation.py index 9b4bf48e97..fcd198b01a 100644 --- a/litellm/llms/ollama/completion/transformation.py +++ b/litellm/llms/ollama/completion/transformation.py @@ -347,6 +347,7 @@ class OllamaConfig(BaseConfig): messages: List[AllMessageValues], optional_params: dict, api_key: Optional[str] = None, + api_base: Optional[str] = None, ) -> dict: return headers diff --git a/litellm/llms/oobabooga/chat/transformation.py b/litellm/llms/oobabooga/chat/transformation.py index 02283f93e2..6fd56f934e 100644 --- a/litellm/llms/oobabooga/chat/transformation.py +++ b/litellm/llms/oobabooga/chat/transformation.py @@ -89,6 +89,7 @@ class OobaboogaConfig(OpenAIGPTConfig): messages: List[AllMessageValues], optional_params: dict, api_key: Optional[str] = None, + api_base: Optional[str] = None, ) -> dict: headers = { "accept": "application/json", diff --git a/litellm/llms/openai/chat/gpt_transformation.py b/litellm/llms/openai/chat/gpt_transformation.py index 7b732a5557..599250ab6b 100644 --- a/litellm/llms/openai/chat/gpt_transformation.py +++ b/litellm/llms/openai/chat/gpt_transformation.py @@ -181,6 +181,7 @@ class OpenAIGPTConfig(BaseConfig): Returns: dict: The transformed request. Sent as the body of the API call. """ + messages = self._transform_messages(messages=messages, model=model) return { "model": model, "messages": messages, @@ -225,5 +226,6 @@ class OpenAIGPTConfig(BaseConfig): messages: List[AllMessageValues], optional_params: dict, api_key: Optional[str] = None, + api_base: Optional[str] = None, ) -> dict: raise NotImplementedError diff --git a/litellm/llms/openai/common_utils.py b/litellm/llms/openai/common_utils.py index 87857f7ced..98a55b4bd3 100644 --- a/litellm/llms/openai/common_utils.py +++ b/litellm/llms/openai/common_utils.py @@ -45,7 +45,8 @@ class OpenAIError(BaseLLMException): ####### Error Handling Utils for OpenAI API ####################### ################################################################### def drop_params_from_unprocessable_entity_error( - e: openai.UnprocessableEntityError, data: Dict[str, Any] + e: Union[openai.UnprocessableEntityError, httpx.HTTPStatusError], + data: Dict[str, Any], ) -> Dict[str, Any]: """ Helper function to read OpenAI UnprocessableEntityError and drop the params that raised an error from the error message. @@ -58,14 +59,25 @@ def drop_params_from_unprocessable_entity_error( Dict[str, Any]: A new dictionary with invalid parameters removed """ invalid_params: List[str] = [] - if e.body is not None and isinstance(e.body, dict) and e.body.get("message"): - message = e.body.get("message", {}) + if isinstance(e, httpx.HTTPStatusError): + error_json = e.response.json() + error_message = error_json.get("error", {}) + error_body = error_message + else: + error_body = e.body + if ( + error_body is not None + and isinstance(error_body, dict) + and error_body.get("message") + ): + message = error_body.get("message", {}) if isinstance(message, str): try: message = json.loads(message) except json.JSONDecodeError: message = {"detail": message} detail = message.get("detail") + if isinstance(detail, List) and len(detail) > 0 and isinstance(detail[0], dict): for error_dict in detail: if ( @@ -76,4 +88,5 @@ def drop_params_from_unprocessable_entity_error( invalid_params.append(error_dict["loc"][1]) new_data = {k: v for k, v in data.items() if k not in invalid_params} + return new_data diff --git a/litellm/llms/openai/openai.py b/litellm/llms/openai/openai.py index 2ec9037e32..a7ab3a72e0 100644 --- a/litellm/llms/openai/openai.py +++ b/litellm/llms/openai/openai.py @@ -2,9 +2,11 @@ import hashlib import types from typing import ( Any, + AsyncIterator, Callable, Coroutine, Iterable, + Iterator, List, Literal, Optional, @@ -24,10 +26,16 @@ import litellm from litellm import LlmProviders from litellm._logging import verbose_logger from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException from litellm.llms.bedrock.chat.invoke_handler import MockResponseIterator from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS -from litellm.types.utils import EmbeddingResponse, ImageResponse, ModelResponse +from litellm.types.utils import ( + EmbeddingResponse, + ImageResponse, + ModelResponse, + ModelResponseStream, +) from litellm.utils import ( CustomStreamWrapper, ProviderConfigManager, @@ -36,7 +44,6 @@ from litellm.utils import ( from ...types.llms.openai import * from ..base import BaseLLM -from .chat.gpt_transformation import OpenAIGPTConfig from .common_utils import OpenAIError, drop_params_from_unprocessable_entity_error @@ -232,6 +239,7 @@ class OpenAIConfig(BaseConfig): litellm_params: dict, headers: dict, ) -> dict: + messages = self._transform_messages(messages=messages, model=model) return {"model": model, "messages": messages, **optional_params} def transform_response( @@ -248,10 +256,21 @@ class OpenAIConfig(BaseConfig): api_key: Optional[str] = None, json_mode: Optional[bool] = None, ) -> ModelResponse: - raise NotImplementedError( - "OpenAI handler does this transformation as it uses the OpenAI SDK." + + logging_obj.post_call(original_response=raw_response.text) + logging_obj.model_call_details["response_headers"] = raw_response.headers + final_response_obj = cast( + ModelResponse, + convert_to_model_response_object( + response_object=raw_response.json(), + model_response_object=model_response, + hidden_params={"headers": raw_response.headers}, + _response_headers=dict(raw_response.headers), + ), ) + return final_response_obj + def validate_environment( self, headers: dict, @@ -259,12 +278,37 @@ class OpenAIConfig(BaseConfig): messages: List[AllMessageValues], optional_params: dict, api_key: Optional[str] = None, + api_base: Optional[str] = None, ) -> dict: - raise NotImplementedError( - "OpenAI handler does this validation as it uses the OpenAI SDK." + return { + "Authorization": f"Bearer {api_key}", + **headers, + } + + def get_model_response_iterator( + self, + streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse], + sync_stream: bool, + json_mode: Optional[bool] = False, + ) -> Any: + return OpenAIChatCompletionResponseIterator( + streaming_response=streaming_response, + sync_stream=sync_stream, + json_mode=json_mode, ) +class OpenAIChatCompletionResponseIterator(BaseModelResponseIterator): + def chunk_parser(self, chunk: dict) -> ModelResponseStream: + """ + {'choices': [{'delta': {'content': '', 'role': 'assistant'}, 'finish_reason': None, 'index': 0, 'logprobs': None}], 'created': 1735763082, 'id': 'a83a2b0fbfaf4aab9c2c93cb8ba346d7', 'model': 'mistral-large', 'object': 'chat.completion.chunk'} + """ + try: + return ModelResponseStream(**chunk) + except Exception as e: + raise e + + class OpenAIChatCompletion(BaseLLM): def __init__(self) -> None: @@ -473,14 +517,6 @@ class OpenAIChatCompletion(BaseLLM): if custom_llm_provider is not None and custom_llm_provider != "openai": model_response.model = f"{custom_llm_provider}/{model}" - if messages is not None and provider_config is not None: - if isinstance(provider_config, OpenAIGPTConfig) or isinstance( - provider_config, OpenAIConfig - ): # [TODO]: remove. no longer needed as .transform_request can just handle this. - messages = provider_config._transform_messages( - messages=messages, model=model - ) - for _ in range( 2 ): # if call fails due to alternating messages, retry with reformatted message @@ -647,12 +683,10 @@ class OpenAIChatCompletion(BaseLLM): new_messages = messages new_messages.append({"role": "user", "content": ""}) messages = new_messages - elif ( - "unknown field: parameter index is not a valid field" in str(e) - ) and "tools" in data: - litellm.remove_index_from_tool_calls( - tool_calls=data["tools"], messages=messages - ) + elif "unknown field: parameter index is not a valid field" in str( + e + ): + litellm.remove_index_from_tool_calls(messages=messages) else: raise e except OpenAIError as e: diff --git a/litellm/llms/petals/completion/transformation.py b/litellm/llms/petals/completion/transformation.py index 79792c1f65..dec3f69416 100644 --- a/litellm/llms/petals/completion/transformation.py +++ b/litellm/llms/petals/completion/transformation.py @@ -132,5 +132,6 @@ class PetalsConfig(BaseConfig): messages: List[AllMessageValues], optional_params: dict, api_key: Optional[str] = None, + api_base: Optional[str] = None, ) -> dict: return {} diff --git a/litellm/llms/predibase/chat/transformation.py b/litellm/llms/predibase/chat/transformation.py index 452c6f8cd5..b9ca0ff693 100644 --- a/litellm/llms/predibase/chat/transformation.py +++ b/litellm/llms/predibase/chat/transformation.py @@ -164,6 +164,7 @@ class PredibaseConfig(BaseConfig): messages: List[AllMessageValues], optional_params: dict, api_key: Optional[str] = None, + api_base: Optional[str] = None, ) -> dict: if api_key is None: raise ValueError( diff --git a/litellm/llms/replicate/chat/transformation.py b/litellm/llms/replicate/chat/transformation.py index 1e8e2579ef..310193ea66 100644 --- a/litellm/llms/replicate/chat/transformation.py +++ b/litellm/llms/replicate/chat/transformation.py @@ -309,6 +309,7 @@ class ReplicateConfig(BaseConfig): messages: List[AllMessageValues], optional_params: dict, api_key: Optional[str] = None, + api_base: Optional[str] = None, ) -> dict: headers = { "Authorization": f"Token {api_key}", diff --git a/litellm/llms/sagemaker/completion/transformation.py b/litellm/llms/sagemaker/completion/transformation.py index a2d2c34f9b..4ee4d2ce6a 100644 --- a/litellm/llms/sagemaker/completion/transformation.py +++ b/litellm/llms/sagemaker/completion/transformation.py @@ -260,6 +260,7 @@ class SagemakerConfig(BaseConfig): messages: List[AllMessageValues], optional_params: dict, api_key: Optional[str] = None, + api_base: Optional[str] = None, ) -> dict: headers = {"Content-Type": "application/json"} diff --git a/litellm/llms/triton/completion/transformation.py b/litellm/llms/triton/completion/transformation.py index 10223453de..0cd6940063 100644 --- a/litellm/llms/triton/completion/transformation.py +++ b/litellm/llms/triton/completion/transformation.py @@ -48,6 +48,7 @@ class TritonConfig(BaseConfig): messages: List[AllMessageValues], optional_params: Dict, api_key: Optional[str] = None, + api_base: Optional[str] = None, ) -> Dict: return {"Content-Type": "application/json"} diff --git a/litellm/llms/triton/embedding/transformation.py b/litellm/llms/triton/embedding/transformation.py index 85857a5610..4744ec0834 100644 --- a/litellm/llms/triton/embedding/transformation.py +++ b/litellm/llms/triton/embedding/transformation.py @@ -43,6 +43,7 @@ class TritonEmbeddingConfig(BaseEmbeddingConfig): messages: List[AllMessageValues], optional_params: dict, api_key: Optional[str] = None, + api_base: Optional[str] = None, ) -> dict: return {} diff --git a/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py b/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py index 0a51870cfe..294c815016 100644 --- a/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py +++ b/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py @@ -808,6 +808,7 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig): messages: List[AllMessageValues], optional_params: Dict, api_key: Optional[str] = None, + api_base: Optional[str] = None, ) -> Dict: default_headers = { "Content-Type": "application/json", diff --git a/litellm/llms/voyage/embedding/transformation.py b/litellm/llms/voyage/embedding/transformation.py index 3d969223a5..623dfe73af 100644 --- a/litellm/llms/voyage/embedding/transformation.py +++ b/litellm/llms/voyage/embedding/transformation.py @@ -82,6 +82,7 @@ class VoyageEmbeddingConfig(BaseEmbeddingConfig): messages: List[AllMessageValues], optional_params: dict, api_key: Optional[str] = None, + api_base: Optional[str] = None, ) -> dict: if api_key is None: api_key = ( diff --git a/litellm/llms/watsonx/common_utils.py b/litellm/llms/watsonx/common_utils.py index 50fefc4da8..b8340503d3 100644 --- a/litellm/llms/watsonx/common_utils.py +++ b/litellm/llms/watsonx/common_utils.py @@ -166,6 +166,7 @@ class IBMWatsonXMixin: messages: List[AllMessageValues], optional_params: Dict, api_key: Optional[str] = None, + api_base: Optional[str] = None, ) -> Dict: default_headers = { "Content-Type": "application/json", diff --git a/litellm/main.py b/litellm/main.py index 537ee78bea..9f08f9f26c 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1122,6 +1122,7 @@ def completion( # type: ignore # noqa: PLR0915 custom_prompt_dict=custom_prompt_dict, litellm_metadata=kwargs.get("litellm_metadata"), disable_add_transform_inline_image_block=disable_add_transform_inline_image_block, + drop_params=kwargs.get("drop_params"), ) logging.update_environment_variables( model=model, @@ -1347,39 +1348,28 @@ def completion( # type: ignore # noqa: PLR0915 if extra_headers is not None: optional_params["extra_headers"] = extra_headers - ## LOAD CONFIG - if set - config = litellm.AzureAIStudioConfig.get_config() - for k, v in config.items(): - if ( - k not in optional_params - ): # completion(top_k=3) > openai_config(top_k=3) <- allows for dynamic variables to be passed in - optional_params[k] = v - ## FOR COHERE if "command-r" in model: # make sure tool call in messages are str messages = stringify_json_tool_call_content(messages=messages) ## COMPLETION CALL try: - response = openai_chat_completions.completion( + response = base_llm_http_handler.completion( model=model, messages=messages, headers=headers, model_response=model_response, - print_verbose=print_verbose, api_key=api_key, api_base=api_base, acompletion=acompletion, logging_obj=logging, optional_params=optional_params, litellm_params=litellm_params, - logger_fn=logger_fn, timeout=timeout, # type: ignore - custom_prompt_dict=custom_prompt_dict, client=client, # pass AsyncOpenAI, OpenAI client - organization=organization, custom_llm_provider=custom_llm_provider, - drop_params=non_default_params.get("drop_params"), + encoding=encoding, + stream=stream, ) except Exception as e: ## LOGGING - log the original exception returned diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index a4a8618082..69815102f6 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -29,4 +29,4 @@ litellm_settings: failure_callback: ["langfuse"] langfuse_public_key: os.environ/LANGFUSE_PROJECT3_PUBLIC langfuse_secret: os.environ/LANGFUSE_PROJECT3_SECRET - langfuse_host: os.environ/LANGFUSE_HOST \ No newline at end of file + langfuse_host: os.environ/LANGFUSE_HOST diff --git a/litellm/types/utils.py b/litellm/types/utils.py index dd5d1b5ec8..a1ef3e6e56 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1074,10 +1074,10 @@ class EmbeddingResponse(OpenAIObject): class Logprobs(OpenAIObject): - text_offset: List[int] - token_logprobs: List[Union[float, None]] - tokens: List[str] - top_logprobs: List[Union[Dict[str, float], None]] + text_offset: Optional[List[int]] + token_logprobs: Optional[List[Union[float, None]]] + tokens: Optional[List[str]] + top_logprobs: Optional[List[Union[Dict[str, float], None]]] class TextChoices(OpenAIObject): diff --git a/litellm/utils.py b/litellm/utils.py index 17f5b74405..7e9287dc89 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2002,6 +2002,7 @@ def get_litellm_params( custom_prompt_dict: Optional[dict] = None, litellm_metadata: Optional[dict] = None, disable_add_transform_inline_image_block: Optional[bool] = None, + drop_params: Optional[bool] = None, ): litellm_params = { "acompletion": acompletion, @@ -2035,6 +2036,7 @@ def get_litellm_params( "custom_prompt_dict": custom_prompt_dict, "litellm_metadata": litellm_metadata, "disable_add_transform_inline_image_block": disable_add_transform_inline_image_block, + "drop_params": drop_params, } return litellm_params @@ -6345,3 +6347,44 @@ def extract_duration_from_srt_or_vtt(srt_or_vtt_content: str) -> Optional[float] durations.append(total_seconds) return max(durations) if durations else None + + +import httpx + + +def _add_path_to_api_base(api_base: str, ending_path: str) -> str: + """ + Adds an ending path to an API base URL while preventing duplicate path segments. + + Args: + api_base: Base URL string + ending_path: Path to append to the base URL + + Returns: + Modified URL string with proper path handling + """ + original_url = httpx.URL(api_base) + base_url = original_url.copy_with(params={}) # Removes query params + base_path = original_url.path.rstrip("/") + end_path = ending_path.lstrip("/") + + # Split paths into segments + base_segments = [s for s in base_path.split("/") if s] + end_segments = [s for s in end_path.split("/") if s] + + # Find overlapping segments from the end of base_path and start of ending_path + final_segments = [] + for i in range(len(base_segments)): + if base_segments[i:] == end_segments[: len(base_segments) - i]: + final_segments = base_segments[:i] + end_segments + break + else: + # No overlap found, just combine all segments + final_segments = base_segments + end_segments + + # Construct the new path + modified_path = "/" + "/".join(final_segments) + modified_url = base_url.copy_with(path=modified_path) + + # Re-add the original query parameters + return str(modified_url.copy_with(params=original_url.params)) diff --git a/tests/llm_translation/test_azure_ai.py b/tests/llm_translation/test_azure_ai.py index f765a368f5..e6741c4cbf 100644 --- a/tests/llm_translation/test_azure_ai.py +++ b/tests/llm_translation/test_azure_ai.py @@ -28,6 +28,7 @@ from unittest.mock import MagicMock, patch import pytest import litellm +from litellm import completion @pytest.mark.parametrize( @@ -51,18 +52,13 @@ async def test_azure_ai_with_image_url(): Test that Azure AI studio can handle image_url passed when content is a list containing both text and image_url """ - from openai import AsyncOpenAI + from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler litellm.set_verbose = True - client = AsyncOpenAI( - api_key="fake-api-key", - base_url="https://Phi-3-5-vision-instruct-dcvov.eastus2.models.ai.azure.com", - ) + client = AsyncHTTPHandler() - with patch.object( - client.chat.completions.with_raw_response, "create" - ) as mock_client: + with patch.object(client, "post") as mock_client: try: await litellm.acompletion( model="azure_ai/Phi-3-5-vision-instruct-dcvov", @@ -94,8 +90,9 @@ async def test_azure_ai_with_image_url(): # Verify the request was made mock_client.assert_called_once() + print(f"mock_client.call_args.kwargs: {mock_client.call_args.kwargs}") # Check the request body - request_body = mock_client.call_args.kwargs + request_body = json.loads(mock_client.call_args.kwargs["data"]) assert request_body["model"] == "Phi-3-5-vision-instruct-dcvov" assert request_body["messages"] == [ { @@ -111,3 +108,79 @@ async def test_azure_ai_with_image_url(): ], } ] + + +@pytest.mark.parametrize( + "api_base, expected_url", + [ + ( + "https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview", + "https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview", + ), + ( + "https://litellm8397336933.services.ai.azure.com/models/chat/completions", + "https://litellm8397336933.services.ai.azure.com/models/chat/completions", + ), + ( + "https://litellm8397336933.services.ai.azure.com/models", + "https://litellm8397336933.services.ai.azure.com/models/chat/completions", + ), + ( + "https://litellm8397336933.services.ai.azure.com", + "https://litellm8397336933.services.ai.azure.com/models/chat/completions", + ), + ], +) +def test_azure_ai_services_handler(api_base, expected_url): + from litellm.llms.custom_httpx.http_handler import HTTPHandler + + litellm.set_verbose = True + + client = HTTPHandler() + + with patch.object(client, "post") as mock_client: + try: + response = litellm.completion( + model="azure_ai/Meta-Llama-3.1-70B-Instruct", + messages=[{"role": "user", "content": "Hello, how are you?"}], + api_key="my-fake-api-key", + api_base=api_base, + client=client, + ) + + print(response) + + except Exception as e: + print(f"Error: {e}") + + mock_client.assert_called_once() + assert mock_client.call_args.kwargs["headers"]["api-key"] == "my-fake-api-key" + assert mock_client.call_args.kwargs["url"] == expected_url + + +def test_completion_azure_ai_command_r(): + try: + import os + + litellm.set_verbose = True + + os.environ["AZURE_AI_API_BASE"] = os.getenv("AZURE_COHERE_API_BASE", "") + os.environ["AZURE_AI_API_KEY"] = os.getenv("AZURE_COHERE_API_KEY", "") + + response = completion( + model="azure_ai/command-r-plus", + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "What is the meaning of life?"} + ], + } + ], + ) # type: ignore + + assert "azure_ai" in response.model + except litellm.Timeout as e: + pass + except Exception as e: + pytest.fail(f"Error occurred: {e}") diff --git a/tests/local_testing/test_completion.py b/tests/local_testing/test_completion.py index 0952c53833..4b7c2aabd1 100644 --- a/tests/local_testing/test_completion.py +++ b/tests/local_testing/test_completion.py @@ -132,34 +132,6 @@ def test_null_role_response(): assert response.choices[0].message.role == "assistant" -def test_completion_azure_ai_command_r(): - try: - import os - - litellm.set_verbose = True - - os.environ["AZURE_AI_API_BASE"] = os.getenv("AZURE_COHERE_API_BASE", "") - os.environ["AZURE_AI_API_KEY"] = os.getenv("AZURE_COHERE_API_KEY", "") - - response = completion( - model="azure_ai/command-r-plus", - messages=[ - { - "role": "user", - "content": [ - {"type": "text", "text": "What is the meaning of life?"} - ], - } - ], - ) # type: ignore - - assert "azure_ai" in response.model - except litellm.Timeout as e: - pass - except Exception as e: - pytest.fail(f"Error occurred: {e}") - - @pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.asyncio async def test_completion_azure_ai_mistral_invalid_params(sync_mode): diff --git a/tests/local_testing/test_get_llm_provider.py b/tests/local_testing/test_get_llm_provider.py index c713eaa3c3..99dcaa1ddc 100644 --- a/tests/local_testing/test_get_llm_provider.py +++ b/tests/local_testing/test_get_llm_provider.py @@ -199,4 +199,4 @@ def test_azure_global_standard_get_llm_provider(): api_base="https://my-deployment-francecentral.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview", api_key="fake-api-key", ) - assert custom_llm_provider == "azure" + assert custom_llm_provider == "azure_ai" diff --git a/tests/local_testing/test_streaming.py b/tests/local_testing/test_streaming.py index 702008c4e1..2a2e9d7fd9 100644 --- a/tests/local_testing/test_streaming.py +++ b/tests/local_testing/test_streaming.py @@ -2954,6 +2954,7 @@ def test_azure_streaming_and_function_calling(): async def test_completion_azure_ai_mistral_invalid_params(sync_mode): try: import os + from litellm import stream_chunk_builder litellm.set_verbose = True @@ -2968,15 +2969,21 @@ async def test_completion_azure_ai_mistral_invalid_params(sync_mode): "drop_params": True, "stream": True, } + chunks = [] if sync_mode: - response: litellm.ModelResponse = completion(**data) # type: ignore + response = completion(**data) # type: ignore for chunk in response: print(chunk) + chunks.append(chunk) else: - response: litellm.ModelResponse = await litellm.acompletion(**data) # type: ignore + response = await litellm.acompletion(**data) # type: ignore async for chunk in response: print(chunk) + chunks.append(chunk) + print(f"chunks: {chunks}") + response = stream_chunk_builder(chunks=chunks) + assert response.choices[0].message.content is not None except litellm.Timeout as e: pass except Exception as e: diff --git a/tests/local_testing/test_utils.py b/tests/local_testing/test_utils.py index 28cff52e95..76970a7435 100644 --- a/tests/local_testing/test_utils.py +++ b/tests/local_testing/test_utils.py @@ -1252,3 +1252,19 @@ def test_fireworks_ai_document_inlining(): assert supports_pdf_input("fireworks_ai/llama-3.1-8b-instruct") is True assert supports_vision("fireworks_ai/llama-3.1-8b-instruct") is True + + +def test_logprobs_type(): + from litellm.types.utils import Logprobs + + logprobs = { + "text_offset": None, + "token_logprobs": None, + "tokens": None, + "top_logprobs": None, + } + logprobs = Logprobs(**logprobs) + assert logprobs.text_offset is None + assert logprobs.token_logprobs is None + assert logprobs.tokens is None + assert logprobs.top_logprobs is None