Litellm dev 12 30 2024 p2 (#7495)

* test(azure_openai_o1.py): initial commit with testing for azure openai o1 preview model

* fix(base_llm_unit_tests.py): handle azure o1 preview response format tests

skip as o1 on azure doesn't support tool calling yet

* fix: initial commit of azure o1 handler using openai caller

simplifies calling + allows fake streaming logic alr. implemented for openai to just work

* feat(azure/o1_handler.py): fake o1 streaming for azure o1 models

azure does not currently support streaming for o1

* feat(o1_transformation.py): support overriding 'should_fake_stream' on azure/o1 via 'supports_native_streaming' param on model info

enables user to toggle on when azure allows o1 streaming without needing to bump versions

* style(router.py): remove 'give feedback/get help' messaging when router is used

Prevents noisy messaging

Closes https://github.com/BerriAI/litellm/issues/5942

* fix(types/utils.py): handle none logprobs

Fixes https://github.com/BerriAI/litellm/issues/328

* fix(exception_mapping_utils.py): fix error str unbound error

* refactor(azure_ai/): move to openai_like chat completion handler

allows for easy swapping of api base url's (e.g. ai.services.com)

Fixes https://github.com/BerriAI/litellm/issues/7275

* refactor(azure_ai/): move to base llm http handler

* fix(azure_ai/): handle differing api endpoints

* fix(azure_ai/): make sure all unit tests are passing

* fix: fix linting errors

* fix: fix linting errors

* fix: fix linting error

* fix: fix linting errors

* fix(azure_ai/transformation.py): handle extra body param

* fix(azure_ai/transformation.py): fix max retries param handling

* fix: fix test

* test(test_azure_o1.py): fix test

* fix(llm_http_handler.py): support handling azure ai unprocessable entity error

* fix(llm_http_handler.py): handle sync invalid param error for azure ai

* fix(azure_ai/): streaming support with base_llm_http_handler

* fix(llm_http_handler.py): working sync stream calls with unprocessable entity handling for azure ai

* fix: fix linting errors

* fix(llm_http_handler.py): fix linting error

* fix(azure_ai/): handle cohere tool call invalid index param error
This commit is contained in:
Krish Dholakia 2025-01-01 18:57:29 -08:00 committed by GitHub
parent 0f1b298fe0
commit 0120176541
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
42 changed files with 638 additions and 192 deletions

View file

@ -2,9 +2,11 @@ import hashlib
import types
from typing import (
Any,
AsyncIterator,
Callable,
Coroutine,
Iterable,
Iterator,
List,
Literal,
Optional,
@ -24,10 +26,16 @@ import litellm
from litellm import LlmProviders
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.llms.bedrock.chat.invoke_handler import MockResponseIterator
from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS
from litellm.types.utils import EmbeddingResponse, ImageResponse, ModelResponse
from litellm.types.utils import (
EmbeddingResponse,
ImageResponse,
ModelResponse,
ModelResponseStream,
)
from litellm.utils import (
CustomStreamWrapper,
ProviderConfigManager,
@ -36,7 +44,6 @@ from litellm.utils import (
from ...types.llms.openai import *
from ..base import BaseLLM
from .chat.gpt_transformation import OpenAIGPTConfig
from .common_utils import OpenAIError, drop_params_from_unprocessable_entity_error
@ -232,6 +239,7 @@ class OpenAIConfig(BaseConfig):
litellm_params: dict,
headers: dict,
) -> dict:
messages = self._transform_messages(messages=messages, model=model)
return {"model": model, "messages": messages, **optional_params}
def transform_response(
@ -248,10 +256,21 @@ class OpenAIConfig(BaseConfig):
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
raise NotImplementedError(
"OpenAI handler does this transformation as it uses the OpenAI SDK."
logging_obj.post_call(original_response=raw_response.text)
logging_obj.model_call_details["response_headers"] = raw_response.headers
final_response_obj = cast(
ModelResponse,
convert_to_model_response_object(
response_object=raw_response.json(),
model_response_object=model_response,
hidden_params={"headers": raw_response.headers},
_response_headers=dict(raw_response.headers),
),
)
return final_response_obj
def validate_environment(
self,
headers: dict,
@ -259,12 +278,37 @@ class OpenAIConfig(BaseConfig):
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
raise NotImplementedError(
"OpenAI handler does this validation as it uses the OpenAI SDK."
return {
"Authorization": f"Bearer {api_key}",
**headers,
}
def get_model_response_iterator(
self,
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
sync_stream: bool,
json_mode: Optional[bool] = False,
) -> Any:
return OpenAIChatCompletionResponseIterator(
streaming_response=streaming_response,
sync_stream=sync_stream,
json_mode=json_mode,
)
class OpenAIChatCompletionResponseIterator(BaseModelResponseIterator):
def chunk_parser(self, chunk: dict) -> ModelResponseStream:
"""
{'choices': [{'delta': {'content': '', 'role': 'assistant'}, 'finish_reason': None, 'index': 0, 'logprobs': None}], 'created': 1735763082, 'id': 'a83a2b0fbfaf4aab9c2c93cb8ba346d7', 'model': 'mistral-large', 'object': 'chat.completion.chunk'}
"""
try:
return ModelResponseStream(**chunk)
except Exception as e:
raise e
class OpenAIChatCompletion(BaseLLM):
def __init__(self) -> None:
@ -473,14 +517,6 @@ class OpenAIChatCompletion(BaseLLM):
if custom_llm_provider is not None and custom_llm_provider != "openai":
model_response.model = f"{custom_llm_provider}/{model}"
if messages is not None and provider_config is not None:
if isinstance(provider_config, OpenAIGPTConfig) or isinstance(
provider_config, OpenAIConfig
): # [TODO]: remove. no longer needed as .transform_request can just handle this.
messages = provider_config._transform_messages(
messages=messages, model=model
)
for _ in range(
2
): # if call fails due to alternating messages, retry with reformatted message
@ -647,12 +683,10 @@ class OpenAIChatCompletion(BaseLLM):
new_messages = messages
new_messages.append({"role": "user", "content": ""})
messages = new_messages
elif (
"unknown field: parameter index is not a valid field" in str(e)
) and "tools" in data:
litellm.remove_index_from_tool_calls(
tool_calls=data["tools"], messages=messages
)
elif "unknown field: parameter index is not a valid field" in str(
e
):
litellm.remove_index_from_tool_calls(messages=messages)
else:
raise e
except OpenAIError as e: