Litellm dev 12 30 2024 p2 (#7495)

* test(azure_openai_o1.py): initial commit with testing for azure openai o1 preview model

* fix(base_llm_unit_tests.py): handle azure o1 preview response format tests

skip as o1 on azure doesn't support tool calling yet

* fix: initial commit of azure o1 handler using openai caller

simplifies calling + allows fake streaming logic alr. implemented for openai to just work

* feat(azure/o1_handler.py): fake o1 streaming for azure o1 models

azure does not currently support streaming for o1

* feat(o1_transformation.py): support overriding 'should_fake_stream' on azure/o1 via 'supports_native_streaming' param on model info

enables user to toggle on when azure allows o1 streaming without needing to bump versions

* style(router.py): remove 'give feedback/get help' messaging when router is used

Prevents noisy messaging

Closes https://github.com/BerriAI/litellm/issues/5942

* fix(types/utils.py): handle none logprobs

Fixes https://github.com/BerriAI/litellm/issues/328

* fix(exception_mapping_utils.py): fix error str unbound error

* refactor(azure_ai/): move to openai_like chat completion handler

allows for easy swapping of api base url's (e.g. ai.services.com)

Fixes https://github.com/BerriAI/litellm/issues/7275

* refactor(azure_ai/): move to base llm http handler

* fix(azure_ai/): handle differing api endpoints

* fix(azure_ai/): make sure all unit tests are passing

* fix: fix linting errors

* fix: fix linting errors

* fix: fix linting error

* fix: fix linting errors

* fix(azure_ai/transformation.py): handle extra body param

* fix(azure_ai/transformation.py): fix max retries param handling

* fix: fix test

* test(test_azure_o1.py): fix test

* fix(llm_http_handler.py): support handling azure ai unprocessable entity error

* fix(llm_http_handler.py): handle sync invalid param error for azure ai

* fix(azure_ai/): streaming support with base_llm_http_handler

* fix(llm_http_handler.py): working sync stream calls with unprocessable entity handling for azure ai

* fix: fix linting errors

* fix(llm_http_handler.py): fix linting error

* fix(azure_ai/): handle cohere tool call invalid index param error
This commit is contained in:
Krish Dholakia 2025-01-01 18:57:29 -08:00 committed by GitHub
parent b5e14ef52a
commit b0f570ee16
42 changed files with 638 additions and 192 deletions

View file

@ -1,10 +1,11 @@
# What is this?
## Helper utilities
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any, List, Optional, Union
import httpx
from litellm._logging import verbose_logger
from litellm.types.llms.openai import AllMessageValues, ChatCompletionToolParam
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
@ -53,17 +54,18 @@ def map_finish_reason(
return finish_reason
def remove_index_from_tool_calls(messages, tool_calls):
for tool_call in tool_calls:
if "index" in tool_call:
tool_call.pop("index")
def remove_index_from_tool_calls(
messages: Optional[List[AllMessageValues]],
):
if messages is not None:
for message in messages:
if "tool_calls" in message:
tool_calls = message["tool_calls"]
for tool_call in tool_calls:
if "index" in tool_call:
tool_call.pop("index")
_tool_calls = message.get("tool_calls")
if _tool_calls is not None and isinstance(_tool_calls, list):
for tool_call in _tool_calls:
if (
isinstance(tool_call, dict) and "index" in tool_call
): # Type guard to ensure it's a dict
tool_call.pop("index", None)
return

View file

@ -148,11 +148,10 @@ def exception_type( # type: ignore # noqa: PLR0915
original_exception=original_exception
)
try:
error_str = str(original_exception)
if model:
if hasattr(original_exception, "message"):
error_str = str(original_exception.message)
else:
error_str = str(original_exception)
if isinstance(original_exception, BaseException):
exception_type = type(original_exception).__name__
else:

View file

@ -741,6 +741,7 @@ class AnthropicConfig(BaseConfig):
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Dict:
if api_key is None:
raise litellm.AuthenticationError(

View file

@ -85,6 +85,7 @@ class AnthropicTextConfig(BaseConfig):
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
if api_key is None:
raise ValueError(

View file

@ -283,6 +283,7 @@ class AzureOpenAIConfig(BaseConfig):
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
raise NotImplementedError(
"Azure OpenAI has custom logic for validating environment, as it uses the OpenAI SDK."

View file

@ -1,4 +1,7 @@
from typing import List, Optional, Tuple
from typing import Any, List, Optional, Tuple, cast
import httpx
from httpx import Response
import litellm
from litellm._logging import verbose_logger
@ -6,13 +9,81 @@ from litellm.litellm_core_utils.prompt_templates.common_utils import (
_audio_or_image_in_message_content,
convert_content_list_to_str,
)
from litellm.llms.base_llm.chat.transformation import LiteLLMLoggingObj
from litellm.llms.openai.common_utils import drop_params_from_unprocessable_entity_error
from litellm.llms.openai.openai import OpenAIConfig
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ProviderField
from litellm.types.llms.openai import AllMessageValues, ChatCompletionToolParam
from litellm.types.utils import ModelResponse, ProviderField
from litellm.utils import _add_path_to_api_base
class AzureAIStudioConfig(OpenAIConfig):
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
if api_base and "services.ai.azure.com" in api_base:
headers["api-key"] = api_key
else:
headers["Authorization"] = f"Bearer {api_key}"
return headers
def get_complete_url(
self,
api_base: str,
model: str,
optional_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
Constructs a complete URL for the API request.
Args:
- api_base: Base URL, e.g.,
"https://litellm8397336933.services.ai.azure.com"
OR
"https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview"
- model: Model name.
- optional_params: Additional query parameters, including "api_version".
- stream: If streaming is required (optional).
Returns:
- A complete URL string, e.g.,
"https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview"
"""
original_url = httpx.URL(api_base)
# Extract api_version or use default
api_version = cast(Optional[str], optional_params.get("api_version"))
# Check if 'api-version' is already present
if "api-version" not in original_url.params and api_version:
# Add api_version to optional_params
original_url.params["api-version"] = api_version
# Add the path to the base URL
if "services.ai.azure.com" in api_base:
new_url = _add_path_to_api_base(
api_base=api_base, ending_path="/models/chat/completions"
)
else:
new_url = _add_path_to_api_base(
api_base=api_base, ending_path="/chat/completions"
)
# Convert optional_params to query parameters
query_params = original_url.params
final_url = httpx.URL(new_url).copy_with(params=query_params)
return str(final_url)
def get_required_params(self) -> List[ProviderField]:
"""For a given provider, return it's required fields with a description"""
return [
@ -62,8 +133,6 @@ class AzureAIStudioConfig(OpenAIConfig):
):
return True
if api_base and "services.ai.azure" in api_base:
return True
except Exception:
return False
return False
@ -86,3 +155,81 @@ class AzureAIStudioConfig(OpenAIConfig):
)
custom_llm_provider = "azure"
return api_base, dynamic_api_key, custom_llm_provider
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
extra_body = optional_params.pop("extra_body", {})
if extra_body and isinstance(extra_body, dict):
optional_params.update(extra_body)
optional_params.pop("max_retries", None)
return super().transform_request(
model, messages, optional_params, litellm_params, headers
)
def transform_response(
self,
model: str,
raw_response: Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
model_response.model = f"azure_ai/{model}"
return super().transform_response(
model=model,
raw_response=raw_response,
model_response=model_response,
logging_obj=logging_obj,
request_data=request_data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
api_key=api_key,
json_mode=json_mode,
)
def should_retry_llm_api_inside_llm_translation_on_http_error(
self, e: httpx.HTTPStatusError, litellm_params: dict
) -> bool:
should_drop_params = litellm_params.get("drop_params") or litellm.drop_params
error_text = e.response.text
if should_drop_params and "Extra inputs are not permitted" in error_text:
return True
elif (
"unknown field: parameter index is not a valid field" in error_text
): # remove index from tool calls
return True
return super().should_retry_llm_api_inside_llm_translation_on_http_error(
e=e, litellm_params=litellm_params
)
@property
def max_retry_on_unprocessable_entity_error(self) -> int:
return 2
def transform_request_on_unprocessable_entity_error(
self, e: httpx.HTTPStatusError, request_data: dict
) -> dict:
_messages = cast(Optional[List[AllMessageValues]], request_data.get("messages"))
if (
"unknown field: parameter index is not a valid field" in e.response.text
and _messages is not None
):
litellm.remove_index_from_tool_calls(
messages=_messages,
)
data = drop_params_from_unprocessable_entity_error(e=e, data=request_data)
return data

View file

@ -1,8 +1,8 @@
import json
from abc import abstractmethod
from typing import Optional
from typing import Optional, Union
from litellm.types.utils import GenericStreamingChunk
from litellm.types.utils import GenericStreamingChunk, ModelResponseStream
class BaseModelResponseIterator:
@ -13,7 +13,9 @@ class BaseModelResponseIterator:
self.response_iterator = self.streaming_response
self.json_mode = json_mode
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
def chunk_parser(
self, chunk: dict
) -> Union[GenericStreamingChunk, ModelResponseStream]:
return GenericStreamingChunk(
text="",
is_finished=False,
@ -27,7 +29,9 @@ class BaseModelResponseIterator:
def __iter__(self):
return self
def _handle_string_chunk(self, str_line: str) -> GenericStreamingChunk:
def _handle_string_chunk(
self, str_line: str
) -> Union[GenericStreamingChunk, ModelResponseStream]:
# chunk is a str at this point
if "[DONE]" in str_line:
return GenericStreamingChunk(

View file

@ -82,6 +82,33 @@ class BaseConfig(ABC):
"""
return False
def should_retry_llm_api_inside_llm_translation_on_http_error(
self, e: httpx.HTTPStatusError, litellm_params: dict
) -> bool:
"""
Returns True if the model/provider should retry the LLM API on UnprocessableEntityError
Overriden by azure ai - where different models support different parameters
"""
return False
def transform_request_on_unprocessable_entity_error(
self, e: httpx.HTTPStatusError, request_data: dict
) -> dict:
"""
Transform the request data on UnprocessableEntityError
"""
return request_data
@property
def max_retry_on_unprocessable_entity_error(self) -> int:
"""
Returns the max retry count for UnprocessableEntityError
Used if `should_retry_llm_api_inside_llm_translation_on_http_error` is True
"""
return 0
@abstractmethod
def get_supported_openai_params(self, model: str) -> list:
pass
@ -104,6 +131,7 @@ class BaseConfig(ABC):
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
pass

View file

@ -115,6 +115,7 @@ class AmazonInvokeMixin:
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
raise NotImplementedError(
"validate_environment not implemented for config. Done in invoke_handler.py"

View file

@ -119,6 +119,7 @@ class ClarifaiConfig(BaseConfig):
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
headers = {
"accept": "application/json",

View file

@ -60,6 +60,7 @@ class CloudflareChatConfig(BaseConfig):
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
if api_key is None:
raise ValueError(

View file

@ -116,6 +116,7 @@ class CohereChatConfig(BaseConfig):
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
return cohere_validate_environment(
headers=headers,

View file

@ -102,6 +102,7 @@ class CohereTextConfig(BaseConfig):
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
return cohere_validate_environment(
headers=headers,

View file

@ -8,7 +8,7 @@ import litellm
import litellm.litellm_core_utils
import litellm.types
import litellm.types.utils
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
from litellm.llms.custom_httpx.http_handler import (
@ -30,6 +30,114 @@ else:
class BaseLLMHTTPHandler:
async def _make_common_async_call(
self,
async_httpx_client: AsyncHTTPHandler,
provider_config: BaseConfig,
api_base: str,
headers: dict,
data: dict,
timeout: Union[float, httpx.Timeout],
litellm_params: dict,
stream: bool = False,
) -> httpx.Response:
"""Common implementation across stream + non-stream calls. Meant to ensure consistent error-handling."""
max_retry_on_unprocessable_entity_error = (
provider_config.max_retry_on_unprocessable_entity_error
)
response: Optional[httpx.Response] = None
for i in range(max(max_retry_on_unprocessable_entity_error, 1)):
try:
response = await async_httpx_client.post(
url=api_base,
headers=headers,
data=json.dumps(data),
timeout=timeout,
stream=stream,
)
except httpx.HTTPStatusError as e:
hit_max_retry = i + 1 == max_retry_on_unprocessable_entity_error
should_retry = provider_config.should_retry_llm_api_inside_llm_translation_on_http_error(
e=e, litellm_params=litellm_params
)
if should_retry and not hit_max_retry:
data = (
provider_config.transform_request_on_unprocessable_entity_error(
e=e, request_data=data
)
)
continue
else:
raise self._handle_error(e=e, provider_config=provider_config)
except Exception as e:
raise self._handle_error(e=e, provider_config=provider_config)
break
if response is None:
raise provider_config.get_error_class(
error_message="No response from the API",
status_code=422, # don't retry on this error
headers={},
)
return response
def _make_common_sync_call(
self,
sync_httpx_client: HTTPHandler,
provider_config: BaseConfig,
api_base: str,
headers: dict,
data: dict,
timeout: Union[float, httpx.Timeout],
litellm_params: dict,
stream: bool = False,
) -> httpx.Response:
max_retry_on_unprocessable_entity_error = (
provider_config.max_retry_on_unprocessable_entity_error
)
response: Optional[httpx.Response] = None
for i in range(max(max_retry_on_unprocessable_entity_error, 1)):
try:
response = sync_httpx_client.post(
url=api_base,
headers=headers,
data=json.dumps(data),
timeout=timeout,
stream=stream,
)
except httpx.HTTPStatusError as e:
hit_max_retry = i + 1 == max_retry_on_unprocessable_entity_error
should_retry = provider_config.should_retry_llm_api_inside_llm_translation_on_http_error(
e=e, litellm_params=litellm_params
)
if should_retry and not hit_max_retry:
data = (
provider_config.transform_request_on_unprocessable_entity_error(
e=e, request_data=data
)
)
continue
else:
raise self._handle_error(e=e, provider_config=provider_config)
except Exception as e:
raise self._handle_error(e=e, provider_config=provider_config)
break
if response is None:
raise provider_config.get_error_class(
error_message="No response from the API",
status_code=422, # don't retry on this error
headers={},
)
return response
async def async_completion(
self,
custom_llm_provider: str,
@ -55,15 +163,16 @@ class BaseLLMHTTPHandler:
else:
async_httpx_client = client
try:
response = await async_httpx_client.post(
url=api_base,
response = await self._make_common_async_call(
async_httpx_client=async_httpx_client,
provider_config=provider_config,
api_base=api_base,
headers=headers,
data=json.dumps(data),
data=data,
timeout=timeout,
litellm_params=litellm_params,
stream=False,
)
except Exception as e:
raise self._handle_error(e=e, provider_config=provider_config)
return provider_config.transform_response(
model=model,
raw_response=response,
@ -93,7 +202,7 @@ class BaseLLMHTTPHandler:
stream: Optional[bool] = False,
fake_stream: bool = False,
api_key: Optional[str] = None,
headers={},
headers: Optional[dict] = {},
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
):
provider_config = ProviderConfigManager.get_provider_chat_config(
@ -102,10 +211,11 @@ class BaseLLMHTTPHandler:
# get config from model, custom llm provider
headers = provider_config.validate_environment(
api_key=api_key,
headers=headers,
headers=headers or {},
model=model,
messages=messages,
optional_params=optional_params,
api_base=api_base,
)
api_base = provider_config.get_complete_url(
@ -154,6 +264,7 @@ class BaseLLMHTTPHandler:
if client is not None and isinstance(client, AsyncHTTPHandler)
else None
),
litellm_params=litellm_params,
)
else:
@ -186,7 +297,7 @@ class BaseLLMHTTPHandler:
provider_config=provider_config,
api_base=api_base,
headers=headers, # type: ignore
data=json.dumps(data),
data=data,
model=model,
messages=messages,
logging_obj=logging_obj,
@ -197,6 +308,7 @@ class BaseLLMHTTPHandler:
if client is not None and isinstance(client, HTTPHandler)
else None
),
litellm_params=litellm_params,
)
return CustomStreamWrapper(
completion_stream=completion_stream,
@ -210,19 +322,15 @@ class BaseLLMHTTPHandler:
else:
sync_httpx_client = client
try:
response = sync_httpx_client.post(
url=api_base,
headers=headers,
data=json.dumps(data),
timeout=timeout,
)
except Exception as e:
raise self._handle_error(
e=e,
response = self._make_common_sync_call(
sync_httpx_client=sync_httpx_client,
provider_config=provider_config,
api_base=api_base,
headers=headers,
data=data,
timeout=timeout,
litellm_params=litellm_params,
)
return provider_config.transform_response(
model=model,
raw_response=response,
@ -241,43 +349,32 @@ class BaseLLMHTTPHandler:
provider_config: BaseConfig,
api_base: str,
headers: dict,
data: str,
data: dict,
model: str,
messages: list,
logging_obj,
timeout: Optional[Union[float, httpx.Timeout]],
litellm_params: dict,
timeout: Union[float, httpx.Timeout],
fake_stream: bool = False,
client: Optional[HTTPHandler] = None,
) -> Tuple[Any, httpx.Headers]:
) -> Tuple[Any, dict]:
if client is None or not isinstance(client, HTTPHandler):
sync_httpx_client = _get_httpx_client()
else:
sync_httpx_client = client
try:
stream = True
if fake_stream is True:
stream = False
response = sync_httpx_client.post(
api_base, headers=headers, data=data, timeout=timeout, stream=stream
)
except httpx.HTTPStatusError as e:
raise self._handle_error(
e=e,
provider_config=provider_config,
)
except Exception as e:
for exception in litellm.LITELLM_EXCEPTION_TYPES:
if isinstance(e, exception):
raise e
raise self._handle_error(
e=e,
provider_config=provider_config,
)
if response.status_code != 200:
raise BaseLLMException(
status_code=response.status_code,
message=str(response.read()),
response = self._make_common_sync_call(
sync_httpx_client=sync_httpx_client,
provider_config=provider_config,
api_base=api_base,
headers=headers,
data=data,
timeout=timeout,
litellm_params=litellm_params,
stream=stream,
)
if fake_stream is True:
@ -297,7 +394,7 @@ class BaseLLMHTTPHandler:
additional_args={"complete_input_dict": data},
)
return completion_stream, response.headers
return completion_stream, dict(response.headers)
async def acompletion_stream_function(
self,
@ -310,6 +407,7 @@ class BaseLLMHTTPHandler:
timeout: Union[float, httpx.Timeout],
logging_obj: LiteLLMLoggingObj,
data: dict,
litellm_params: dict,
fake_stream: bool = False,
client: Optional[AsyncHTTPHandler] = None,
):
@ -318,12 +416,13 @@ class BaseLLMHTTPHandler:
provider_config=provider_config,
api_base=api_base,
headers=headers,
data=json.dumps(data),
data=data,
messages=messages,
logging_obj=logging_obj,
timeout=timeout,
fake_stream=fake_stream,
client=client,
litellm_params=litellm_params,
)
streamwrapper = CustomStreamWrapper(
completion_stream=completion_stream,
@ -339,10 +438,11 @@ class BaseLLMHTTPHandler:
provider_config: BaseConfig,
api_base: str,
headers: dict,
data: str,
data: dict,
messages: list,
logging_obj: LiteLLMLoggingObj,
timeout: Optional[Union[float, httpx.Timeout]],
timeout: Union[float, httpx.Timeout],
litellm_params: dict,
fake_stream: bool = False,
client: Optional[AsyncHTTPHandler] = None,
) -> Tuple[Any, httpx.Headers]:
@ -355,29 +455,18 @@ class BaseLLMHTTPHandler:
stream = True
if fake_stream is True:
stream = False
try:
response = await async_httpx_client.post(
api_base, headers=headers, data=data, stream=stream, timeout=timeout
)
except httpx.HTTPStatusError as e:
raise self._handle_error(
e=e,
provider_config=provider_config,
)
except Exception as e:
for exception in litellm.LITELLM_EXCEPTION_TYPES:
if isinstance(e, exception):
raise e
raise self._handle_error(
e=e,
response = await self._make_common_async_call(
async_httpx_client=async_httpx_client,
provider_config=provider_config,
api_base=api_base,
headers=headers,
data=data,
timeout=timeout,
litellm_params=litellm_params,
stream=stream,
)
if response.status_code != 200:
raise BaseLLMException(
status_code=response.status_code,
message=str(response.read()),
)
if fake_stream is True:
completion_stream = provider_config.get_model_response_iterator(
streaming_response=response.json(), sync_stream=False

View file

@ -118,6 +118,7 @@ class DeepgramAudioTranscriptionConfig(BaseAudioTranscriptionConfig):
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
api_key = api_key or get_secret_str("DEEPGRAM_API_KEY")
return {

View file

@ -42,6 +42,7 @@ class FireworksAIMixin:
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
api_key = self._get_api_key(api_key)
if api_key is None:

View file

@ -724,12 +724,14 @@ class Huggingface(BaseLLM):
token_logprob = token["logprob"]
# Add the token information to the 'token_info' list
_logprob.tokens.append(token_text)
_logprob.token_logprobs.append(token_logprob)
cast(List[str], _logprob.tokens).append(token_text)
cast(List[float], _logprob.token_logprobs).append(token_logprob)
# stub this to work with llm eval harness
top_alt_tokens = {"": -1.0, "": -2.0, "": -3.0} # noqa: F601
_logprob.top_logprobs.append(top_alt_tokens)
cast(List[Dict[str, float]], _logprob.top_logprobs).append(
top_alt_tokens
)
# For each element in the 'tokens' list, extract the relevant information
for i, token in enumerate(response_details["tokens"]):
@ -751,13 +753,15 @@ class Huggingface(BaseLLM):
top_alt_tokens[text] = logprob
# Add the token information to the 'token_info' list
_logprob.tokens.append(token_text)
_logprob.token_logprobs.append(token_logprob)
_logprob.top_logprobs.append(top_alt_tokens)
cast(List[str], _logprob.tokens).append(token_text)
cast(List[float], _logprob.token_logprobs).append(token_logprob)
cast(List[Dict[str, float]], _logprob.top_logprobs).append(
top_alt_tokens
)
# Add the text offset of the token
# This is computed as the sum of the lengths of all previous tokens
_logprob.text_offset.append(
cast(List[int], _logprob.text_offset).append(
sum(len(t["text"]) for t in response_details["tokens"][:i])
)

View file

@ -356,6 +356,7 @@ class HuggingfaceChatConfig(BaseConfig):
messages: List[AllMessageValues],
optional_params: Dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Dict:
default_headers = {
"content-type": "application/json",

View file

@ -94,6 +94,7 @@ class NLPCloudConfig(BaseConfig):
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
headers = {
"accept": "application/json",

View file

@ -347,6 +347,7 @@ class OllamaConfig(BaseConfig):
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
return headers

View file

@ -89,6 +89,7 @@ class OobaboogaConfig(OpenAIGPTConfig):
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
headers = {
"accept": "application/json",

View file

@ -181,6 +181,7 @@ class OpenAIGPTConfig(BaseConfig):
Returns:
dict: The transformed request. Sent as the body of the API call.
"""
messages = self._transform_messages(messages=messages, model=model)
return {
"model": model,
"messages": messages,
@ -225,5 +226,6 @@ class OpenAIGPTConfig(BaseConfig):
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
raise NotImplementedError

View file

@ -45,7 +45,8 @@ class OpenAIError(BaseLLMException):
####### Error Handling Utils for OpenAI API #######################
###################################################################
def drop_params_from_unprocessable_entity_error(
e: openai.UnprocessableEntityError, data: Dict[str, Any]
e: Union[openai.UnprocessableEntityError, httpx.HTTPStatusError],
data: Dict[str, Any],
) -> Dict[str, Any]:
"""
Helper function to read OpenAI UnprocessableEntityError and drop the params that raised an error from the error message.
@ -58,14 +59,25 @@ def drop_params_from_unprocessable_entity_error(
Dict[str, Any]: A new dictionary with invalid parameters removed
"""
invalid_params: List[str] = []
if e.body is not None and isinstance(e.body, dict) and e.body.get("message"):
message = e.body.get("message", {})
if isinstance(e, httpx.HTTPStatusError):
error_json = e.response.json()
error_message = error_json.get("error", {})
error_body = error_message
else:
error_body = e.body
if (
error_body is not None
and isinstance(error_body, dict)
and error_body.get("message")
):
message = error_body.get("message", {})
if isinstance(message, str):
try:
message = json.loads(message)
except json.JSONDecodeError:
message = {"detail": message}
detail = message.get("detail")
if isinstance(detail, List) and len(detail) > 0 and isinstance(detail[0], dict):
for error_dict in detail:
if (
@ -76,4 +88,5 @@ def drop_params_from_unprocessable_entity_error(
invalid_params.append(error_dict["loc"][1])
new_data = {k: v for k, v in data.items() if k not in invalid_params}
return new_data

View file

@ -2,9 +2,11 @@ import hashlib
import types
from typing import (
Any,
AsyncIterator,
Callable,
Coroutine,
Iterable,
Iterator,
List,
Literal,
Optional,
@ -24,10 +26,16 @@ import litellm
from litellm import LlmProviders
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.llms.bedrock.chat.invoke_handler import MockResponseIterator
from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS
from litellm.types.utils import EmbeddingResponse, ImageResponse, ModelResponse
from litellm.types.utils import (
EmbeddingResponse,
ImageResponse,
ModelResponse,
ModelResponseStream,
)
from litellm.utils import (
CustomStreamWrapper,
ProviderConfigManager,
@ -36,7 +44,6 @@ from litellm.utils import (
from ...types.llms.openai import *
from ..base import BaseLLM
from .chat.gpt_transformation import OpenAIGPTConfig
from .common_utils import OpenAIError, drop_params_from_unprocessable_entity_error
@ -232,6 +239,7 @@ class OpenAIConfig(BaseConfig):
litellm_params: dict,
headers: dict,
) -> dict:
messages = self._transform_messages(messages=messages, model=model)
return {"model": model, "messages": messages, **optional_params}
def transform_response(
@ -248,10 +256,21 @@ class OpenAIConfig(BaseConfig):
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
raise NotImplementedError(
"OpenAI handler does this transformation as it uses the OpenAI SDK."
logging_obj.post_call(original_response=raw_response.text)
logging_obj.model_call_details["response_headers"] = raw_response.headers
final_response_obj = cast(
ModelResponse,
convert_to_model_response_object(
response_object=raw_response.json(),
model_response_object=model_response,
hidden_params={"headers": raw_response.headers},
_response_headers=dict(raw_response.headers),
),
)
return final_response_obj
def validate_environment(
self,
headers: dict,
@ -259,12 +278,37 @@ class OpenAIConfig(BaseConfig):
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
raise NotImplementedError(
"OpenAI handler does this validation as it uses the OpenAI SDK."
return {
"Authorization": f"Bearer {api_key}",
**headers,
}
def get_model_response_iterator(
self,
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
sync_stream: bool,
json_mode: Optional[bool] = False,
) -> Any:
return OpenAIChatCompletionResponseIterator(
streaming_response=streaming_response,
sync_stream=sync_stream,
json_mode=json_mode,
)
class OpenAIChatCompletionResponseIterator(BaseModelResponseIterator):
def chunk_parser(self, chunk: dict) -> ModelResponseStream:
"""
{'choices': [{'delta': {'content': '', 'role': 'assistant'}, 'finish_reason': None, 'index': 0, 'logprobs': None}], 'created': 1735763082, 'id': 'a83a2b0fbfaf4aab9c2c93cb8ba346d7', 'model': 'mistral-large', 'object': 'chat.completion.chunk'}
"""
try:
return ModelResponseStream(**chunk)
except Exception as e:
raise e
class OpenAIChatCompletion(BaseLLM):
def __init__(self) -> None:
@ -473,14 +517,6 @@ class OpenAIChatCompletion(BaseLLM):
if custom_llm_provider is not None and custom_llm_provider != "openai":
model_response.model = f"{custom_llm_provider}/{model}"
if messages is not None and provider_config is not None:
if isinstance(provider_config, OpenAIGPTConfig) or isinstance(
provider_config, OpenAIConfig
): # [TODO]: remove. no longer needed as .transform_request can just handle this.
messages = provider_config._transform_messages(
messages=messages, model=model
)
for _ in range(
2
): # if call fails due to alternating messages, retry with reformatted message
@ -647,12 +683,10 @@ class OpenAIChatCompletion(BaseLLM):
new_messages = messages
new_messages.append({"role": "user", "content": ""})
messages = new_messages
elif (
"unknown field: parameter index is not a valid field" in str(e)
) and "tools" in data:
litellm.remove_index_from_tool_calls(
tool_calls=data["tools"], messages=messages
)
elif "unknown field: parameter index is not a valid field" in str(
e
):
litellm.remove_index_from_tool_calls(messages=messages)
else:
raise e
except OpenAIError as e:

View file

@ -132,5 +132,6 @@ class PetalsConfig(BaseConfig):
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
return {}

View file

@ -164,6 +164,7 @@ class PredibaseConfig(BaseConfig):
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
if api_key is None:
raise ValueError(

View file

@ -309,6 +309,7 @@ class ReplicateConfig(BaseConfig):
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
headers = {
"Authorization": f"Token {api_key}",

View file

@ -260,6 +260,7 @@ class SagemakerConfig(BaseConfig):
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
headers = {"Content-Type": "application/json"}

View file

@ -48,6 +48,7 @@ class TritonConfig(BaseConfig):
messages: List[AllMessageValues],
optional_params: Dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Dict:
return {"Content-Type": "application/json"}

View file

@ -43,6 +43,7 @@ class TritonEmbeddingConfig(BaseEmbeddingConfig):
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
return {}

View file

@ -808,6 +808,7 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
messages: List[AllMessageValues],
optional_params: Dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Dict:
default_headers = {
"Content-Type": "application/json",

View file

@ -82,6 +82,7 @@ class VoyageEmbeddingConfig(BaseEmbeddingConfig):
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
if api_key is None:
api_key = (

View file

@ -166,6 +166,7 @@ class IBMWatsonXMixin:
messages: List[AllMessageValues],
optional_params: Dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Dict:
default_headers = {
"Content-Type": "application/json",

View file

@ -1122,6 +1122,7 @@ def completion( # type: ignore # noqa: PLR0915
custom_prompt_dict=custom_prompt_dict,
litellm_metadata=kwargs.get("litellm_metadata"),
disable_add_transform_inline_image_block=disable_add_transform_inline_image_block,
drop_params=kwargs.get("drop_params"),
)
logging.update_environment_variables(
model=model,
@ -1347,39 +1348,28 @@ def completion( # type: ignore # noqa: PLR0915
if extra_headers is not None:
optional_params["extra_headers"] = extra_headers
## LOAD CONFIG - if set
config = litellm.AzureAIStudioConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > openai_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
## FOR COHERE
if "command-r" in model: # make sure tool call in messages are str
messages = stringify_json_tool_call_content(messages=messages)
## COMPLETION CALL
try:
response = openai_chat_completions.completion(
response = base_llm_http_handler.completion(
model=model,
messages=messages,
headers=headers,
model_response=model_response,
print_verbose=print_verbose,
api_key=api_key,
api_base=api_base,
acompletion=acompletion,
logging_obj=logging,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
timeout=timeout, # type: ignore
custom_prompt_dict=custom_prompt_dict,
client=client, # pass AsyncOpenAI, OpenAI client
organization=organization,
custom_llm_provider=custom_llm_provider,
drop_params=non_default_params.get("drop_params"),
encoding=encoding,
stream=stream,
)
except Exception as e:
## LOGGING - log the original exception returned

View file

@ -1074,10 +1074,10 @@ class EmbeddingResponse(OpenAIObject):
class Logprobs(OpenAIObject):
text_offset: List[int]
token_logprobs: List[Union[float, None]]
tokens: List[str]
top_logprobs: List[Union[Dict[str, float], None]]
text_offset: Optional[List[int]]
token_logprobs: Optional[List[Union[float, None]]]
tokens: Optional[List[str]]
top_logprobs: Optional[List[Union[Dict[str, float], None]]]
class TextChoices(OpenAIObject):

View file

@ -2002,6 +2002,7 @@ def get_litellm_params(
custom_prompt_dict: Optional[dict] = None,
litellm_metadata: Optional[dict] = None,
disable_add_transform_inline_image_block: Optional[bool] = None,
drop_params: Optional[bool] = None,
):
litellm_params = {
"acompletion": acompletion,
@ -2035,6 +2036,7 @@ def get_litellm_params(
"custom_prompt_dict": custom_prompt_dict,
"litellm_metadata": litellm_metadata,
"disable_add_transform_inline_image_block": disable_add_transform_inline_image_block,
"drop_params": drop_params,
}
return litellm_params
@ -6345,3 +6347,44 @@ def extract_duration_from_srt_or_vtt(srt_or_vtt_content: str) -> Optional[float]
durations.append(total_seconds)
return max(durations) if durations else None
import httpx
def _add_path_to_api_base(api_base: str, ending_path: str) -> str:
"""
Adds an ending path to an API base URL while preventing duplicate path segments.
Args:
api_base: Base URL string
ending_path: Path to append to the base URL
Returns:
Modified URL string with proper path handling
"""
original_url = httpx.URL(api_base)
base_url = original_url.copy_with(params={}) # Removes query params
base_path = original_url.path.rstrip("/")
end_path = ending_path.lstrip("/")
# Split paths into segments
base_segments = [s for s in base_path.split("/") if s]
end_segments = [s for s in end_path.split("/") if s]
# Find overlapping segments from the end of base_path and start of ending_path
final_segments = []
for i in range(len(base_segments)):
if base_segments[i:] == end_segments[: len(base_segments) - i]:
final_segments = base_segments[:i] + end_segments
break
else:
# No overlap found, just combine all segments
final_segments = base_segments + end_segments
# Construct the new path
modified_path = "/" + "/".join(final_segments)
modified_url = base_url.copy_with(path=modified_path)
# Re-add the original query parameters
return str(modified_url.copy_with(params=original_url.params))

View file

@ -28,6 +28,7 @@ from unittest.mock import MagicMock, patch
import pytest
import litellm
from litellm import completion
@pytest.mark.parametrize(
@ -51,18 +52,13 @@ async def test_azure_ai_with_image_url():
Test that Azure AI studio can handle image_url passed when content is a list containing both text and image_url
"""
from openai import AsyncOpenAI
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
litellm.set_verbose = True
client = AsyncOpenAI(
api_key="fake-api-key",
base_url="https://Phi-3-5-vision-instruct-dcvov.eastus2.models.ai.azure.com",
)
client = AsyncHTTPHandler()
with patch.object(
client.chat.completions.with_raw_response, "create"
) as mock_client:
with patch.object(client, "post") as mock_client:
try:
await litellm.acompletion(
model="azure_ai/Phi-3-5-vision-instruct-dcvov",
@ -94,8 +90,9 @@ async def test_azure_ai_with_image_url():
# Verify the request was made
mock_client.assert_called_once()
print(f"mock_client.call_args.kwargs: {mock_client.call_args.kwargs}")
# Check the request body
request_body = mock_client.call_args.kwargs
request_body = json.loads(mock_client.call_args.kwargs["data"])
assert request_body["model"] == "Phi-3-5-vision-instruct-dcvov"
assert request_body["messages"] == [
{
@ -111,3 +108,79 @@ async def test_azure_ai_with_image_url():
],
}
]
@pytest.mark.parametrize(
"api_base, expected_url",
[
(
"https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview",
"https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview",
),
(
"https://litellm8397336933.services.ai.azure.com/models/chat/completions",
"https://litellm8397336933.services.ai.azure.com/models/chat/completions",
),
(
"https://litellm8397336933.services.ai.azure.com/models",
"https://litellm8397336933.services.ai.azure.com/models/chat/completions",
),
(
"https://litellm8397336933.services.ai.azure.com",
"https://litellm8397336933.services.ai.azure.com/models/chat/completions",
),
],
)
def test_azure_ai_services_handler(api_base, expected_url):
from litellm.llms.custom_httpx.http_handler import HTTPHandler
litellm.set_verbose = True
client = HTTPHandler()
with patch.object(client, "post") as mock_client:
try:
response = litellm.completion(
model="azure_ai/Meta-Llama-3.1-70B-Instruct",
messages=[{"role": "user", "content": "Hello, how are you?"}],
api_key="my-fake-api-key",
api_base=api_base,
client=client,
)
print(response)
except Exception as e:
print(f"Error: {e}")
mock_client.assert_called_once()
assert mock_client.call_args.kwargs["headers"]["api-key"] == "my-fake-api-key"
assert mock_client.call_args.kwargs["url"] == expected_url
def test_completion_azure_ai_command_r():
try:
import os
litellm.set_verbose = True
os.environ["AZURE_AI_API_BASE"] = os.getenv("AZURE_COHERE_API_BASE", "")
os.environ["AZURE_AI_API_KEY"] = os.getenv("AZURE_COHERE_API_KEY", "")
response = completion(
model="azure_ai/command-r-plus",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "What is the meaning of life?"}
],
}
],
) # type: ignore
assert "azure_ai" in response.model
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -132,34 +132,6 @@ def test_null_role_response():
assert response.choices[0].message.role == "assistant"
def test_completion_azure_ai_command_r():
try:
import os
litellm.set_verbose = True
os.environ["AZURE_AI_API_BASE"] = os.getenv("AZURE_COHERE_API_BASE", "")
os.environ["AZURE_AI_API_KEY"] = os.getenv("AZURE_COHERE_API_KEY", "")
response = completion(
model="azure_ai/command-r-plus",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "What is the meaning of life?"}
],
}
],
) # type: ignore
assert "azure_ai" in response.model
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_completion_azure_ai_mistral_invalid_params(sync_mode):

View file

@ -199,4 +199,4 @@ def test_azure_global_standard_get_llm_provider():
api_base="https://my-deployment-francecentral.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview",
api_key="fake-api-key",
)
assert custom_llm_provider == "azure"
assert custom_llm_provider == "azure_ai"

View file

@ -2954,6 +2954,7 @@ def test_azure_streaming_and_function_calling():
async def test_completion_azure_ai_mistral_invalid_params(sync_mode):
try:
import os
from litellm import stream_chunk_builder
litellm.set_verbose = True
@ -2968,15 +2969,21 @@ async def test_completion_azure_ai_mistral_invalid_params(sync_mode):
"drop_params": True,
"stream": True,
}
chunks = []
if sync_mode:
response: litellm.ModelResponse = completion(**data) # type: ignore
response = completion(**data) # type: ignore
for chunk in response:
print(chunk)
chunks.append(chunk)
else:
response: litellm.ModelResponse = await litellm.acompletion(**data) # type: ignore
response = await litellm.acompletion(**data) # type: ignore
async for chunk in response:
print(chunk)
chunks.append(chunk)
print(f"chunks: {chunks}")
response = stream_chunk_builder(chunks=chunks)
assert response.choices[0].message.content is not None
except litellm.Timeout as e:
pass
except Exception as e:

View file

@ -1252,3 +1252,19 @@ def test_fireworks_ai_document_inlining():
assert supports_pdf_input("fireworks_ai/llama-3.1-8b-instruct") is True
assert supports_vision("fireworks_ai/llama-3.1-8b-instruct") is True
def test_logprobs_type():
from litellm.types.utils import Logprobs
logprobs = {
"text_offset": None,
"token_logprobs": None,
"tokens": None,
"top_logprobs": None,
}
logprobs = Logprobs(**logprobs)
assert logprobs.text_offset is None
assert logprobs.token_logprobs is None
assert logprobs.tokens is None
assert logprobs.top_logprobs is None