Litellm dev 12 30 2024 p2 (#7495)

* test(azure_openai_o1.py): initial commit with testing for azure openai o1 preview model

* fix(base_llm_unit_tests.py): handle azure o1 preview response format tests

skip as o1 on azure doesn't support tool calling yet

* fix: initial commit of azure o1 handler using openai caller

simplifies calling + allows fake streaming logic alr. implemented for openai to just work

* feat(azure/o1_handler.py): fake o1 streaming for azure o1 models

azure does not currently support streaming for o1

* feat(o1_transformation.py): support overriding 'should_fake_stream' on azure/o1 via 'supports_native_streaming' param on model info

enables user to toggle on when azure allows o1 streaming without needing to bump versions

* style(router.py): remove 'give feedback/get help' messaging when router is used

Prevents noisy messaging

Closes https://github.com/BerriAI/litellm/issues/5942

* fix(types/utils.py): handle none logprobs

Fixes https://github.com/BerriAI/litellm/issues/328

* fix(exception_mapping_utils.py): fix error str unbound error

* refactor(azure_ai/): move to openai_like chat completion handler

allows for easy swapping of api base url's (e.g. ai.services.com)

Fixes https://github.com/BerriAI/litellm/issues/7275

* refactor(azure_ai/): move to base llm http handler

* fix(azure_ai/): handle differing api endpoints

* fix(azure_ai/): make sure all unit tests are passing

* fix: fix linting errors

* fix: fix linting errors

* fix: fix linting error

* fix: fix linting errors

* fix(azure_ai/transformation.py): handle extra body param

* fix(azure_ai/transformation.py): fix max retries param handling

* fix: fix test

* test(test_azure_o1.py): fix test

* fix(llm_http_handler.py): support handling azure ai unprocessable entity error

* fix(llm_http_handler.py): handle sync invalid param error for azure ai

* fix(azure_ai/): streaming support with base_llm_http_handler

* fix(llm_http_handler.py): working sync stream calls with unprocessable entity handling for azure ai

* fix: fix linting errors

* fix(llm_http_handler.py): fix linting error

* fix(azure_ai/): handle cohere tool call invalid index param error
This commit is contained in:
Krish Dholakia 2025-01-01 18:57:29 -08:00 committed by GitHub
parent 0f1b298fe0
commit 0120176541
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
42 changed files with 638 additions and 192 deletions

View file

@ -8,7 +8,7 @@ import litellm
import litellm.litellm_core_utils
import litellm.types
import litellm.types.utils
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
from litellm.llms.custom_httpx.http_handler import (
@ -30,6 +30,114 @@ else:
class BaseLLMHTTPHandler:
async def _make_common_async_call(
self,
async_httpx_client: AsyncHTTPHandler,
provider_config: BaseConfig,
api_base: str,
headers: dict,
data: dict,
timeout: Union[float, httpx.Timeout],
litellm_params: dict,
stream: bool = False,
) -> httpx.Response:
"""Common implementation across stream + non-stream calls. Meant to ensure consistent error-handling."""
max_retry_on_unprocessable_entity_error = (
provider_config.max_retry_on_unprocessable_entity_error
)
response: Optional[httpx.Response] = None
for i in range(max(max_retry_on_unprocessable_entity_error, 1)):
try:
response = await async_httpx_client.post(
url=api_base,
headers=headers,
data=json.dumps(data),
timeout=timeout,
stream=stream,
)
except httpx.HTTPStatusError as e:
hit_max_retry = i + 1 == max_retry_on_unprocessable_entity_error
should_retry = provider_config.should_retry_llm_api_inside_llm_translation_on_http_error(
e=e, litellm_params=litellm_params
)
if should_retry and not hit_max_retry:
data = (
provider_config.transform_request_on_unprocessable_entity_error(
e=e, request_data=data
)
)
continue
else:
raise self._handle_error(e=e, provider_config=provider_config)
except Exception as e:
raise self._handle_error(e=e, provider_config=provider_config)
break
if response is None:
raise provider_config.get_error_class(
error_message="No response from the API",
status_code=422, # don't retry on this error
headers={},
)
return response
def _make_common_sync_call(
self,
sync_httpx_client: HTTPHandler,
provider_config: BaseConfig,
api_base: str,
headers: dict,
data: dict,
timeout: Union[float, httpx.Timeout],
litellm_params: dict,
stream: bool = False,
) -> httpx.Response:
max_retry_on_unprocessable_entity_error = (
provider_config.max_retry_on_unprocessable_entity_error
)
response: Optional[httpx.Response] = None
for i in range(max(max_retry_on_unprocessable_entity_error, 1)):
try:
response = sync_httpx_client.post(
url=api_base,
headers=headers,
data=json.dumps(data),
timeout=timeout,
stream=stream,
)
except httpx.HTTPStatusError as e:
hit_max_retry = i + 1 == max_retry_on_unprocessable_entity_error
should_retry = provider_config.should_retry_llm_api_inside_llm_translation_on_http_error(
e=e, litellm_params=litellm_params
)
if should_retry and not hit_max_retry:
data = (
provider_config.transform_request_on_unprocessable_entity_error(
e=e, request_data=data
)
)
continue
else:
raise self._handle_error(e=e, provider_config=provider_config)
except Exception as e:
raise self._handle_error(e=e, provider_config=provider_config)
break
if response is None:
raise provider_config.get_error_class(
error_message="No response from the API",
status_code=422, # don't retry on this error
headers={},
)
return response
async def async_completion(
self,
custom_llm_provider: str,
@ -55,15 +163,16 @@ class BaseLLMHTTPHandler:
else:
async_httpx_client = client
try:
response = await async_httpx_client.post(
url=api_base,
headers=headers,
data=json.dumps(data),
timeout=timeout,
)
except Exception as e:
raise self._handle_error(e=e, provider_config=provider_config)
response = await self._make_common_async_call(
async_httpx_client=async_httpx_client,
provider_config=provider_config,
api_base=api_base,
headers=headers,
data=data,
timeout=timeout,
litellm_params=litellm_params,
stream=False,
)
return provider_config.transform_response(
model=model,
raw_response=response,
@ -93,7 +202,7 @@ class BaseLLMHTTPHandler:
stream: Optional[bool] = False,
fake_stream: bool = False,
api_key: Optional[str] = None,
headers={},
headers: Optional[dict] = {},
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
):
provider_config = ProviderConfigManager.get_provider_chat_config(
@ -102,10 +211,11 @@ class BaseLLMHTTPHandler:
# get config from model, custom llm provider
headers = provider_config.validate_environment(
api_key=api_key,
headers=headers,
headers=headers or {},
model=model,
messages=messages,
optional_params=optional_params,
api_base=api_base,
)
api_base = provider_config.get_complete_url(
@ -154,6 +264,7 @@ class BaseLLMHTTPHandler:
if client is not None and isinstance(client, AsyncHTTPHandler)
else None
),
litellm_params=litellm_params,
)
else:
@ -186,7 +297,7 @@ class BaseLLMHTTPHandler:
provider_config=provider_config,
api_base=api_base,
headers=headers, # type: ignore
data=json.dumps(data),
data=data,
model=model,
messages=messages,
logging_obj=logging_obj,
@ -197,6 +308,7 @@ class BaseLLMHTTPHandler:
if client is not None and isinstance(client, HTTPHandler)
else None
),
litellm_params=litellm_params,
)
return CustomStreamWrapper(
completion_stream=completion_stream,
@ -210,19 +322,15 @@ class BaseLLMHTTPHandler:
else:
sync_httpx_client = client
try:
response = sync_httpx_client.post(
url=api_base,
headers=headers,
data=json.dumps(data),
timeout=timeout,
)
except Exception as e:
raise self._handle_error(
e=e,
provider_config=provider_config,
)
response = self._make_common_sync_call(
sync_httpx_client=sync_httpx_client,
provider_config=provider_config,
api_base=api_base,
headers=headers,
data=data,
timeout=timeout,
litellm_params=litellm_params,
)
return provider_config.transform_response(
model=model,
raw_response=response,
@ -241,44 +349,33 @@ class BaseLLMHTTPHandler:
provider_config: BaseConfig,
api_base: str,
headers: dict,
data: str,
data: dict,
model: str,
messages: list,
logging_obj,
timeout: Optional[Union[float, httpx.Timeout]],
litellm_params: dict,
timeout: Union[float, httpx.Timeout],
fake_stream: bool = False,
client: Optional[HTTPHandler] = None,
) -> Tuple[Any, httpx.Headers]:
) -> Tuple[Any, dict]:
if client is None or not isinstance(client, HTTPHandler):
sync_httpx_client = _get_httpx_client()
else:
sync_httpx_client = client
try:
stream = True
if fake_stream is True:
stream = False
response = sync_httpx_client.post(
api_base, headers=headers, data=data, timeout=timeout, stream=stream
)
except httpx.HTTPStatusError as e:
raise self._handle_error(
e=e,
provider_config=provider_config,
)
except Exception as e:
for exception in litellm.LITELLM_EXCEPTION_TYPES:
if isinstance(e, exception):
raise e
raise self._handle_error(
e=e,
provider_config=provider_config,
)
stream = True
if fake_stream is True:
stream = False
if response.status_code != 200:
raise BaseLLMException(
status_code=response.status_code,
message=str(response.read()),
)
response = self._make_common_sync_call(
sync_httpx_client=sync_httpx_client,
provider_config=provider_config,
api_base=api_base,
headers=headers,
data=data,
timeout=timeout,
litellm_params=litellm_params,
stream=stream,
)
if fake_stream is True:
completion_stream = provider_config.get_model_response_iterator(
@ -297,7 +394,7 @@ class BaseLLMHTTPHandler:
additional_args={"complete_input_dict": data},
)
return completion_stream, response.headers
return completion_stream, dict(response.headers)
async def acompletion_stream_function(
self,
@ -310,6 +407,7 @@ class BaseLLMHTTPHandler:
timeout: Union[float, httpx.Timeout],
logging_obj: LiteLLMLoggingObj,
data: dict,
litellm_params: dict,
fake_stream: bool = False,
client: Optional[AsyncHTTPHandler] = None,
):
@ -318,12 +416,13 @@ class BaseLLMHTTPHandler:
provider_config=provider_config,
api_base=api_base,
headers=headers,
data=json.dumps(data),
data=data,
messages=messages,
logging_obj=logging_obj,
timeout=timeout,
fake_stream=fake_stream,
client=client,
litellm_params=litellm_params,
)
streamwrapper = CustomStreamWrapper(
completion_stream=completion_stream,
@ -339,10 +438,11 @@ class BaseLLMHTTPHandler:
provider_config: BaseConfig,
api_base: str,
headers: dict,
data: str,
data: dict,
messages: list,
logging_obj: LiteLLMLoggingObj,
timeout: Optional[Union[float, httpx.Timeout]],
timeout: Union[float, httpx.Timeout],
litellm_params: dict,
fake_stream: bool = False,
client: Optional[AsyncHTTPHandler] = None,
) -> Tuple[Any, httpx.Headers]:
@ -355,29 +455,18 @@ class BaseLLMHTTPHandler:
stream = True
if fake_stream is True:
stream = False
try:
response = await async_httpx_client.post(
api_base, headers=headers, data=data, stream=stream, timeout=timeout
)
except httpx.HTTPStatusError as e:
raise self._handle_error(
e=e,
provider_config=provider_config,
)
except Exception as e:
for exception in litellm.LITELLM_EXCEPTION_TYPES:
if isinstance(e, exception):
raise e
raise self._handle_error(
e=e,
provider_config=provider_config,
)
if response.status_code != 200:
raise BaseLLMException(
status_code=response.status_code,
message=str(response.read()),
)
response = await self._make_common_async_call(
async_httpx_client=async_httpx_client,
provider_config=provider_config,
api_base=api_base,
headers=headers,
data=data,
timeout=timeout,
litellm_params=litellm_params,
stream=stream,
)
if fake_stream is True:
completion_stream = provider_config.get_model_response_iterator(
streaming_response=response.json(), sync_stream=False