Litellm dev 12 30 2024 p2 (#7495)

* test(azure_openai_o1.py): initial commit with testing for azure openai o1 preview model

* fix(base_llm_unit_tests.py): handle azure o1 preview response format tests

skip as o1 on azure doesn't support tool calling yet

* fix: initial commit of azure o1 handler using openai caller

simplifies calling + allows fake streaming logic alr. implemented for openai to just work

* feat(azure/o1_handler.py): fake o1 streaming for azure o1 models

azure does not currently support streaming for o1

* feat(o1_transformation.py): support overriding 'should_fake_stream' on azure/o1 via 'supports_native_streaming' param on model info

enables user to toggle on when azure allows o1 streaming without needing to bump versions

* style(router.py): remove 'give feedback/get help' messaging when router is used

Prevents noisy messaging

Closes https://github.com/BerriAI/litellm/issues/5942

* fix(types/utils.py): handle none logprobs

Fixes https://github.com/BerriAI/litellm/issues/328

* fix(exception_mapping_utils.py): fix error str unbound error

* refactor(azure_ai/): move to openai_like chat completion handler

allows for easy swapping of api base url's (e.g. ai.services.com)

Fixes https://github.com/BerriAI/litellm/issues/7275

* refactor(azure_ai/): move to base llm http handler

* fix(azure_ai/): handle differing api endpoints

* fix(azure_ai/): make sure all unit tests are passing

* fix: fix linting errors

* fix: fix linting errors

* fix: fix linting error

* fix: fix linting errors

* fix(azure_ai/transformation.py): handle extra body param

* fix(azure_ai/transformation.py): fix max retries param handling

* fix: fix test

* test(test_azure_o1.py): fix test

* fix(llm_http_handler.py): support handling azure ai unprocessable entity error

* fix(llm_http_handler.py): handle sync invalid param error for azure ai

* fix(azure_ai/): streaming support with base_llm_http_handler

* fix(llm_http_handler.py): working sync stream calls with unprocessable entity handling for azure ai

* fix: fix linting errors

* fix(llm_http_handler.py): fix linting error

* fix(azure_ai/): handle cohere tool call invalid index param error
This commit is contained in:
Krish Dholakia 2025-01-01 18:57:29 -08:00 committed by GitHub
parent b5e14ef52a
commit b0f570ee16
42 changed files with 638 additions and 192 deletions

View file

@ -1,10 +1,11 @@
# What is this? # What is this?
## Helper utilities ## Helper utilities
from typing import TYPE_CHECKING, Any, Optional, Union from typing import TYPE_CHECKING, Any, List, Optional, Union
import httpx import httpx
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.types.llms.openai import AllMessageValues, ChatCompletionToolParam
if TYPE_CHECKING: if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span from opentelemetry.trace import Span as _Span
@ -53,17 +54,18 @@ def map_finish_reason(
return finish_reason return finish_reason
def remove_index_from_tool_calls(messages, tool_calls): def remove_index_from_tool_calls(
for tool_call in tool_calls: messages: Optional[List[AllMessageValues]],
if "index" in tool_call: ):
tool_call.pop("index") if messages is not None:
for message in messages: for message in messages:
if "tool_calls" in message: _tool_calls = message.get("tool_calls")
tool_calls = message["tool_calls"] if _tool_calls is not None and isinstance(_tool_calls, list):
for tool_call in tool_calls: for tool_call in _tool_calls:
if "index" in tool_call: if (
tool_call.pop("index") isinstance(tool_call, dict) and "index" in tool_call
): # Type guard to ensure it's a dict
tool_call.pop("index", None)
return return

View file

@ -148,11 +148,10 @@ def exception_type( # type: ignore # noqa: PLR0915
original_exception=original_exception original_exception=original_exception
) )
try: try:
error_str = str(original_exception)
if model: if model:
if hasattr(original_exception, "message"): if hasattr(original_exception, "message"):
error_str = str(original_exception.message) error_str = str(original_exception.message)
else:
error_str = str(original_exception)
if isinstance(original_exception, BaseException): if isinstance(original_exception, BaseException):
exception_type = type(original_exception).__name__ exception_type = type(original_exception).__name__
else: else:

View file

@ -741,6 +741,7 @@ class AnthropicConfig(BaseConfig):
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Dict: ) -> Dict:
if api_key is None: if api_key is None:
raise litellm.AuthenticationError( raise litellm.AuthenticationError(

View file

@ -85,6 +85,7 @@ class AnthropicTextConfig(BaseConfig):
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict: ) -> dict:
if api_key is None: if api_key is None:
raise ValueError( raise ValueError(

View file

@ -283,6 +283,7 @@ class AzureOpenAIConfig(BaseConfig):
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict: ) -> dict:
raise NotImplementedError( raise NotImplementedError(
"Azure OpenAI has custom logic for validating environment, as it uses the OpenAI SDK." "Azure OpenAI has custom logic for validating environment, as it uses the OpenAI SDK."

View file

@ -1,4 +1,7 @@
from typing import List, Optional, Tuple from typing import Any, List, Optional, Tuple, cast
import httpx
from httpx import Response
import litellm import litellm
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
@ -6,13 +9,81 @@ from litellm.litellm_core_utils.prompt_templates.common_utils import (
_audio_or_image_in_message_content, _audio_or_image_in_message_content,
convert_content_list_to_str, convert_content_list_to_str,
) )
from litellm.llms.base_llm.chat.transformation import LiteLLMLoggingObj
from litellm.llms.openai.common_utils import drop_params_from_unprocessable_entity_error
from litellm.llms.openai.openai import OpenAIConfig from litellm.llms.openai.openai import OpenAIConfig
from litellm.secret_managers.main import get_secret_str from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues from litellm.types.llms.openai import AllMessageValues, ChatCompletionToolParam
from litellm.types.utils import ProviderField from litellm.types.utils import ModelResponse, ProviderField
from litellm.utils import _add_path_to_api_base
class AzureAIStudioConfig(OpenAIConfig): class AzureAIStudioConfig(OpenAIConfig):
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
if api_base and "services.ai.azure.com" in api_base:
headers["api-key"] = api_key
else:
headers["Authorization"] = f"Bearer {api_key}"
return headers
def get_complete_url(
self,
api_base: str,
model: str,
optional_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
Constructs a complete URL for the API request.
Args:
- api_base: Base URL, e.g.,
"https://litellm8397336933.services.ai.azure.com"
OR
"https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview"
- model: Model name.
- optional_params: Additional query parameters, including "api_version".
- stream: If streaming is required (optional).
Returns:
- A complete URL string, e.g.,
"https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview"
"""
original_url = httpx.URL(api_base)
# Extract api_version or use default
api_version = cast(Optional[str], optional_params.get("api_version"))
# Check if 'api-version' is already present
if "api-version" not in original_url.params and api_version:
# Add api_version to optional_params
original_url.params["api-version"] = api_version
# Add the path to the base URL
if "services.ai.azure.com" in api_base:
new_url = _add_path_to_api_base(
api_base=api_base, ending_path="/models/chat/completions"
)
else:
new_url = _add_path_to_api_base(
api_base=api_base, ending_path="/chat/completions"
)
# Convert optional_params to query parameters
query_params = original_url.params
final_url = httpx.URL(new_url).copy_with(params=query_params)
return str(final_url)
def get_required_params(self) -> List[ProviderField]: def get_required_params(self) -> List[ProviderField]:
"""For a given provider, return it's required fields with a description""" """For a given provider, return it's required fields with a description"""
return [ return [
@ -62,8 +133,6 @@ class AzureAIStudioConfig(OpenAIConfig):
): ):
return True return True
if api_base and "services.ai.azure" in api_base:
return True
except Exception: except Exception:
return False return False
return False return False
@ -86,3 +155,81 @@ class AzureAIStudioConfig(OpenAIConfig):
) )
custom_llm_provider = "azure" custom_llm_provider = "azure"
return api_base, dynamic_api_key, custom_llm_provider return api_base, dynamic_api_key, custom_llm_provider
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
extra_body = optional_params.pop("extra_body", {})
if extra_body and isinstance(extra_body, dict):
optional_params.update(extra_body)
optional_params.pop("max_retries", None)
return super().transform_request(
model, messages, optional_params, litellm_params, headers
)
def transform_response(
self,
model: str,
raw_response: Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
model_response.model = f"azure_ai/{model}"
return super().transform_response(
model=model,
raw_response=raw_response,
model_response=model_response,
logging_obj=logging_obj,
request_data=request_data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
api_key=api_key,
json_mode=json_mode,
)
def should_retry_llm_api_inside_llm_translation_on_http_error(
self, e: httpx.HTTPStatusError, litellm_params: dict
) -> bool:
should_drop_params = litellm_params.get("drop_params") or litellm.drop_params
error_text = e.response.text
if should_drop_params and "Extra inputs are not permitted" in error_text:
return True
elif (
"unknown field: parameter index is not a valid field" in error_text
): # remove index from tool calls
return True
return super().should_retry_llm_api_inside_llm_translation_on_http_error(
e=e, litellm_params=litellm_params
)
@property
def max_retry_on_unprocessable_entity_error(self) -> int:
return 2
def transform_request_on_unprocessable_entity_error(
self, e: httpx.HTTPStatusError, request_data: dict
) -> dict:
_messages = cast(Optional[List[AllMessageValues]], request_data.get("messages"))
if (
"unknown field: parameter index is not a valid field" in e.response.text
and _messages is not None
):
litellm.remove_index_from_tool_calls(
messages=_messages,
)
data = drop_params_from_unprocessable_entity_error(e=e, data=request_data)
return data

View file

@ -1,8 +1,8 @@
import json import json
from abc import abstractmethod from abc import abstractmethod
from typing import Optional from typing import Optional, Union
from litellm.types.utils import GenericStreamingChunk from litellm.types.utils import GenericStreamingChunk, ModelResponseStream
class BaseModelResponseIterator: class BaseModelResponseIterator:
@ -13,7 +13,9 @@ class BaseModelResponseIterator:
self.response_iterator = self.streaming_response self.response_iterator = self.streaming_response
self.json_mode = json_mode self.json_mode = json_mode
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk: def chunk_parser(
self, chunk: dict
) -> Union[GenericStreamingChunk, ModelResponseStream]:
return GenericStreamingChunk( return GenericStreamingChunk(
text="", text="",
is_finished=False, is_finished=False,
@ -27,7 +29,9 @@ class BaseModelResponseIterator:
def __iter__(self): def __iter__(self):
return self return self
def _handle_string_chunk(self, str_line: str) -> GenericStreamingChunk: def _handle_string_chunk(
self, str_line: str
) -> Union[GenericStreamingChunk, ModelResponseStream]:
# chunk is a str at this point # chunk is a str at this point
if "[DONE]" in str_line: if "[DONE]" in str_line:
return GenericStreamingChunk( return GenericStreamingChunk(

View file

@ -82,6 +82,33 @@ class BaseConfig(ABC):
""" """
return False return False
def should_retry_llm_api_inside_llm_translation_on_http_error(
self, e: httpx.HTTPStatusError, litellm_params: dict
) -> bool:
"""
Returns True if the model/provider should retry the LLM API on UnprocessableEntityError
Overriden by azure ai - where different models support different parameters
"""
return False
def transform_request_on_unprocessable_entity_error(
self, e: httpx.HTTPStatusError, request_data: dict
) -> dict:
"""
Transform the request data on UnprocessableEntityError
"""
return request_data
@property
def max_retry_on_unprocessable_entity_error(self) -> int:
"""
Returns the max retry count for UnprocessableEntityError
Used if `should_retry_llm_api_inside_llm_translation_on_http_error` is True
"""
return 0
@abstractmethod @abstractmethod
def get_supported_openai_params(self, model: str) -> list: def get_supported_openai_params(self, model: str) -> list:
pass pass
@ -104,6 +131,7 @@ class BaseConfig(ABC):
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict: ) -> dict:
pass pass

View file

@ -115,6 +115,7 @@ class AmazonInvokeMixin:
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict: ) -> dict:
raise NotImplementedError( raise NotImplementedError(
"validate_environment not implemented for config. Done in invoke_handler.py" "validate_environment not implemented for config. Done in invoke_handler.py"

View file

@ -119,6 +119,7 @@ class ClarifaiConfig(BaseConfig):
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict: ) -> dict:
headers = { headers = {
"accept": "application/json", "accept": "application/json",

View file

@ -60,6 +60,7 @@ class CloudflareChatConfig(BaseConfig):
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict: ) -> dict:
if api_key is None: if api_key is None:
raise ValueError( raise ValueError(

View file

@ -116,6 +116,7 @@ class CohereChatConfig(BaseConfig):
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict: ) -> dict:
return cohere_validate_environment( return cohere_validate_environment(
headers=headers, headers=headers,

View file

@ -102,6 +102,7 @@ class CohereTextConfig(BaseConfig):
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict: ) -> dict:
return cohere_validate_environment( return cohere_validate_environment(
headers=headers, headers=headers,

View file

@ -8,7 +8,7 @@ import litellm
import litellm.litellm_core_utils import litellm.litellm_core_utils
import litellm.types import litellm.types
import litellm.types.utils import litellm.types.utils
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
@ -30,6 +30,114 @@ else:
class BaseLLMHTTPHandler: class BaseLLMHTTPHandler:
async def _make_common_async_call(
self,
async_httpx_client: AsyncHTTPHandler,
provider_config: BaseConfig,
api_base: str,
headers: dict,
data: dict,
timeout: Union[float, httpx.Timeout],
litellm_params: dict,
stream: bool = False,
) -> httpx.Response:
"""Common implementation across stream + non-stream calls. Meant to ensure consistent error-handling."""
max_retry_on_unprocessable_entity_error = (
provider_config.max_retry_on_unprocessable_entity_error
)
response: Optional[httpx.Response] = None
for i in range(max(max_retry_on_unprocessable_entity_error, 1)):
try:
response = await async_httpx_client.post(
url=api_base,
headers=headers,
data=json.dumps(data),
timeout=timeout,
stream=stream,
)
except httpx.HTTPStatusError as e:
hit_max_retry = i + 1 == max_retry_on_unprocessable_entity_error
should_retry = provider_config.should_retry_llm_api_inside_llm_translation_on_http_error(
e=e, litellm_params=litellm_params
)
if should_retry and not hit_max_retry:
data = (
provider_config.transform_request_on_unprocessable_entity_error(
e=e, request_data=data
)
)
continue
else:
raise self._handle_error(e=e, provider_config=provider_config)
except Exception as e:
raise self._handle_error(e=e, provider_config=provider_config)
break
if response is None:
raise provider_config.get_error_class(
error_message="No response from the API",
status_code=422, # don't retry on this error
headers={},
)
return response
def _make_common_sync_call(
self,
sync_httpx_client: HTTPHandler,
provider_config: BaseConfig,
api_base: str,
headers: dict,
data: dict,
timeout: Union[float, httpx.Timeout],
litellm_params: dict,
stream: bool = False,
) -> httpx.Response:
max_retry_on_unprocessable_entity_error = (
provider_config.max_retry_on_unprocessable_entity_error
)
response: Optional[httpx.Response] = None
for i in range(max(max_retry_on_unprocessable_entity_error, 1)):
try:
response = sync_httpx_client.post(
url=api_base,
headers=headers,
data=json.dumps(data),
timeout=timeout,
stream=stream,
)
except httpx.HTTPStatusError as e:
hit_max_retry = i + 1 == max_retry_on_unprocessable_entity_error
should_retry = provider_config.should_retry_llm_api_inside_llm_translation_on_http_error(
e=e, litellm_params=litellm_params
)
if should_retry and not hit_max_retry:
data = (
provider_config.transform_request_on_unprocessable_entity_error(
e=e, request_data=data
)
)
continue
else:
raise self._handle_error(e=e, provider_config=provider_config)
except Exception as e:
raise self._handle_error(e=e, provider_config=provider_config)
break
if response is None:
raise provider_config.get_error_class(
error_message="No response from the API",
status_code=422, # don't retry on this error
headers={},
)
return response
async def async_completion( async def async_completion(
self, self,
custom_llm_provider: str, custom_llm_provider: str,
@ -55,15 +163,16 @@ class BaseLLMHTTPHandler:
else: else:
async_httpx_client = client async_httpx_client = client
try: response = await self._make_common_async_call(
response = await async_httpx_client.post( async_httpx_client=async_httpx_client,
url=api_base, provider_config=provider_config,
api_base=api_base,
headers=headers, headers=headers,
data=json.dumps(data), data=data,
timeout=timeout, timeout=timeout,
litellm_params=litellm_params,
stream=False,
) )
except Exception as e:
raise self._handle_error(e=e, provider_config=provider_config)
return provider_config.transform_response( return provider_config.transform_response(
model=model, model=model,
raw_response=response, raw_response=response,
@ -93,7 +202,7 @@ class BaseLLMHTTPHandler:
stream: Optional[bool] = False, stream: Optional[bool] = False,
fake_stream: bool = False, fake_stream: bool = False,
api_key: Optional[str] = None, api_key: Optional[str] = None,
headers={}, headers: Optional[dict] = {},
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
): ):
provider_config = ProviderConfigManager.get_provider_chat_config( provider_config = ProviderConfigManager.get_provider_chat_config(
@ -102,10 +211,11 @@ class BaseLLMHTTPHandler:
# get config from model, custom llm provider # get config from model, custom llm provider
headers = provider_config.validate_environment( headers = provider_config.validate_environment(
api_key=api_key, api_key=api_key,
headers=headers, headers=headers or {},
model=model, model=model,
messages=messages, messages=messages,
optional_params=optional_params, optional_params=optional_params,
api_base=api_base,
) )
api_base = provider_config.get_complete_url( api_base = provider_config.get_complete_url(
@ -154,6 +264,7 @@ class BaseLLMHTTPHandler:
if client is not None and isinstance(client, AsyncHTTPHandler) if client is not None and isinstance(client, AsyncHTTPHandler)
else None else None
), ),
litellm_params=litellm_params,
) )
else: else:
@ -186,7 +297,7 @@ class BaseLLMHTTPHandler:
provider_config=provider_config, provider_config=provider_config,
api_base=api_base, api_base=api_base,
headers=headers, # type: ignore headers=headers, # type: ignore
data=json.dumps(data), data=data,
model=model, model=model,
messages=messages, messages=messages,
logging_obj=logging_obj, logging_obj=logging_obj,
@ -197,6 +308,7 @@ class BaseLLMHTTPHandler:
if client is not None and isinstance(client, HTTPHandler) if client is not None and isinstance(client, HTTPHandler)
else None else None
), ),
litellm_params=litellm_params,
) )
return CustomStreamWrapper( return CustomStreamWrapper(
completion_stream=completion_stream, completion_stream=completion_stream,
@ -210,19 +322,15 @@ class BaseLLMHTTPHandler:
else: else:
sync_httpx_client = client sync_httpx_client = client
try: response = self._make_common_sync_call(
response = sync_httpx_client.post( sync_httpx_client=sync_httpx_client,
url=api_base,
headers=headers,
data=json.dumps(data),
timeout=timeout,
)
except Exception as e:
raise self._handle_error(
e=e,
provider_config=provider_config, provider_config=provider_config,
api_base=api_base,
headers=headers,
data=data,
timeout=timeout,
litellm_params=litellm_params,
) )
return provider_config.transform_response( return provider_config.transform_response(
model=model, model=model,
raw_response=response, raw_response=response,
@ -241,43 +349,32 @@ class BaseLLMHTTPHandler:
provider_config: BaseConfig, provider_config: BaseConfig,
api_base: str, api_base: str,
headers: dict, headers: dict,
data: str, data: dict,
model: str, model: str,
messages: list, messages: list,
logging_obj, logging_obj,
timeout: Optional[Union[float, httpx.Timeout]], litellm_params: dict,
timeout: Union[float, httpx.Timeout],
fake_stream: bool = False, fake_stream: bool = False,
client: Optional[HTTPHandler] = None, client: Optional[HTTPHandler] = None,
) -> Tuple[Any, httpx.Headers]: ) -> Tuple[Any, dict]:
if client is None or not isinstance(client, HTTPHandler): if client is None or not isinstance(client, HTTPHandler):
sync_httpx_client = _get_httpx_client() sync_httpx_client = _get_httpx_client()
else: else:
sync_httpx_client = client sync_httpx_client = client
try:
stream = True stream = True
if fake_stream is True: if fake_stream is True:
stream = False stream = False
response = sync_httpx_client.post(
api_base, headers=headers, data=data, timeout=timeout, stream=stream
)
except httpx.HTTPStatusError as e:
raise self._handle_error(
e=e,
provider_config=provider_config,
)
except Exception as e:
for exception in litellm.LITELLM_EXCEPTION_TYPES:
if isinstance(e, exception):
raise e
raise self._handle_error(
e=e,
provider_config=provider_config,
)
if response.status_code != 200: response = self._make_common_sync_call(
raise BaseLLMException( sync_httpx_client=sync_httpx_client,
status_code=response.status_code, provider_config=provider_config,
message=str(response.read()), api_base=api_base,
headers=headers,
data=data,
timeout=timeout,
litellm_params=litellm_params,
stream=stream,
) )
if fake_stream is True: if fake_stream is True:
@ -297,7 +394,7 @@ class BaseLLMHTTPHandler:
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
return completion_stream, response.headers return completion_stream, dict(response.headers)
async def acompletion_stream_function( async def acompletion_stream_function(
self, self,
@ -310,6 +407,7 @@ class BaseLLMHTTPHandler:
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
logging_obj: LiteLLMLoggingObj, logging_obj: LiteLLMLoggingObj,
data: dict, data: dict,
litellm_params: dict,
fake_stream: bool = False, fake_stream: bool = False,
client: Optional[AsyncHTTPHandler] = None, client: Optional[AsyncHTTPHandler] = None,
): ):
@ -318,12 +416,13 @@ class BaseLLMHTTPHandler:
provider_config=provider_config, provider_config=provider_config,
api_base=api_base, api_base=api_base,
headers=headers, headers=headers,
data=json.dumps(data), data=data,
messages=messages, messages=messages,
logging_obj=logging_obj, logging_obj=logging_obj,
timeout=timeout, timeout=timeout,
fake_stream=fake_stream, fake_stream=fake_stream,
client=client, client=client,
litellm_params=litellm_params,
) )
streamwrapper = CustomStreamWrapper( streamwrapper = CustomStreamWrapper(
completion_stream=completion_stream, completion_stream=completion_stream,
@ -339,10 +438,11 @@ class BaseLLMHTTPHandler:
provider_config: BaseConfig, provider_config: BaseConfig,
api_base: str, api_base: str,
headers: dict, headers: dict,
data: str, data: dict,
messages: list, messages: list,
logging_obj: LiteLLMLoggingObj, logging_obj: LiteLLMLoggingObj,
timeout: Optional[Union[float, httpx.Timeout]], timeout: Union[float, httpx.Timeout],
litellm_params: dict,
fake_stream: bool = False, fake_stream: bool = False,
client: Optional[AsyncHTTPHandler] = None, client: Optional[AsyncHTTPHandler] = None,
) -> Tuple[Any, httpx.Headers]: ) -> Tuple[Any, httpx.Headers]:
@ -355,29 +455,18 @@ class BaseLLMHTTPHandler:
stream = True stream = True
if fake_stream is True: if fake_stream is True:
stream = False stream = False
try:
response = await async_httpx_client.post( response = await self._make_common_async_call(
api_base, headers=headers, data=data, stream=stream, timeout=timeout async_httpx_client=async_httpx_client,
)
except httpx.HTTPStatusError as e:
raise self._handle_error(
e=e,
provider_config=provider_config,
)
except Exception as e:
for exception in litellm.LITELLM_EXCEPTION_TYPES:
if isinstance(e, exception):
raise e
raise self._handle_error(
e=e,
provider_config=provider_config, provider_config=provider_config,
api_base=api_base,
headers=headers,
data=data,
timeout=timeout,
litellm_params=litellm_params,
stream=stream,
) )
if response.status_code != 200:
raise BaseLLMException(
status_code=response.status_code,
message=str(response.read()),
)
if fake_stream is True: if fake_stream is True:
completion_stream = provider_config.get_model_response_iterator( completion_stream = provider_config.get_model_response_iterator(
streaming_response=response.json(), sync_stream=False streaming_response=response.json(), sync_stream=False

View file

@ -118,6 +118,7 @@ class DeepgramAudioTranscriptionConfig(BaseAudioTranscriptionConfig):
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict: ) -> dict:
api_key = api_key or get_secret_str("DEEPGRAM_API_KEY") api_key = api_key or get_secret_str("DEEPGRAM_API_KEY")
return { return {

View file

@ -42,6 +42,7 @@ class FireworksAIMixin:
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict: ) -> dict:
api_key = self._get_api_key(api_key) api_key = self._get_api_key(api_key)
if api_key is None: if api_key is None:

View file

@ -724,12 +724,14 @@ class Huggingface(BaseLLM):
token_logprob = token["logprob"] token_logprob = token["logprob"]
# Add the token information to the 'token_info' list # Add the token information to the 'token_info' list
_logprob.tokens.append(token_text) cast(List[str], _logprob.tokens).append(token_text)
_logprob.token_logprobs.append(token_logprob) cast(List[float], _logprob.token_logprobs).append(token_logprob)
# stub this to work with llm eval harness # stub this to work with llm eval harness
top_alt_tokens = {"": -1.0, "": -2.0, "": -3.0} # noqa: F601 top_alt_tokens = {"": -1.0, "": -2.0, "": -3.0} # noqa: F601
_logprob.top_logprobs.append(top_alt_tokens) cast(List[Dict[str, float]], _logprob.top_logprobs).append(
top_alt_tokens
)
# For each element in the 'tokens' list, extract the relevant information # For each element in the 'tokens' list, extract the relevant information
for i, token in enumerate(response_details["tokens"]): for i, token in enumerate(response_details["tokens"]):
@ -751,13 +753,15 @@ class Huggingface(BaseLLM):
top_alt_tokens[text] = logprob top_alt_tokens[text] = logprob
# Add the token information to the 'token_info' list # Add the token information to the 'token_info' list
_logprob.tokens.append(token_text) cast(List[str], _logprob.tokens).append(token_text)
_logprob.token_logprobs.append(token_logprob) cast(List[float], _logprob.token_logprobs).append(token_logprob)
_logprob.top_logprobs.append(top_alt_tokens) cast(List[Dict[str, float]], _logprob.top_logprobs).append(
top_alt_tokens
)
# Add the text offset of the token # Add the text offset of the token
# This is computed as the sum of the lengths of all previous tokens # This is computed as the sum of the lengths of all previous tokens
_logprob.text_offset.append( cast(List[int], _logprob.text_offset).append(
sum(len(t["text"]) for t in response_details["tokens"][:i]) sum(len(t["text"]) for t in response_details["tokens"][:i])
) )

View file

@ -356,6 +356,7 @@ class HuggingfaceChatConfig(BaseConfig):
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: Dict, optional_params: Dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Dict: ) -> Dict:
default_headers = { default_headers = {
"content-type": "application/json", "content-type": "application/json",

View file

@ -94,6 +94,7 @@ class NLPCloudConfig(BaseConfig):
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict: ) -> dict:
headers = { headers = {
"accept": "application/json", "accept": "application/json",

View file

@ -347,6 +347,7 @@ class OllamaConfig(BaseConfig):
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict: ) -> dict:
return headers return headers

View file

@ -89,6 +89,7 @@ class OobaboogaConfig(OpenAIGPTConfig):
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict: ) -> dict:
headers = { headers = {
"accept": "application/json", "accept": "application/json",

View file

@ -181,6 +181,7 @@ class OpenAIGPTConfig(BaseConfig):
Returns: Returns:
dict: The transformed request. Sent as the body of the API call. dict: The transformed request. Sent as the body of the API call.
""" """
messages = self._transform_messages(messages=messages, model=model)
return { return {
"model": model, "model": model,
"messages": messages, "messages": messages,
@ -225,5 +226,6 @@ class OpenAIGPTConfig(BaseConfig):
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict: ) -> dict:
raise NotImplementedError raise NotImplementedError

View file

@ -45,7 +45,8 @@ class OpenAIError(BaseLLMException):
####### Error Handling Utils for OpenAI API ####################### ####### Error Handling Utils for OpenAI API #######################
################################################################### ###################################################################
def drop_params_from_unprocessable_entity_error( def drop_params_from_unprocessable_entity_error(
e: openai.UnprocessableEntityError, data: Dict[str, Any] e: Union[openai.UnprocessableEntityError, httpx.HTTPStatusError],
data: Dict[str, Any],
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Helper function to read OpenAI UnprocessableEntityError and drop the params that raised an error from the error message. Helper function to read OpenAI UnprocessableEntityError and drop the params that raised an error from the error message.
@ -58,14 +59,25 @@ def drop_params_from_unprocessable_entity_error(
Dict[str, Any]: A new dictionary with invalid parameters removed Dict[str, Any]: A new dictionary with invalid parameters removed
""" """
invalid_params: List[str] = [] invalid_params: List[str] = []
if e.body is not None and isinstance(e.body, dict) and e.body.get("message"): if isinstance(e, httpx.HTTPStatusError):
message = e.body.get("message", {}) error_json = e.response.json()
error_message = error_json.get("error", {})
error_body = error_message
else:
error_body = e.body
if (
error_body is not None
and isinstance(error_body, dict)
and error_body.get("message")
):
message = error_body.get("message", {})
if isinstance(message, str): if isinstance(message, str):
try: try:
message = json.loads(message) message = json.loads(message)
except json.JSONDecodeError: except json.JSONDecodeError:
message = {"detail": message} message = {"detail": message}
detail = message.get("detail") detail = message.get("detail")
if isinstance(detail, List) and len(detail) > 0 and isinstance(detail[0], dict): if isinstance(detail, List) and len(detail) > 0 and isinstance(detail[0], dict):
for error_dict in detail: for error_dict in detail:
if ( if (
@ -76,4 +88,5 @@ def drop_params_from_unprocessable_entity_error(
invalid_params.append(error_dict["loc"][1]) invalid_params.append(error_dict["loc"][1])
new_data = {k: v for k, v in data.items() if k not in invalid_params} new_data = {k: v for k, v in data.items() if k not in invalid_params}
return new_data return new_data

View file

@ -2,9 +2,11 @@ import hashlib
import types import types
from typing import ( from typing import (
Any, Any,
AsyncIterator,
Callable, Callable,
Coroutine, Coroutine,
Iterable, Iterable,
Iterator,
List, List,
Literal, Literal,
Optional, Optional,
@ -24,10 +26,16 @@ import litellm
from litellm import LlmProviders from litellm import LlmProviders
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.llms.bedrock.chat.invoke_handler import MockResponseIterator from litellm.llms.bedrock.chat.invoke_handler import MockResponseIterator
from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS
from litellm.types.utils import EmbeddingResponse, ImageResponse, ModelResponse from litellm.types.utils import (
EmbeddingResponse,
ImageResponse,
ModelResponse,
ModelResponseStream,
)
from litellm.utils import ( from litellm.utils import (
CustomStreamWrapper, CustomStreamWrapper,
ProviderConfigManager, ProviderConfigManager,
@ -36,7 +44,6 @@ from litellm.utils import (
from ...types.llms.openai import * from ...types.llms.openai import *
from ..base import BaseLLM from ..base import BaseLLM
from .chat.gpt_transformation import OpenAIGPTConfig
from .common_utils import OpenAIError, drop_params_from_unprocessable_entity_error from .common_utils import OpenAIError, drop_params_from_unprocessable_entity_error
@ -232,6 +239,7 @@ class OpenAIConfig(BaseConfig):
litellm_params: dict, litellm_params: dict,
headers: dict, headers: dict,
) -> dict: ) -> dict:
messages = self._transform_messages(messages=messages, model=model)
return {"model": model, "messages": messages, **optional_params} return {"model": model, "messages": messages, **optional_params}
def transform_response( def transform_response(
@ -248,10 +256,21 @@ class OpenAIConfig(BaseConfig):
api_key: Optional[str] = None, api_key: Optional[str] = None,
json_mode: Optional[bool] = None, json_mode: Optional[bool] = None,
) -> ModelResponse: ) -> ModelResponse:
raise NotImplementedError(
"OpenAI handler does this transformation as it uses the OpenAI SDK." logging_obj.post_call(original_response=raw_response.text)
logging_obj.model_call_details["response_headers"] = raw_response.headers
final_response_obj = cast(
ModelResponse,
convert_to_model_response_object(
response_object=raw_response.json(),
model_response_object=model_response,
hidden_params={"headers": raw_response.headers},
_response_headers=dict(raw_response.headers),
),
) )
return final_response_obj
def validate_environment( def validate_environment(
self, self,
headers: dict, headers: dict,
@ -259,12 +278,37 @@ class OpenAIConfig(BaseConfig):
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict: ) -> dict:
raise NotImplementedError( return {
"OpenAI handler does this validation as it uses the OpenAI SDK." "Authorization": f"Bearer {api_key}",
**headers,
}
def get_model_response_iterator(
self,
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
sync_stream: bool,
json_mode: Optional[bool] = False,
) -> Any:
return OpenAIChatCompletionResponseIterator(
streaming_response=streaming_response,
sync_stream=sync_stream,
json_mode=json_mode,
) )
class OpenAIChatCompletionResponseIterator(BaseModelResponseIterator):
def chunk_parser(self, chunk: dict) -> ModelResponseStream:
"""
{'choices': [{'delta': {'content': '', 'role': 'assistant'}, 'finish_reason': None, 'index': 0, 'logprobs': None}], 'created': 1735763082, 'id': 'a83a2b0fbfaf4aab9c2c93cb8ba346d7', 'model': 'mistral-large', 'object': 'chat.completion.chunk'}
"""
try:
return ModelResponseStream(**chunk)
except Exception as e:
raise e
class OpenAIChatCompletion(BaseLLM): class OpenAIChatCompletion(BaseLLM):
def __init__(self) -> None: def __init__(self) -> None:
@ -473,14 +517,6 @@ class OpenAIChatCompletion(BaseLLM):
if custom_llm_provider is not None and custom_llm_provider != "openai": if custom_llm_provider is not None and custom_llm_provider != "openai":
model_response.model = f"{custom_llm_provider}/{model}" model_response.model = f"{custom_llm_provider}/{model}"
if messages is not None and provider_config is not None:
if isinstance(provider_config, OpenAIGPTConfig) or isinstance(
provider_config, OpenAIConfig
): # [TODO]: remove. no longer needed as .transform_request can just handle this.
messages = provider_config._transform_messages(
messages=messages, model=model
)
for _ in range( for _ in range(
2 2
): # if call fails due to alternating messages, retry with reformatted message ): # if call fails due to alternating messages, retry with reformatted message
@ -647,12 +683,10 @@ class OpenAIChatCompletion(BaseLLM):
new_messages = messages new_messages = messages
new_messages.append({"role": "user", "content": ""}) new_messages.append({"role": "user", "content": ""})
messages = new_messages messages = new_messages
elif ( elif "unknown field: parameter index is not a valid field" in str(
"unknown field: parameter index is not a valid field" in str(e) e
) and "tools" in data: ):
litellm.remove_index_from_tool_calls( litellm.remove_index_from_tool_calls(messages=messages)
tool_calls=data["tools"], messages=messages
)
else: else:
raise e raise e
except OpenAIError as e: except OpenAIError as e:

View file

@ -132,5 +132,6 @@ class PetalsConfig(BaseConfig):
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict: ) -> dict:
return {} return {}

View file

@ -164,6 +164,7 @@ class PredibaseConfig(BaseConfig):
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict: ) -> dict:
if api_key is None: if api_key is None:
raise ValueError( raise ValueError(

View file

@ -309,6 +309,7 @@ class ReplicateConfig(BaseConfig):
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict: ) -> dict:
headers = { headers = {
"Authorization": f"Token {api_key}", "Authorization": f"Token {api_key}",

View file

@ -260,6 +260,7 @@ class SagemakerConfig(BaseConfig):
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict: ) -> dict:
headers = {"Content-Type": "application/json"} headers = {"Content-Type": "application/json"}

View file

@ -48,6 +48,7 @@ class TritonConfig(BaseConfig):
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: Dict, optional_params: Dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Dict: ) -> Dict:
return {"Content-Type": "application/json"} return {"Content-Type": "application/json"}

View file

@ -43,6 +43,7 @@ class TritonEmbeddingConfig(BaseEmbeddingConfig):
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict: ) -> dict:
return {} return {}

View file

@ -808,6 +808,7 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: Dict, optional_params: Dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Dict: ) -> Dict:
default_headers = { default_headers = {
"Content-Type": "application/json", "Content-Type": "application/json",

View file

@ -82,6 +82,7 @@ class VoyageEmbeddingConfig(BaseEmbeddingConfig):
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict: ) -> dict:
if api_key is None: if api_key is None:
api_key = ( api_key = (

View file

@ -166,6 +166,7 @@ class IBMWatsonXMixin:
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: Dict, optional_params: Dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Dict: ) -> Dict:
default_headers = { default_headers = {
"Content-Type": "application/json", "Content-Type": "application/json",

View file

@ -1122,6 +1122,7 @@ def completion( # type: ignore # noqa: PLR0915
custom_prompt_dict=custom_prompt_dict, custom_prompt_dict=custom_prompt_dict,
litellm_metadata=kwargs.get("litellm_metadata"), litellm_metadata=kwargs.get("litellm_metadata"),
disable_add_transform_inline_image_block=disable_add_transform_inline_image_block, disable_add_transform_inline_image_block=disable_add_transform_inline_image_block,
drop_params=kwargs.get("drop_params"),
) )
logging.update_environment_variables( logging.update_environment_variables(
model=model, model=model,
@ -1347,39 +1348,28 @@ def completion( # type: ignore # noqa: PLR0915
if extra_headers is not None: if extra_headers is not None:
optional_params["extra_headers"] = extra_headers optional_params["extra_headers"] = extra_headers
## LOAD CONFIG - if set
config = litellm.AzureAIStudioConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > openai_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
## FOR COHERE ## FOR COHERE
if "command-r" in model: # make sure tool call in messages are str if "command-r" in model: # make sure tool call in messages are str
messages = stringify_json_tool_call_content(messages=messages) messages = stringify_json_tool_call_content(messages=messages)
## COMPLETION CALL ## COMPLETION CALL
try: try:
response = openai_chat_completions.completion( response = base_llm_http_handler.completion(
model=model, model=model,
messages=messages, messages=messages,
headers=headers, headers=headers,
model_response=model_response, model_response=model_response,
print_verbose=print_verbose,
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
acompletion=acompletion, acompletion=acompletion,
logging_obj=logging, logging_obj=logging,
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params, litellm_params=litellm_params,
logger_fn=logger_fn,
timeout=timeout, # type: ignore timeout=timeout, # type: ignore
custom_prompt_dict=custom_prompt_dict,
client=client, # pass AsyncOpenAI, OpenAI client client=client, # pass AsyncOpenAI, OpenAI client
organization=organization,
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
drop_params=non_default_params.get("drop_params"), encoding=encoding,
stream=stream,
) )
except Exception as e: except Exception as e:
## LOGGING - log the original exception returned ## LOGGING - log the original exception returned

View file

@ -1074,10 +1074,10 @@ class EmbeddingResponse(OpenAIObject):
class Logprobs(OpenAIObject): class Logprobs(OpenAIObject):
text_offset: List[int] text_offset: Optional[List[int]]
token_logprobs: List[Union[float, None]] token_logprobs: Optional[List[Union[float, None]]]
tokens: List[str] tokens: Optional[List[str]]
top_logprobs: List[Union[Dict[str, float], None]] top_logprobs: Optional[List[Union[Dict[str, float], None]]]
class TextChoices(OpenAIObject): class TextChoices(OpenAIObject):

View file

@ -2002,6 +2002,7 @@ def get_litellm_params(
custom_prompt_dict: Optional[dict] = None, custom_prompt_dict: Optional[dict] = None,
litellm_metadata: Optional[dict] = None, litellm_metadata: Optional[dict] = None,
disable_add_transform_inline_image_block: Optional[bool] = None, disable_add_transform_inline_image_block: Optional[bool] = None,
drop_params: Optional[bool] = None,
): ):
litellm_params = { litellm_params = {
"acompletion": acompletion, "acompletion": acompletion,
@ -2035,6 +2036,7 @@ def get_litellm_params(
"custom_prompt_dict": custom_prompt_dict, "custom_prompt_dict": custom_prompt_dict,
"litellm_metadata": litellm_metadata, "litellm_metadata": litellm_metadata,
"disable_add_transform_inline_image_block": disable_add_transform_inline_image_block, "disable_add_transform_inline_image_block": disable_add_transform_inline_image_block,
"drop_params": drop_params,
} }
return litellm_params return litellm_params
@ -6345,3 +6347,44 @@ def extract_duration_from_srt_or_vtt(srt_or_vtt_content: str) -> Optional[float]
durations.append(total_seconds) durations.append(total_seconds)
return max(durations) if durations else None return max(durations) if durations else None
import httpx
def _add_path_to_api_base(api_base: str, ending_path: str) -> str:
"""
Adds an ending path to an API base URL while preventing duplicate path segments.
Args:
api_base: Base URL string
ending_path: Path to append to the base URL
Returns:
Modified URL string with proper path handling
"""
original_url = httpx.URL(api_base)
base_url = original_url.copy_with(params={}) # Removes query params
base_path = original_url.path.rstrip("/")
end_path = ending_path.lstrip("/")
# Split paths into segments
base_segments = [s for s in base_path.split("/") if s]
end_segments = [s for s in end_path.split("/") if s]
# Find overlapping segments from the end of base_path and start of ending_path
final_segments = []
for i in range(len(base_segments)):
if base_segments[i:] == end_segments[: len(base_segments) - i]:
final_segments = base_segments[:i] + end_segments
break
else:
# No overlap found, just combine all segments
final_segments = base_segments + end_segments
# Construct the new path
modified_path = "/" + "/".join(final_segments)
modified_url = base_url.copy_with(path=modified_path)
# Re-add the original query parameters
return str(modified_url.copy_with(params=original_url.params))

View file

@ -28,6 +28,7 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
import litellm import litellm
from litellm import completion
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -51,18 +52,13 @@ async def test_azure_ai_with_image_url():
Test that Azure AI studio can handle image_url passed when content is a list containing both text and image_url Test that Azure AI studio can handle image_url passed when content is a list containing both text and image_url
""" """
from openai import AsyncOpenAI from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
litellm.set_verbose = True litellm.set_verbose = True
client = AsyncOpenAI( client = AsyncHTTPHandler()
api_key="fake-api-key",
base_url="https://Phi-3-5-vision-instruct-dcvov.eastus2.models.ai.azure.com",
)
with patch.object( with patch.object(client, "post") as mock_client:
client.chat.completions.with_raw_response, "create"
) as mock_client:
try: try:
await litellm.acompletion( await litellm.acompletion(
model="azure_ai/Phi-3-5-vision-instruct-dcvov", model="azure_ai/Phi-3-5-vision-instruct-dcvov",
@ -94,8 +90,9 @@ async def test_azure_ai_with_image_url():
# Verify the request was made # Verify the request was made
mock_client.assert_called_once() mock_client.assert_called_once()
print(f"mock_client.call_args.kwargs: {mock_client.call_args.kwargs}")
# Check the request body # Check the request body
request_body = mock_client.call_args.kwargs request_body = json.loads(mock_client.call_args.kwargs["data"])
assert request_body["model"] == "Phi-3-5-vision-instruct-dcvov" assert request_body["model"] == "Phi-3-5-vision-instruct-dcvov"
assert request_body["messages"] == [ assert request_body["messages"] == [
{ {
@ -111,3 +108,79 @@ async def test_azure_ai_with_image_url():
], ],
} }
] ]
@pytest.mark.parametrize(
"api_base, expected_url",
[
(
"https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview",
"https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview",
),
(
"https://litellm8397336933.services.ai.azure.com/models/chat/completions",
"https://litellm8397336933.services.ai.azure.com/models/chat/completions",
),
(
"https://litellm8397336933.services.ai.azure.com/models",
"https://litellm8397336933.services.ai.azure.com/models/chat/completions",
),
(
"https://litellm8397336933.services.ai.azure.com",
"https://litellm8397336933.services.ai.azure.com/models/chat/completions",
),
],
)
def test_azure_ai_services_handler(api_base, expected_url):
from litellm.llms.custom_httpx.http_handler import HTTPHandler
litellm.set_verbose = True
client = HTTPHandler()
with patch.object(client, "post") as mock_client:
try:
response = litellm.completion(
model="azure_ai/Meta-Llama-3.1-70B-Instruct",
messages=[{"role": "user", "content": "Hello, how are you?"}],
api_key="my-fake-api-key",
api_base=api_base,
client=client,
)
print(response)
except Exception as e:
print(f"Error: {e}")
mock_client.assert_called_once()
assert mock_client.call_args.kwargs["headers"]["api-key"] == "my-fake-api-key"
assert mock_client.call_args.kwargs["url"] == expected_url
def test_completion_azure_ai_command_r():
try:
import os
litellm.set_verbose = True
os.environ["AZURE_AI_API_BASE"] = os.getenv("AZURE_COHERE_API_BASE", "")
os.environ["AZURE_AI_API_KEY"] = os.getenv("AZURE_COHERE_API_KEY", "")
response = completion(
model="azure_ai/command-r-plus",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "What is the meaning of life?"}
],
}
],
) # type: ignore
assert "azure_ai" in response.model
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -132,34 +132,6 @@ def test_null_role_response():
assert response.choices[0].message.role == "assistant" assert response.choices[0].message.role == "assistant"
def test_completion_azure_ai_command_r():
try:
import os
litellm.set_verbose = True
os.environ["AZURE_AI_API_BASE"] = os.getenv("AZURE_COHERE_API_BASE", "")
os.environ["AZURE_AI_API_KEY"] = os.getenv("AZURE_COHERE_API_KEY", "")
response = completion(
model="azure_ai/command-r-plus",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "What is the meaning of life?"}
],
}
],
) # type: ignore
assert "azure_ai" in response.model
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_completion_azure_ai_mistral_invalid_params(sync_mode): async def test_completion_azure_ai_mistral_invalid_params(sync_mode):

View file

@ -199,4 +199,4 @@ def test_azure_global_standard_get_llm_provider():
api_base="https://my-deployment-francecentral.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview", api_base="https://my-deployment-francecentral.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview",
api_key="fake-api-key", api_key="fake-api-key",
) )
assert custom_llm_provider == "azure" assert custom_llm_provider == "azure_ai"

View file

@ -2954,6 +2954,7 @@ def test_azure_streaming_and_function_calling():
async def test_completion_azure_ai_mistral_invalid_params(sync_mode): async def test_completion_azure_ai_mistral_invalid_params(sync_mode):
try: try:
import os import os
from litellm import stream_chunk_builder
litellm.set_verbose = True litellm.set_verbose = True
@ -2968,15 +2969,21 @@ async def test_completion_azure_ai_mistral_invalid_params(sync_mode):
"drop_params": True, "drop_params": True,
"stream": True, "stream": True,
} }
chunks = []
if sync_mode: if sync_mode:
response: litellm.ModelResponse = completion(**data) # type: ignore response = completion(**data) # type: ignore
for chunk in response: for chunk in response:
print(chunk) print(chunk)
chunks.append(chunk)
else: else:
response: litellm.ModelResponse = await litellm.acompletion(**data) # type: ignore response = await litellm.acompletion(**data) # type: ignore
async for chunk in response: async for chunk in response:
print(chunk) print(chunk)
chunks.append(chunk)
print(f"chunks: {chunks}")
response = stream_chunk_builder(chunks=chunks)
assert response.choices[0].message.content is not None
except litellm.Timeout as e: except litellm.Timeout as e:
pass pass
except Exception as e: except Exception as e:

View file

@ -1252,3 +1252,19 @@ def test_fireworks_ai_document_inlining():
assert supports_pdf_input("fireworks_ai/llama-3.1-8b-instruct") is True assert supports_pdf_input("fireworks_ai/llama-3.1-8b-instruct") is True
assert supports_vision("fireworks_ai/llama-3.1-8b-instruct") is True assert supports_vision("fireworks_ai/llama-3.1-8b-instruct") is True
def test_logprobs_type():
from litellm.types.utils import Logprobs
logprobs = {
"text_offset": None,
"token_logprobs": None,
"tokens": None,
"top_logprobs": None,
}
logprobs = Logprobs(**logprobs)
assert logprobs.text_offset is None
assert logprobs.token_logprobs is None
assert logprobs.tokens is None
assert logprobs.top_logprobs is None