mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Litellm dev 12 30 2024 p2 (#7495)
* test(azure_openai_o1.py): initial commit with testing for azure openai o1 preview model * fix(base_llm_unit_tests.py): handle azure o1 preview response format tests skip as o1 on azure doesn't support tool calling yet * fix: initial commit of azure o1 handler using openai caller simplifies calling + allows fake streaming logic alr. implemented for openai to just work * feat(azure/o1_handler.py): fake o1 streaming for azure o1 models azure does not currently support streaming for o1 * feat(o1_transformation.py): support overriding 'should_fake_stream' on azure/o1 via 'supports_native_streaming' param on model info enables user to toggle on when azure allows o1 streaming without needing to bump versions * style(router.py): remove 'give feedback/get help' messaging when router is used Prevents noisy messaging Closes https://github.com/BerriAI/litellm/issues/5942 * fix(types/utils.py): handle none logprobs Fixes https://github.com/BerriAI/litellm/issues/328 * fix(exception_mapping_utils.py): fix error str unbound error * refactor(azure_ai/): move to openai_like chat completion handler allows for easy swapping of api base url's (e.g. ai.services.com) Fixes https://github.com/BerriAI/litellm/issues/7275 * refactor(azure_ai/): move to base llm http handler * fix(azure_ai/): handle differing api endpoints * fix(azure_ai/): make sure all unit tests are passing * fix: fix linting errors * fix: fix linting errors * fix: fix linting error * fix: fix linting errors * fix(azure_ai/transformation.py): handle extra body param * fix(azure_ai/transformation.py): fix max retries param handling * fix: fix test * test(test_azure_o1.py): fix test * fix(llm_http_handler.py): support handling azure ai unprocessable entity error * fix(llm_http_handler.py): handle sync invalid param error for azure ai * fix(azure_ai/): streaming support with base_llm_http_handler * fix(llm_http_handler.py): working sync stream calls with unprocessable entity handling for azure ai * fix: fix linting errors * fix(llm_http_handler.py): fix linting error * fix(azure_ai/): handle cohere tool call invalid index param error
This commit is contained in:
parent
0f1b298fe0
commit
0120176541
42 changed files with 638 additions and 192 deletions
|
@ -8,7 +8,7 @@ import litellm
|
|||
import litellm.litellm_core_utils
|
||||
import litellm.types
|
||||
import litellm.types.utils
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
|
||||
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
|
@ -30,6 +30,114 @@ else:
|
|||
|
||||
|
||||
class BaseLLMHTTPHandler:
|
||||
|
||||
async def _make_common_async_call(
|
||||
self,
|
||||
async_httpx_client: AsyncHTTPHandler,
|
||||
provider_config: BaseConfig,
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: dict,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
litellm_params: dict,
|
||||
stream: bool = False,
|
||||
) -> httpx.Response:
|
||||
"""Common implementation across stream + non-stream calls. Meant to ensure consistent error-handling."""
|
||||
max_retry_on_unprocessable_entity_error = (
|
||||
provider_config.max_retry_on_unprocessable_entity_error
|
||||
)
|
||||
|
||||
response: Optional[httpx.Response] = None
|
||||
for i in range(max(max_retry_on_unprocessable_entity_error, 1)):
|
||||
try:
|
||||
response = await async_httpx_client.post(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
timeout=timeout,
|
||||
stream=stream,
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
hit_max_retry = i + 1 == max_retry_on_unprocessable_entity_error
|
||||
should_retry = provider_config.should_retry_llm_api_inside_llm_translation_on_http_error(
|
||||
e=e, litellm_params=litellm_params
|
||||
)
|
||||
if should_retry and not hit_max_retry:
|
||||
data = (
|
||||
provider_config.transform_request_on_unprocessable_entity_error(
|
||||
e=e, request_data=data
|
||||
)
|
||||
)
|
||||
continue
|
||||
else:
|
||||
raise self._handle_error(e=e, provider_config=provider_config)
|
||||
except Exception as e:
|
||||
raise self._handle_error(e=e, provider_config=provider_config)
|
||||
break
|
||||
|
||||
if response is None:
|
||||
raise provider_config.get_error_class(
|
||||
error_message="No response from the API",
|
||||
status_code=422, # don't retry on this error
|
||||
headers={},
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _make_common_sync_call(
|
||||
self,
|
||||
sync_httpx_client: HTTPHandler,
|
||||
provider_config: BaseConfig,
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: dict,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
litellm_params: dict,
|
||||
stream: bool = False,
|
||||
) -> httpx.Response:
|
||||
|
||||
max_retry_on_unprocessable_entity_error = (
|
||||
provider_config.max_retry_on_unprocessable_entity_error
|
||||
)
|
||||
|
||||
response: Optional[httpx.Response] = None
|
||||
|
||||
for i in range(max(max_retry_on_unprocessable_entity_error, 1)):
|
||||
try:
|
||||
response = sync_httpx_client.post(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
timeout=timeout,
|
||||
stream=stream,
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
hit_max_retry = i + 1 == max_retry_on_unprocessable_entity_error
|
||||
should_retry = provider_config.should_retry_llm_api_inside_llm_translation_on_http_error(
|
||||
e=e, litellm_params=litellm_params
|
||||
)
|
||||
if should_retry and not hit_max_retry:
|
||||
data = (
|
||||
provider_config.transform_request_on_unprocessable_entity_error(
|
||||
e=e, request_data=data
|
||||
)
|
||||
)
|
||||
continue
|
||||
else:
|
||||
raise self._handle_error(e=e, provider_config=provider_config)
|
||||
except Exception as e:
|
||||
raise self._handle_error(e=e, provider_config=provider_config)
|
||||
break
|
||||
|
||||
if response is None:
|
||||
raise provider_config.get_error_class(
|
||||
error_message="No response from the API",
|
||||
status_code=422, # don't retry on this error
|
||||
headers={},
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
async def async_completion(
|
||||
self,
|
||||
custom_llm_provider: str,
|
||||
|
@ -55,15 +163,16 @@ class BaseLLMHTTPHandler:
|
|||
else:
|
||||
async_httpx_client = client
|
||||
|
||||
try:
|
||||
response = await async_httpx_client.post(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
timeout=timeout,
|
||||
)
|
||||
except Exception as e:
|
||||
raise self._handle_error(e=e, provider_config=provider_config)
|
||||
response = await self._make_common_async_call(
|
||||
async_httpx_client=async_httpx_client,
|
||||
provider_config=provider_config,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=data,
|
||||
timeout=timeout,
|
||||
litellm_params=litellm_params,
|
||||
stream=False,
|
||||
)
|
||||
return provider_config.transform_response(
|
||||
model=model,
|
||||
raw_response=response,
|
||||
|
@ -93,7 +202,7 @@ class BaseLLMHTTPHandler:
|
|||
stream: Optional[bool] = False,
|
||||
fake_stream: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
headers={},
|
||||
headers: Optional[dict] = {},
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
):
|
||||
provider_config = ProviderConfigManager.get_provider_chat_config(
|
||||
|
@ -102,10 +211,11 @@ class BaseLLMHTTPHandler:
|
|||
# get config from model, custom llm provider
|
||||
headers = provider_config.validate_environment(
|
||||
api_key=api_key,
|
||||
headers=headers,
|
||||
headers=headers or {},
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
api_base = provider_config.get_complete_url(
|
||||
|
@ -154,6 +264,7 @@ class BaseLLMHTTPHandler:
|
|||
if client is not None and isinstance(client, AsyncHTTPHandler)
|
||||
else None
|
||||
),
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
else:
|
||||
|
@ -186,7 +297,7 @@ class BaseLLMHTTPHandler:
|
|||
provider_config=provider_config,
|
||||
api_base=api_base,
|
||||
headers=headers, # type: ignore
|
||||
data=json.dumps(data),
|
||||
data=data,
|
||||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
|
@ -197,6 +308,7 @@ class BaseLLMHTTPHandler:
|
|||
if client is not None and isinstance(client, HTTPHandler)
|
||||
else None
|
||||
),
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
return CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
|
@ -210,19 +322,15 @@ class BaseLLMHTTPHandler:
|
|||
else:
|
||||
sync_httpx_client = client
|
||||
|
||||
try:
|
||||
response = sync_httpx_client.post(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
timeout=timeout,
|
||||
)
|
||||
except Exception as e:
|
||||
raise self._handle_error(
|
||||
e=e,
|
||||
provider_config=provider_config,
|
||||
)
|
||||
|
||||
response = self._make_common_sync_call(
|
||||
sync_httpx_client=sync_httpx_client,
|
||||
provider_config=provider_config,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=data,
|
||||
timeout=timeout,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
return provider_config.transform_response(
|
||||
model=model,
|
||||
raw_response=response,
|
||||
|
@ -241,44 +349,33 @@ class BaseLLMHTTPHandler:
|
|||
provider_config: BaseConfig,
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: str,
|
||||
data: dict,
|
||||
model: str,
|
||||
messages: list,
|
||||
logging_obj,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
litellm_params: dict,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
fake_stream: bool = False,
|
||||
client: Optional[HTTPHandler] = None,
|
||||
) -> Tuple[Any, httpx.Headers]:
|
||||
) -> Tuple[Any, dict]:
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
sync_httpx_client = _get_httpx_client()
|
||||
else:
|
||||
sync_httpx_client = client
|
||||
try:
|
||||
stream = True
|
||||
if fake_stream is True:
|
||||
stream = False
|
||||
response = sync_httpx_client.post(
|
||||
api_base, headers=headers, data=data, timeout=timeout, stream=stream
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise self._handle_error(
|
||||
e=e,
|
||||
provider_config=provider_config,
|
||||
)
|
||||
except Exception as e:
|
||||
for exception in litellm.LITELLM_EXCEPTION_TYPES:
|
||||
if isinstance(e, exception):
|
||||
raise e
|
||||
raise self._handle_error(
|
||||
e=e,
|
||||
provider_config=provider_config,
|
||||
)
|
||||
stream = True
|
||||
if fake_stream is True:
|
||||
stream = False
|
||||
|
||||
if response.status_code != 200:
|
||||
raise BaseLLMException(
|
||||
status_code=response.status_code,
|
||||
message=str(response.read()),
|
||||
)
|
||||
response = self._make_common_sync_call(
|
||||
sync_httpx_client=sync_httpx_client,
|
||||
provider_config=provider_config,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=data,
|
||||
timeout=timeout,
|
||||
litellm_params=litellm_params,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
if fake_stream is True:
|
||||
completion_stream = provider_config.get_model_response_iterator(
|
||||
|
@ -297,7 +394,7 @@ class BaseLLMHTTPHandler:
|
|||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
return completion_stream, response.headers
|
||||
return completion_stream, dict(response.headers)
|
||||
|
||||
async def acompletion_stream_function(
|
||||
self,
|
||||
|
@ -310,6 +407,7 @@ class BaseLLMHTTPHandler:
|
|||
timeout: Union[float, httpx.Timeout],
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
data: dict,
|
||||
litellm_params: dict,
|
||||
fake_stream: bool = False,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
):
|
||||
|
@ -318,12 +416,13 @@ class BaseLLMHTTPHandler:
|
|||
provider_config=provider_config,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
data=data,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
timeout=timeout,
|
||||
fake_stream=fake_stream,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
|
@ -339,10 +438,11 @@ class BaseLLMHTTPHandler:
|
|||
provider_config: BaseConfig,
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: str,
|
||||
data: dict,
|
||||
messages: list,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
litellm_params: dict,
|
||||
fake_stream: bool = False,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
) -> Tuple[Any, httpx.Headers]:
|
||||
|
@ -355,29 +455,18 @@ class BaseLLMHTTPHandler:
|
|||
stream = True
|
||||
if fake_stream is True:
|
||||
stream = False
|
||||
try:
|
||||
response = await async_httpx_client.post(
|
||||
api_base, headers=headers, data=data, stream=stream, timeout=timeout
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise self._handle_error(
|
||||
e=e,
|
||||
provider_config=provider_config,
|
||||
)
|
||||
except Exception as e:
|
||||
for exception in litellm.LITELLM_EXCEPTION_TYPES:
|
||||
if isinstance(e, exception):
|
||||
raise e
|
||||
raise self._handle_error(
|
||||
e=e,
|
||||
provider_config=provider_config,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise BaseLLMException(
|
||||
status_code=response.status_code,
|
||||
message=str(response.read()),
|
||||
)
|
||||
response = await self._make_common_async_call(
|
||||
async_httpx_client=async_httpx_client,
|
||||
provider_config=provider_config,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=data,
|
||||
timeout=timeout,
|
||||
litellm_params=litellm_params,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
if fake_stream is True:
|
||||
completion_stream = provider_config.get_model_response_iterator(
|
||||
streaming_response=response.json(), sync_stream=False
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue