mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
Litellm dev 12 30 2024 p2 (#7495)
* test(azure_openai_o1.py): initial commit with testing for azure openai o1 preview model * fix(base_llm_unit_tests.py): handle azure o1 preview response format tests skip as o1 on azure doesn't support tool calling yet * fix: initial commit of azure o1 handler using openai caller simplifies calling + allows fake streaming logic alr. implemented for openai to just work * feat(azure/o1_handler.py): fake o1 streaming for azure o1 models azure does not currently support streaming for o1 * feat(o1_transformation.py): support overriding 'should_fake_stream' on azure/o1 via 'supports_native_streaming' param on model info enables user to toggle on when azure allows o1 streaming without needing to bump versions * style(router.py): remove 'give feedback/get help' messaging when router is used Prevents noisy messaging Closes https://github.com/BerriAI/litellm/issues/5942 * fix(types/utils.py): handle none logprobs Fixes https://github.com/BerriAI/litellm/issues/328 * fix(exception_mapping_utils.py): fix error str unbound error * refactor(azure_ai/): move to openai_like chat completion handler allows for easy swapping of api base url's (e.g. ai.services.com) Fixes https://github.com/BerriAI/litellm/issues/7275 * refactor(azure_ai/): move to base llm http handler * fix(azure_ai/): handle differing api endpoints * fix(azure_ai/): make sure all unit tests are passing * fix: fix linting errors * fix: fix linting errors * fix: fix linting error * fix: fix linting errors * fix(azure_ai/transformation.py): handle extra body param * fix(azure_ai/transformation.py): fix max retries param handling * fix: fix test * test(test_azure_o1.py): fix test * fix(llm_http_handler.py): support handling azure ai unprocessable entity error * fix(llm_http_handler.py): handle sync invalid param error for azure ai * fix(azure_ai/): streaming support with base_llm_http_handler * fix(llm_http_handler.py): working sync stream calls with unprocessable entity handling for azure ai * fix: fix linting errors * fix(llm_http_handler.py): fix linting error * fix(azure_ai/): handle cohere tool call invalid index param error
This commit is contained in:
parent
0f1b298fe0
commit
0120176541
42 changed files with 638 additions and 192 deletions
|
@ -1,10 +1,11 @@
|
||||||
# What is this?
|
# What is this?
|
||||||
## Helper utilities
|
## Helper utilities
|
||||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
|
from litellm.types.llms.openai import AllMessageValues, ChatCompletionToolParam
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from opentelemetry.trace import Span as _Span
|
from opentelemetry.trace import Span as _Span
|
||||||
|
@ -53,17 +54,18 @@ def map_finish_reason(
|
||||||
return finish_reason
|
return finish_reason
|
||||||
|
|
||||||
|
|
||||||
def remove_index_from_tool_calls(messages, tool_calls):
|
def remove_index_from_tool_calls(
|
||||||
for tool_call in tool_calls:
|
messages: Optional[List[AllMessageValues]],
|
||||||
if "index" in tool_call:
|
):
|
||||||
tool_call.pop("index")
|
if messages is not None:
|
||||||
|
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if "tool_calls" in message:
|
_tool_calls = message.get("tool_calls")
|
||||||
tool_calls = message["tool_calls"]
|
if _tool_calls is not None and isinstance(_tool_calls, list):
|
||||||
for tool_call in tool_calls:
|
for tool_call in _tool_calls:
|
||||||
if "index" in tool_call:
|
if (
|
||||||
tool_call.pop("index")
|
isinstance(tool_call, dict) and "index" in tool_call
|
||||||
|
): # Type guard to ensure it's a dict
|
||||||
|
tool_call.pop("index", None)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
@ -148,11 +148,10 @@ def exception_type( # type: ignore # noqa: PLR0915
|
||||||
original_exception=original_exception
|
original_exception=original_exception
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
|
error_str = str(original_exception)
|
||||||
if model:
|
if model:
|
||||||
if hasattr(original_exception, "message"):
|
if hasattr(original_exception, "message"):
|
||||||
error_str = str(original_exception.message)
|
error_str = str(original_exception.message)
|
||||||
else:
|
|
||||||
error_str = str(original_exception)
|
|
||||||
if isinstance(original_exception, BaseException):
|
if isinstance(original_exception, BaseException):
|
||||||
exception_type = type(original_exception).__name__
|
exception_type = type(original_exception).__name__
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -741,6 +741,7 @@ class AnthropicConfig(BaseConfig):
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
if api_key is None:
|
if api_key is None:
|
||||||
raise litellm.AuthenticationError(
|
raise litellm.AuthenticationError(
|
||||||
|
|
|
@ -85,6 +85,7 @@ class AnthropicTextConfig(BaseConfig):
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
if api_key is None:
|
if api_key is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
|
@ -283,6 +283,7 @@ class AzureOpenAIConfig(BaseConfig):
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Azure OpenAI has custom logic for validating environment, as it uses the OpenAI SDK."
|
"Azure OpenAI has custom logic for validating environment, as it uses the OpenAI SDK."
|
||||||
|
|
|
@ -1,4 +1,7 @@
|
||||||
from typing import List, Optional, Tuple
|
from typing import Any, List, Optional, Tuple, cast
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from httpx import Response
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
|
@ -6,13 +9,81 @@ from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||||
_audio_or_image_in_message_content,
|
_audio_or_image_in_message_content,
|
||||||
convert_content_list_to_str,
|
convert_content_list_to_str,
|
||||||
)
|
)
|
||||||
|
from litellm.llms.base_llm.chat.transformation import LiteLLMLoggingObj
|
||||||
|
from litellm.llms.openai.common_utils import drop_params_from_unprocessable_entity_error
|
||||||
from litellm.llms.openai.openai import OpenAIConfig
|
from litellm.llms.openai.openai import OpenAIConfig
|
||||||
from litellm.secret_managers.main import get_secret_str
|
from litellm.secret_managers.main import get_secret_str
|
||||||
from litellm.types.llms.openai import AllMessageValues
|
from litellm.types.llms.openai import AllMessageValues, ChatCompletionToolParam
|
||||||
from litellm.types.utils import ProviderField
|
from litellm.types.utils import ModelResponse, ProviderField
|
||||||
|
from litellm.utils import _add_path_to_api_base
|
||||||
|
|
||||||
|
|
||||||
class AzureAIStudioConfig(OpenAIConfig):
|
class AzureAIStudioConfig(OpenAIConfig):
|
||||||
|
def validate_environment(
|
||||||
|
self,
|
||||||
|
headers: dict,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
if api_base and "services.ai.azure.com" in api_base:
|
||||||
|
headers["api-key"] = api_key
|
||||||
|
else:
|
||||||
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
|
||||||
|
return headers
|
||||||
|
|
||||||
|
def get_complete_url(
|
||||||
|
self,
|
||||||
|
api_base: str,
|
||||||
|
model: str,
|
||||||
|
optional_params: dict,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Constructs a complete URL for the API request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- api_base: Base URL, e.g.,
|
||||||
|
"https://litellm8397336933.services.ai.azure.com"
|
||||||
|
OR
|
||||||
|
"https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview"
|
||||||
|
- model: Model name.
|
||||||
|
- optional_params: Additional query parameters, including "api_version".
|
||||||
|
- stream: If streaming is required (optional).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- A complete URL string, e.g.,
|
||||||
|
"https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview"
|
||||||
|
"""
|
||||||
|
original_url = httpx.URL(api_base)
|
||||||
|
|
||||||
|
# Extract api_version or use default
|
||||||
|
api_version = cast(Optional[str], optional_params.get("api_version"))
|
||||||
|
|
||||||
|
# Check if 'api-version' is already present
|
||||||
|
if "api-version" not in original_url.params and api_version:
|
||||||
|
# Add api_version to optional_params
|
||||||
|
original_url.params["api-version"] = api_version
|
||||||
|
|
||||||
|
# Add the path to the base URL
|
||||||
|
if "services.ai.azure.com" in api_base:
|
||||||
|
new_url = _add_path_to_api_base(
|
||||||
|
api_base=api_base, ending_path="/models/chat/completions"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
new_url = _add_path_to_api_base(
|
||||||
|
api_base=api_base, ending_path="/chat/completions"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert optional_params to query parameters
|
||||||
|
query_params = original_url.params
|
||||||
|
final_url = httpx.URL(new_url).copy_with(params=query_params)
|
||||||
|
|
||||||
|
return str(final_url)
|
||||||
|
|
||||||
def get_required_params(self) -> List[ProviderField]:
|
def get_required_params(self) -> List[ProviderField]:
|
||||||
"""For a given provider, return it's required fields with a description"""
|
"""For a given provider, return it's required fields with a description"""
|
||||||
return [
|
return [
|
||||||
|
@ -62,8 +133,6 @@ class AzureAIStudioConfig(OpenAIConfig):
|
||||||
):
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if api_base and "services.ai.azure" in api_base:
|
|
||||||
return True
|
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
return False
|
return False
|
||||||
|
@ -86,3 +155,81 @@ class AzureAIStudioConfig(OpenAIConfig):
|
||||||
)
|
)
|
||||||
custom_llm_provider = "azure"
|
custom_llm_provider = "azure"
|
||||||
return api_base, dynamic_api_key, custom_llm_provider
|
return api_base, dynamic_api_key, custom_llm_provider
|
||||||
|
|
||||||
|
def transform_request(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
headers: dict,
|
||||||
|
) -> dict:
|
||||||
|
extra_body = optional_params.pop("extra_body", {})
|
||||||
|
if extra_body and isinstance(extra_body, dict):
|
||||||
|
optional_params.update(extra_body)
|
||||||
|
optional_params.pop("max_retries", None)
|
||||||
|
return super().transform_request(
|
||||||
|
model, messages, optional_params, litellm_params, headers
|
||||||
|
)
|
||||||
|
|
||||||
|
def transform_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
raw_response: Response,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
request_data: dict,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
encoding: Any,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
json_mode: Optional[bool] = None,
|
||||||
|
) -> ModelResponse:
|
||||||
|
model_response.model = f"azure_ai/{model}"
|
||||||
|
return super().transform_response(
|
||||||
|
model=model,
|
||||||
|
raw_response=raw_response,
|
||||||
|
model_response=model_response,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
request_data=request_data,
|
||||||
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
encoding=encoding,
|
||||||
|
api_key=api_key,
|
||||||
|
json_mode=json_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
def should_retry_llm_api_inside_llm_translation_on_http_error(
|
||||||
|
self, e: httpx.HTTPStatusError, litellm_params: dict
|
||||||
|
) -> bool:
|
||||||
|
should_drop_params = litellm_params.get("drop_params") or litellm.drop_params
|
||||||
|
error_text = e.response.text
|
||||||
|
if should_drop_params and "Extra inputs are not permitted" in error_text:
|
||||||
|
return True
|
||||||
|
elif (
|
||||||
|
"unknown field: parameter index is not a valid field" in error_text
|
||||||
|
): # remove index from tool calls
|
||||||
|
return True
|
||||||
|
return super().should_retry_llm_api_inside_llm_translation_on_http_error(
|
||||||
|
e=e, litellm_params=litellm_params
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_retry_on_unprocessable_entity_error(self) -> int:
|
||||||
|
return 2
|
||||||
|
|
||||||
|
def transform_request_on_unprocessable_entity_error(
|
||||||
|
self, e: httpx.HTTPStatusError, request_data: dict
|
||||||
|
) -> dict:
|
||||||
|
_messages = cast(Optional[List[AllMessageValues]], request_data.get("messages"))
|
||||||
|
if (
|
||||||
|
"unknown field: parameter index is not a valid field" in e.response.text
|
||||||
|
and _messages is not None
|
||||||
|
):
|
||||||
|
litellm.remove_index_from_tool_calls(
|
||||||
|
messages=_messages,
|
||||||
|
)
|
||||||
|
data = drop_params_from_unprocessable_entity_error(e=e, data=request_data)
|
||||||
|
return data
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
import json
|
import json
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
from litellm.types.utils import GenericStreamingChunk
|
from litellm.types.utils import GenericStreamingChunk, ModelResponseStream
|
||||||
|
|
||||||
|
|
||||||
class BaseModelResponseIterator:
|
class BaseModelResponseIterator:
|
||||||
|
@ -13,7 +13,9 @@ class BaseModelResponseIterator:
|
||||||
self.response_iterator = self.streaming_response
|
self.response_iterator = self.streaming_response
|
||||||
self.json_mode = json_mode
|
self.json_mode = json_mode
|
||||||
|
|
||||||
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
def chunk_parser(
|
||||||
|
self, chunk: dict
|
||||||
|
) -> Union[GenericStreamingChunk, ModelResponseStream]:
|
||||||
return GenericStreamingChunk(
|
return GenericStreamingChunk(
|
||||||
text="",
|
text="",
|
||||||
is_finished=False,
|
is_finished=False,
|
||||||
|
@ -27,7 +29,9 @@ class BaseModelResponseIterator:
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _handle_string_chunk(self, str_line: str) -> GenericStreamingChunk:
|
def _handle_string_chunk(
|
||||||
|
self, str_line: str
|
||||||
|
) -> Union[GenericStreamingChunk, ModelResponseStream]:
|
||||||
# chunk is a str at this point
|
# chunk is a str at this point
|
||||||
if "[DONE]" in str_line:
|
if "[DONE]" in str_line:
|
||||||
return GenericStreamingChunk(
|
return GenericStreamingChunk(
|
||||||
|
|
|
@ -82,6 +82,33 @@ class BaseConfig(ABC):
|
||||||
"""
|
"""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def should_retry_llm_api_inside_llm_translation_on_http_error(
|
||||||
|
self, e: httpx.HTTPStatusError, litellm_params: dict
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Returns True if the model/provider should retry the LLM API on UnprocessableEntityError
|
||||||
|
|
||||||
|
Overriden by azure ai - where different models support different parameters
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
def transform_request_on_unprocessable_entity_error(
|
||||||
|
self, e: httpx.HTTPStatusError, request_data: dict
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Transform the request data on UnprocessableEntityError
|
||||||
|
"""
|
||||||
|
return request_data
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_retry_on_unprocessable_entity_error(self) -> int:
|
||||||
|
"""
|
||||||
|
Returns the max retry count for UnprocessableEntityError
|
||||||
|
|
||||||
|
Used if `should_retry_llm_api_inside_llm_translation_on_http_error` is True
|
||||||
|
"""
|
||||||
|
return 0
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_supported_openai_params(self, model: str) -> list:
|
def get_supported_openai_params(self, model: str) -> list:
|
||||||
pass
|
pass
|
||||||
|
@ -104,6 +131,7 @@ class BaseConfig(ABC):
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -115,6 +115,7 @@ class AmazonInvokeMixin:
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"validate_environment not implemented for config. Done in invoke_handler.py"
|
"validate_environment not implemented for config. Done in invoke_handler.py"
|
||||||
|
|
|
@ -119,6 +119,7 @@ class ClarifaiConfig(BaseConfig):
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
headers = {
|
headers = {
|
||||||
"accept": "application/json",
|
"accept": "application/json",
|
||||||
|
|
|
@ -60,6 +60,7 @@ class CloudflareChatConfig(BaseConfig):
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
if api_key is None:
|
if api_key is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
|
@ -116,6 +116,7 @@ class CohereChatConfig(BaseConfig):
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
return cohere_validate_environment(
|
return cohere_validate_environment(
|
||||||
headers=headers,
|
headers=headers,
|
||||||
|
|
|
@ -102,6 +102,7 @@ class CohereTextConfig(BaseConfig):
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
return cohere_validate_environment(
|
return cohere_validate_environment(
|
||||||
headers=headers,
|
headers=headers,
|
||||||
|
|
|
@ -8,7 +8,7 @@ import litellm
|
||||||
import litellm.litellm_core_utils
|
import litellm.litellm_core_utils
|
||||||
import litellm.types
|
import litellm.types
|
||||||
import litellm.types.utils
|
import litellm.types.utils
|
||||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||||
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
|
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
|
||||||
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
|
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
|
||||||
from litellm.llms.custom_httpx.http_handler import (
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
@ -30,6 +30,114 @@ else:
|
||||||
|
|
||||||
|
|
||||||
class BaseLLMHTTPHandler:
|
class BaseLLMHTTPHandler:
|
||||||
|
|
||||||
|
async def _make_common_async_call(
|
||||||
|
self,
|
||||||
|
async_httpx_client: AsyncHTTPHandler,
|
||||||
|
provider_config: BaseConfig,
|
||||||
|
api_base: str,
|
||||||
|
headers: dict,
|
||||||
|
data: dict,
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
litellm_params: dict,
|
||||||
|
stream: bool = False,
|
||||||
|
) -> httpx.Response:
|
||||||
|
"""Common implementation across stream + non-stream calls. Meant to ensure consistent error-handling."""
|
||||||
|
max_retry_on_unprocessable_entity_error = (
|
||||||
|
provider_config.max_retry_on_unprocessable_entity_error
|
||||||
|
)
|
||||||
|
|
||||||
|
response: Optional[httpx.Response] = None
|
||||||
|
for i in range(max(max_retry_on_unprocessable_entity_error, 1)):
|
||||||
|
try:
|
||||||
|
response = await async_httpx_client.post(
|
||||||
|
url=api_base,
|
||||||
|
headers=headers,
|
||||||
|
data=json.dumps(data),
|
||||||
|
timeout=timeout,
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
hit_max_retry = i + 1 == max_retry_on_unprocessable_entity_error
|
||||||
|
should_retry = provider_config.should_retry_llm_api_inside_llm_translation_on_http_error(
|
||||||
|
e=e, litellm_params=litellm_params
|
||||||
|
)
|
||||||
|
if should_retry and not hit_max_retry:
|
||||||
|
data = (
|
||||||
|
provider_config.transform_request_on_unprocessable_entity_error(
|
||||||
|
e=e, request_data=data
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
raise self._handle_error(e=e, provider_config=provider_config)
|
||||||
|
except Exception as e:
|
||||||
|
raise self._handle_error(e=e, provider_config=provider_config)
|
||||||
|
break
|
||||||
|
|
||||||
|
if response is None:
|
||||||
|
raise provider_config.get_error_class(
|
||||||
|
error_message="No response from the API",
|
||||||
|
status_code=422, # don't retry on this error
|
||||||
|
headers={},
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _make_common_sync_call(
|
||||||
|
self,
|
||||||
|
sync_httpx_client: HTTPHandler,
|
||||||
|
provider_config: BaseConfig,
|
||||||
|
api_base: str,
|
||||||
|
headers: dict,
|
||||||
|
data: dict,
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
litellm_params: dict,
|
||||||
|
stream: bool = False,
|
||||||
|
) -> httpx.Response:
|
||||||
|
|
||||||
|
max_retry_on_unprocessable_entity_error = (
|
||||||
|
provider_config.max_retry_on_unprocessable_entity_error
|
||||||
|
)
|
||||||
|
|
||||||
|
response: Optional[httpx.Response] = None
|
||||||
|
|
||||||
|
for i in range(max(max_retry_on_unprocessable_entity_error, 1)):
|
||||||
|
try:
|
||||||
|
response = sync_httpx_client.post(
|
||||||
|
url=api_base,
|
||||||
|
headers=headers,
|
||||||
|
data=json.dumps(data),
|
||||||
|
timeout=timeout,
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
hit_max_retry = i + 1 == max_retry_on_unprocessable_entity_error
|
||||||
|
should_retry = provider_config.should_retry_llm_api_inside_llm_translation_on_http_error(
|
||||||
|
e=e, litellm_params=litellm_params
|
||||||
|
)
|
||||||
|
if should_retry and not hit_max_retry:
|
||||||
|
data = (
|
||||||
|
provider_config.transform_request_on_unprocessable_entity_error(
|
||||||
|
e=e, request_data=data
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
raise self._handle_error(e=e, provider_config=provider_config)
|
||||||
|
except Exception as e:
|
||||||
|
raise self._handle_error(e=e, provider_config=provider_config)
|
||||||
|
break
|
||||||
|
|
||||||
|
if response is None:
|
||||||
|
raise provider_config.get_error_class(
|
||||||
|
error_message="No response from the API",
|
||||||
|
status_code=422, # don't retry on this error
|
||||||
|
headers={},
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
async def async_completion(
|
async def async_completion(
|
||||||
self,
|
self,
|
||||||
custom_llm_provider: str,
|
custom_llm_provider: str,
|
||||||
|
@ -55,15 +163,16 @@ class BaseLLMHTTPHandler:
|
||||||
else:
|
else:
|
||||||
async_httpx_client = client
|
async_httpx_client = client
|
||||||
|
|
||||||
try:
|
response = await self._make_common_async_call(
|
||||||
response = await async_httpx_client.post(
|
async_httpx_client=async_httpx_client,
|
||||||
url=api_base,
|
provider_config=provider_config,
|
||||||
|
api_base=api_base,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
data=json.dumps(data),
|
data=data,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
stream=False,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
|
||||||
raise self._handle_error(e=e, provider_config=provider_config)
|
|
||||||
return provider_config.transform_response(
|
return provider_config.transform_response(
|
||||||
model=model,
|
model=model,
|
||||||
raw_response=response,
|
raw_response=response,
|
||||||
|
@ -93,7 +202,7 @@ class BaseLLMHTTPHandler:
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
fake_stream: bool = False,
|
fake_stream: bool = False,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
headers={},
|
headers: Optional[dict] = {},
|
||||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||||
):
|
):
|
||||||
provider_config = ProviderConfigManager.get_provider_chat_config(
|
provider_config = ProviderConfigManager.get_provider_chat_config(
|
||||||
|
@ -102,10 +211,11 @@ class BaseLLMHTTPHandler:
|
||||||
# get config from model, custom llm provider
|
# get config from model, custom llm provider
|
||||||
headers = provider_config.validate_environment(
|
headers = provider_config.validate_environment(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
headers=headers,
|
headers=headers or {},
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
api_base=api_base,
|
||||||
)
|
)
|
||||||
|
|
||||||
api_base = provider_config.get_complete_url(
|
api_base = provider_config.get_complete_url(
|
||||||
|
@ -154,6 +264,7 @@ class BaseLLMHTTPHandler:
|
||||||
if client is not None and isinstance(client, AsyncHTTPHandler)
|
if client is not None and isinstance(client, AsyncHTTPHandler)
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
@ -186,7 +297,7 @@ class BaseLLMHTTPHandler:
|
||||||
provider_config=provider_config,
|
provider_config=provider_config,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
headers=headers, # type: ignore
|
headers=headers, # type: ignore
|
||||||
data=json.dumps(data),
|
data=data,
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
|
@ -197,6 +308,7 @@ class BaseLLMHTTPHandler:
|
||||||
if client is not None and isinstance(client, HTTPHandler)
|
if client is not None and isinstance(client, HTTPHandler)
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
return CustomStreamWrapper(
|
return CustomStreamWrapper(
|
||||||
completion_stream=completion_stream,
|
completion_stream=completion_stream,
|
||||||
|
@ -210,19 +322,15 @@ class BaseLLMHTTPHandler:
|
||||||
else:
|
else:
|
||||||
sync_httpx_client = client
|
sync_httpx_client = client
|
||||||
|
|
||||||
try:
|
response = self._make_common_sync_call(
|
||||||
response = sync_httpx_client.post(
|
sync_httpx_client=sync_httpx_client,
|
||||||
url=api_base,
|
|
||||||
headers=headers,
|
|
||||||
data=json.dumps(data),
|
|
||||||
timeout=timeout,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
raise self._handle_error(
|
|
||||||
e=e,
|
|
||||||
provider_config=provider_config,
|
provider_config=provider_config,
|
||||||
|
api_base=api_base,
|
||||||
|
headers=headers,
|
||||||
|
data=data,
|
||||||
|
timeout=timeout,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
return provider_config.transform_response(
|
return provider_config.transform_response(
|
||||||
model=model,
|
model=model,
|
||||||
raw_response=response,
|
raw_response=response,
|
||||||
|
@ -241,43 +349,32 @@ class BaseLLMHTTPHandler:
|
||||||
provider_config: BaseConfig,
|
provider_config: BaseConfig,
|
||||||
api_base: str,
|
api_base: str,
|
||||||
headers: dict,
|
headers: dict,
|
||||||
data: str,
|
data: dict,
|
||||||
model: str,
|
model: str,
|
||||||
messages: list,
|
messages: list,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
timeout: Optional[Union[float, httpx.Timeout]],
|
litellm_params: dict,
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
fake_stream: bool = False,
|
fake_stream: bool = False,
|
||||||
client: Optional[HTTPHandler] = None,
|
client: Optional[HTTPHandler] = None,
|
||||||
) -> Tuple[Any, httpx.Headers]:
|
) -> Tuple[Any, dict]:
|
||||||
if client is None or not isinstance(client, HTTPHandler):
|
if client is None or not isinstance(client, HTTPHandler):
|
||||||
sync_httpx_client = _get_httpx_client()
|
sync_httpx_client = _get_httpx_client()
|
||||||
else:
|
else:
|
||||||
sync_httpx_client = client
|
sync_httpx_client = client
|
||||||
try:
|
|
||||||
stream = True
|
stream = True
|
||||||
if fake_stream is True:
|
if fake_stream is True:
|
||||||
stream = False
|
stream = False
|
||||||
response = sync_httpx_client.post(
|
|
||||||
api_base, headers=headers, data=data, timeout=timeout, stream=stream
|
|
||||||
)
|
|
||||||
except httpx.HTTPStatusError as e:
|
|
||||||
raise self._handle_error(
|
|
||||||
e=e,
|
|
||||||
provider_config=provider_config,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
for exception in litellm.LITELLM_EXCEPTION_TYPES:
|
|
||||||
if isinstance(e, exception):
|
|
||||||
raise e
|
|
||||||
raise self._handle_error(
|
|
||||||
e=e,
|
|
||||||
provider_config=provider_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
response = self._make_common_sync_call(
|
||||||
raise BaseLLMException(
|
sync_httpx_client=sync_httpx_client,
|
||||||
status_code=response.status_code,
|
provider_config=provider_config,
|
||||||
message=str(response.read()),
|
api_base=api_base,
|
||||||
|
headers=headers,
|
||||||
|
data=data,
|
||||||
|
timeout=timeout,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
stream=stream,
|
||||||
)
|
)
|
||||||
|
|
||||||
if fake_stream is True:
|
if fake_stream is True:
|
||||||
|
@ -297,7 +394,7 @@ class BaseLLMHTTPHandler:
|
||||||
additional_args={"complete_input_dict": data},
|
additional_args={"complete_input_dict": data},
|
||||||
)
|
)
|
||||||
|
|
||||||
return completion_stream, response.headers
|
return completion_stream, dict(response.headers)
|
||||||
|
|
||||||
async def acompletion_stream_function(
|
async def acompletion_stream_function(
|
||||||
self,
|
self,
|
||||||
|
@ -310,6 +407,7 @@ class BaseLLMHTTPHandler:
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
logging_obj: LiteLLMLoggingObj,
|
logging_obj: LiteLLMLoggingObj,
|
||||||
data: dict,
|
data: dict,
|
||||||
|
litellm_params: dict,
|
||||||
fake_stream: bool = False,
|
fake_stream: bool = False,
|
||||||
client: Optional[AsyncHTTPHandler] = None,
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
):
|
):
|
||||||
|
@ -318,12 +416,13 @@ class BaseLLMHTTPHandler:
|
||||||
provider_config=provider_config,
|
provider_config=provider_config,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
data=json.dumps(data),
|
data=data,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
fake_stream=fake_stream,
|
fake_stream=fake_stream,
|
||||||
client=client,
|
client=client,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
streamwrapper = CustomStreamWrapper(
|
streamwrapper = CustomStreamWrapper(
|
||||||
completion_stream=completion_stream,
|
completion_stream=completion_stream,
|
||||||
|
@ -339,10 +438,11 @@ class BaseLLMHTTPHandler:
|
||||||
provider_config: BaseConfig,
|
provider_config: BaseConfig,
|
||||||
api_base: str,
|
api_base: str,
|
||||||
headers: dict,
|
headers: dict,
|
||||||
data: str,
|
data: dict,
|
||||||
messages: list,
|
messages: list,
|
||||||
logging_obj: LiteLLMLoggingObj,
|
logging_obj: LiteLLMLoggingObj,
|
||||||
timeout: Optional[Union[float, httpx.Timeout]],
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
litellm_params: dict,
|
||||||
fake_stream: bool = False,
|
fake_stream: bool = False,
|
||||||
client: Optional[AsyncHTTPHandler] = None,
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
) -> Tuple[Any, httpx.Headers]:
|
) -> Tuple[Any, httpx.Headers]:
|
||||||
|
@ -355,29 +455,18 @@ class BaseLLMHTTPHandler:
|
||||||
stream = True
|
stream = True
|
||||||
if fake_stream is True:
|
if fake_stream is True:
|
||||||
stream = False
|
stream = False
|
||||||
try:
|
|
||||||
response = await async_httpx_client.post(
|
response = await self._make_common_async_call(
|
||||||
api_base, headers=headers, data=data, stream=stream, timeout=timeout
|
async_httpx_client=async_httpx_client,
|
||||||
)
|
|
||||||
except httpx.HTTPStatusError as e:
|
|
||||||
raise self._handle_error(
|
|
||||||
e=e,
|
|
||||||
provider_config=provider_config,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
for exception in litellm.LITELLM_EXCEPTION_TYPES:
|
|
||||||
if isinstance(e, exception):
|
|
||||||
raise e
|
|
||||||
raise self._handle_error(
|
|
||||||
e=e,
|
|
||||||
provider_config=provider_config,
|
provider_config=provider_config,
|
||||||
|
api_base=api_base,
|
||||||
|
headers=headers,
|
||||||
|
data=data,
|
||||||
|
timeout=timeout,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
stream=stream,
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.status_code != 200:
|
|
||||||
raise BaseLLMException(
|
|
||||||
status_code=response.status_code,
|
|
||||||
message=str(response.read()),
|
|
||||||
)
|
|
||||||
if fake_stream is True:
|
if fake_stream is True:
|
||||||
completion_stream = provider_config.get_model_response_iterator(
|
completion_stream = provider_config.get_model_response_iterator(
|
||||||
streaming_response=response.json(), sync_stream=False
|
streaming_response=response.json(), sync_stream=False
|
||||||
|
|
|
@ -118,6 +118,7 @@ class DeepgramAudioTranscriptionConfig(BaseAudioTranscriptionConfig):
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
api_key = api_key or get_secret_str("DEEPGRAM_API_KEY")
|
api_key = api_key or get_secret_str("DEEPGRAM_API_KEY")
|
||||||
return {
|
return {
|
||||||
|
|
|
@ -42,6 +42,7 @@ class FireworksAIMixin:
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
api_key = self._get_api_key(api_key)
|
api_key = self._get_api_key(api_key)
|
||||||
if api_key is None:
|
if api_key is None:
|
||||||
|
|
|
@ -724,12 +724,14 @@ class Huggingface(BaseLLM):
|
||||||
token_logprob = token["logprob"]
|
token_logprob = token["logprob"]
|
||||||
|
|
||||||
# Add the token information to the 'token_info' list
|
# Add the token information to the 'token_info' list
|
||||||
_logprob.tokens.append(token_text)
|
cast(List[str], _logprob.tokens).append(token_text)
|
||||||
_logprob.token_logprobs.append(token_logprob)
|
cast(List[float], _logprob.token_logprobs).append(token_logprob)
|
||||||
|
|
||||||
# stub this to work with llm eval harness
|
# stub this to work with llm eval harness
|
||||||
top_alt_tokens = {"": -1.0, "": -2.0, "": -3.0} # noqa: F601
|
top_alt_tokens = {"": -1.0, "": -2.0, "": -3.0} # noqa: F601
|
||||||
_logprob.top_logprobs.append(top_alt_tokens)
|
cast(List[Dict[str, float]], _logprob.top_logprobs).append(
|
||||||
|
top_alt_tokens
|
||||||
|
)
|
||||||
|
|
||||||
# For each element in the 'tokens' list, extract the relevant information
|
# For each element in the 'tokens' list, extract the relevant information
|
||||||
for i, token in enumerate(response_details["tokens"]):
|
for i, token in enumerate(response_details["tokens"]):
|
||||||
|
@ -751,13 +753,15 @@ class Huggingface(BaseLLM):
|
||||||
top_alt_tokens[text] = logprob
|
top_alt_tokens[text] = logprob
|
||||||
|
|
||||||
# Add the token information to the 'token_info' list
|
# Add the token information to the 'token_info' list
|
||||||
_logprob.tokens.append(token_text)
|
cast(List[str], _logprob.tokens).append(token_text)
|
||||||
_logprob.token_logprobs.append(token_logprob)
|
cast(List[float], _logprob.token_logprobs).append(token_logprob)
|
||||||
_logprob.top_logprobs.append(top_alt_tokens)
|
cast(List[Dict[str, float]], _logprob.top_logprobs).append(
|
||||||
|
top_alt_tokens
|
||||||
|
)
|
||||||
|
|
||||||
# Add the text offset of the token
|
# Add the text offset of the token
|
||||||
# This is computed as the sum of the lengths of all previous tokens
|
# This is computed as the sum of the lengths of all previous tokens
|
||||||
_logprob.text_offset.append(
|
cast(List[int], _logprob.text_offset).append(
|
||||||
sum(len(t["text"]) for t in response_details["tokens"][:i])
|
sum(len(t["text"]) for t in response_details["tokens"][:i])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -356,6 +356,7 @@ class HuggingfaceChatConfig(BaseConfig):
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: Dict,
|
optional_params: Dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
default_headers = {
|
default_headers = {
|
||||||
"content-type": "application/json",
|
"content-type": "application/json",
|
||||||
|
|
|
@ -94,6 +94,7 @@ class NLPCloudConfig(BaseConfig):
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
headers = {
|
headers = {
|
||||||
"accept": "application/json",
|
"accept": "application/json",
|
||||||
|
|
|
@ -347,6 +347,7 @@ class OllamaConfig(BaseConfig):
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
|
|
|
@ -89,6 +89,7 @@ class OobaboogaConfig(OpenAIGPTConfig):
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
headers = {
|
headers = {
|
||||||
"accept": "application/json",
|
"accept": "application/json",
|
||||||
|
|
|
@ -181,6 +181,7 @@ class OpenAIGPTConfig(BaseConfig):
|
||||||
Returns:
|
Returns:
|
||||||
dict: The transformed request. Sent as the body of the API call.
|
dict: The transformed request. Sent as the body of the API call.
|
||||||
"""
|
"""
|
||||||
|
messages = self._transform_messages(messages=messages, model=model)
|
||||||
return {
|
return {
|
||||||
"model": model,
|
"model": model,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
|
@ -225,5 +226,6 @@ class OpenAIGPTConfig(BaseConfig):
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
|
@ -45,7 +45,8 @@ class OpenAIError(BaseLLMException):
|
||||||
####### Error Handling Utils for OpenAI API #######################
|
####### Error Handling Utils for OpenAI API #######################
|
||||||
###################################################################
|
###################################################################
|
||||||
def drop_params_from_unprocessable_entity_error(
|
def drop_params_from_unprocessable_entity_error(
|
||||||
e: openai.UnprocessableEntityError, data: Dict[str, Any]
|
e: Union[openai.UnprocessableEntityError, httpx.HTTPStatusError],
|
||||||
|
data: Dict[str, Any],
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Helper function to read OpenAI UnprocessableEntityError and drop the params that raised an error from the error message.
|
Helper function to read OpenAI UnprocessableEntityError and drop the params that raised an error from the error message.
|
||||||
|
@ -58,14 +59,25 @@ def drop_params_from_unprocessable_entity_error(
|
||||||
Dict[str, Any]: A new dictionary with invalid parameters removed
|
Dict[str, Any]: A new dictionary with invalid parameters removed
|
||||||
"""
|
"""
|
||||||
invalid_params: List[str] = []
|
invalid_params: List[str] = []
|
||||||
if e.body is not None and isinstance(e.body, dict) and e.body.get("message"):
|
if isinstance(e, httpx.HTTPStatusError):
|
||||||
message = e.body.get("message", {})
|
error_json = e.response.json()
|
||||||
|
error_message = error_json.get("error", {})
|
||||||
|
error_body = error_message
|
||||||
|
else:
|
||||||
|
error_body = e.body
|
||||||
|
if (
|
||||||
|
error_body is not None
|
||||||
|
and isinstance(error_body, dict)
|
||||||
|
and error_body.get("message")
|
||||||
|
):
|
||||||
|
message = error_body.get("message", {})
|
||||||
if isinstance(message, str):
|
if isinstance(message, str):
|
||||||
try:
|
try:
|
||||||
message = json.loads(message)
|
message = json.loads(message)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
message = {"detail": message}
|
message = {"detail": message}
|
||||||
detail = message.get("detail")
|
detail = message.get("detail")
|
||||||
|
|
||||||
if isinstance(detail, List) and len(detail) > 0 and isinstance(detail[0], dict):
|
if isinstance(detail, List) and len(detail) > 0 and isinstance(detail[0], dict):
|
||||||
for error_dict in detail:
|
for error_dict in detail:
|
||||||
if (
|
if (
|
||||||
|
@ -76,4 +88,5 @@ def drop_params_from_unprocessable_entity_error(
|
||||||
invalid_params.append(error_dict["loc"][1])
|
invalid_params.append(error_dict["loc"][1])
|
||||||
|
|
||||||
new_data = {k: v for k, v in data.items() if k not in invalid_params}
|
new_data = {k: v for k, v in data.items() if k not in invalid_params}
|
||||||
|
|
||||||
return new_data
|
return new_data
|
||||||
|
|
|
@ -2,9 +2,11 @@ import hashlib
|
||||||
import types
|
import types
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
|
AsyncIterator,
|
||||||
Callable,
|
Callable,
|
||||||
Coroutine,
|
Coroutine,
|
||||||
Iterable,
|
Iterable,
|
||||||
|
Iterator,
|
||||||
List,
|
List,
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
Optional,
|
||||||
|
@ -24,10 +26,16 @@ import litellm
|
||||||
from litellm import LlmProviders
|
from litellm import LlmProviders
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
|
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||||
from litellm.llms.bedrock.chat.invoke_handler import MockResponseIterator
|
from litellm.llms.bedrock.chat.invoke_handler import MockResponseIterator
|
||||||
from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS
|
from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS
|
||||||
from litellm.types.utils import EmbeddingResponse, ImageResponse, ModelResponse
|
from litellm.types.utils import (
|
||||||
|
EmbeddingResponse,
|
||||||
|
ImageResponse,
|
||||||
|
ModelResponse,
|
||||||
|
ModelResponseStream,
|
||||||
|
)
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
CustomStreamWrapper,
|
CustomStreamWrapper,
|
||||||
ProviderConfigManager,
|
ProviderConfigManager,
|
||||||
|
@ -36,7 +44,6 @@ from litellm.utils import (
|
||||||
|
|
||||||
from ...types.llms.openai import *
|
from ...types.llms.openai import *
|
||||||
from ..base import BaseLLM
|
from ..base import BaseLLM
|
||||||
from .chat.gpt_transformation import OpenAIGPTConfig
|
|
||||||
from .common_utils import OpenAIError, drop_params_from_unprocessable_entity_error
|
from .common_utils import OpenAIError, drop_params_from_unprocessable_entity_error
|
||||||
|
|
||||||
|
|
||||||
|
@ -232,6 +239,7 @@ class OpenAIConfig(BaseConfig):
|
||||||
litellm_params: dict,
|
litellm_params: dict,
|
||||||
headers: dict,
|
headers: dict,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
messages = self._transform_messages(messages=messages, model=model)
|
||||||
return {"model": model, "messages": messages, **optional_params}
|
return {"model": model, "messages": messages, **optional_params}
|
||||||
|
|
||||||
def transform_response(
|
def transform_response(
|
||||||
|
@ -248,10 +256,21 @@ class OpenAIConfig(BaseConfig):
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
json_mode: Optional[bool] = None,
|
json_mode: Optional[bool] = None,
|
||||||
) -> ModelResponse:
|
) -> ModelResponse:
|
||||||
raise NotImplementedError(
|
|
||||||
"OpenAI handler does this transformation as it uses the OpenAI SDK."
|
logging_obj.post_call(original_response=raw_response.text)
|
||||||
|
logging_obj.model_call_details["response_headers"] = raw_response.headers
|
||||||
|
final_response_obj = cast(
|
||||||
|
ModelResponse,
|
||||||
|
convert_to_model_response_object(
|
||||||
|
response_object=raw_response.json(),
|
||||||
|
model_response_object=model_response,
|
||||||
|
hidden_params={"headers": raw_response.headers},
|
||||||
|
_response_headers=dict(raw_response.headers),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return final_response_obj
|
||||||
|
|
||||||
def validate_environment(
|
def validate_environment(
|
||||||
self,
|
self,
|
||||||
headers: dict,
|
headers: dict,
|
||||||
|
@ -259,12 +278,37 @@ class OpenAIConfig(BaseConfig):
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
raise NotImplementedError(
|
return {
|
||||||
"OpenAI handler does this validation as it uses the OpenAI SDK."
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
**headers,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_model_response_iterator(
|
||||||
|
self,
|
||||||
|
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||||
|
sync_stream: bool,
|
||||||
|
json_mode: Optional[bool] = False,
|
||||||
|
) -> Any:
|
||||||
|
return OpenAIChatCompletionResponseIterator(
|
||||||
|
streaming_response=streaming_response,
|
||||||
|
sync_stream=sync_stream,
|
||||||
|
json_mode=json_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIChatCompletionResponseIterator(BaseModelResponseIterator):
|
||||||
|
def chunk_parser(self, chunk: dict) -> ModelResponseStream:
|
||||||
|
"""
|
||||||
|
{'choices': [{'delta': {'content': '', 'role': 'assistant'}, 'finish_reason': None, 'index': 0, 'logprobs': None}], 'created': 1735763082, 'id': 'a83a2b0fbfaf4aab9c2c93cb8ba346d7', 'model': 'mistral-large', 'object': 'chat.completion.chunk'}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return ModelResponseStream(**chunk)
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
class OpenAIChatCompletion(BaseLLM):
|
class OpenAIChatCompletion(BaseLLM):
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
@ -473,14 +517,6 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
if custom_llm_provider is not None and custom_llm_provider != "openai":
|
if custom_llm_provider is not None and custom_llm_provider != "openai":
|
||||||
model_response.model = f"{custom_llm_provider}/{model}"
|
model_response.model = f"{custom_llm_provider}/{model}"
|
||||||
|
|
||||||
if messages is not None and provider_config is not None:
|
|
||||||
if isinstance(provider_config, OpenAIGPTConfig) or isinstance(
|
|
||||||
provider_config, OpenAIConfig
|
|
||||||
): # [TODO]: remove. no longer needed as .transform_request can just handle this.
|
|
||||||
messages = provider_config._transform_messages(
|
|
||||||
messages=messages, model=model
|
|
||||||
)
|
|
||||||
|
|
||||||
for _ in range(
|
for _ in range(
|
||||||
2
|
2
|
||||||
): # if call fails due to alternating messages, retry with reformatted message
|
): # if call fails due to alternating messages, retry with reformatted message
|
||||||
|
@ -647,12 +683,10 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
new_messages = messages
|
new_messages = messages
|
||||||
new_messages.append({"role": "user", "content": ""})
|
new_messages.append({"role": "user", "content": ""})
|
||||||
messages = new_messages
|
messages = new_messages
|
||||||
elif (
|
elif "unknown field: parameter index is not a valid field" in str(
|
||||||
"unknown field: parameter index is not a valid field" in str(e)
|
e
|
||||||
) and "tools" in data:
|
):
|
||||||
litellm.remove_index_from_tool_calls(
|
litellm.remove_index_from_tool_calls(messages=messages)
|
||||||
tool_calls=data["tools"], messages=messages
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
except OpenAIError as e:
|
except OpenAIError as e:
|
||||||
|
|
|
@ -132,5 +132,6 @@ class PetalsConfig(BaseConfig):
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
return {}
|
return {}
|
||||||
|
|
|
@ -164,6 +164,7 @@ class PredibaseConfig(BaseConfig):
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
if api_key is None:
|
if api_key is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
|
@ -309,6 +309,7 @@ class ReplicateConfig(BaseConfig):
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Token {api_key}",
|
"Authorization": f"Token {api_key}",
|
||||||
|
|
|
@ -260,6 +260,7 @@ class SagemakerConfig(BaseConfig):
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
headers = {"Content-Type": "application/json"}
|
headers = {"Content-Type": "application/json"}
|
||||||
|
|
||||||
|
|
|
@ -48,6 +48,7 @@ class TritonConfig(BaseConfig):
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: Dict,
|
optional_params: Dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
return {"Content-Type": "application/json"}
|
return {"Content-Type": "application/json"}
|
||||||
|
|
||||||
|
|
|
@ -43,6 +43,7 @@ class TritonEmbeddingConfig(BaseEmbeddingConfig):
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
|
@ -808,6 +808,7 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: Dict,
|
optional_params: Dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
default_headers = {
|
default_headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
|
|
|
@ -82,6 +82,7 @@ class VoyageEmbeddingConfig(BaseEmbeddingConfig):
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
if api_key is None:
|
if api_key is None:
|
||||||
api_key = (
|
api_key = (
|
||||||
|
|
|
@ -166,6 +166,7 @@ class IBMWatsonXMixin:
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: Dict,
|
optional_params: Dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
default_headers = {
|
default_headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
|
|
|
@ -1122,6 +1122,7 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
custom_prompt_dict=custom_prompt_dict,
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
litellm_metadata=kwargs.get("litellm_metadata"),
|
litellm_metadata=kwargs.get("litellm_metadata"),
|
||||||
disable_add_transform_inline_image_block=disable_add_transform_inline_image_block,
|
disable_add_transform_inline_image_block=disable_add_transform_inline_image_block,
|
||||||
|
drop_params=kwargs.get("drop_params"),
|
||||||
)
|
)
|
||||||
logging.update_environment_variables(
|
logging.update_environment_variables(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -1347,39 +1348,28 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
if extra_headers is not None:
|
if extra_headers is not None:
|
||||||
optional_params["extra_headers"] = extra_headers
|
optional_params["extra_headers"] = extra_headers
|
||||||
|
|
||||||
## LOAD CONFIG - if set
|
|
||||||
config = litellm.AzureAIStudioConfig.get_config()
|
|
||||||
for k, v in config.items():
|
|
||||||
if (
|
|
||||||
k not in optional_params
|
|
||||||
): # completion(top_k=3) > openai_config(top_k=3) <- allows for dynamic variables to be passed in
|
|
||||||
optional_params[k] = v
|
|
||||||
|
|
||||||
## FOR COHERE
|
## FOR COHERE
|
||||||
if "command-r" in model: # make sure tool call in messages are str
|
if "command-r" in model: # make sure tool call in messages are str
|
||||||
messages = stringify_json_tool_call_content(messages=messages)
|
messages = stringify_json_tool_call_content(messages=messages)
|
||||||
|
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
try:
|
try:
|
||||||
response = openai_chat_completions.completion(
|
response = base_llm_http_handler.completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
print_verbose=print_verbose,
|
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
acompletion=acompletion,
|
acompletion=acompletion,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
logger_fn=logger_fn,
|
|
||||||
timeout=timeout, # type: ignore
|
timeout=timeout, # type: ignore
|
||||||
custom_prompt_dict=custom_prompt_dict,
|
|
||||||
client=client, # pass AsyncOpenAI, OpenAI client
|
client=client, # pass AsyncOpenAI, OpenAI client
|
||||||
organization=organization,
|
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
drop_params=non_default_params.get("drop_params"),
|
encoding=encoding,
|
||||||
|
stream=stream,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
## LOGGING - log the original exception returned
|
## LOGGING - log the original exception returned
|
||||||
|
|
|
@ -1074,10 +1074,10 @@ class EmbeddingResponse(OpenAIObject):
|
||||||
|
|
||||||
|
|
||||||
class Logprobs(OpenAIObject):
|
class Logprobs(OpenAIObject):
|
||||||
text_offset: List[int]
|
text_offset: Optional[List[int]]
|
||||||
token_logprobs: List[Union[float, None]]
|
token_logprobs: Optional[List[Union[float, None]]]
|
||||||
tokens: List[str]
|
tokens: Optional[List[str]]
|
||||||
top_logprobs: List[Union[Dict[str, float], None]]
|
top_logprobs: Optional[List[Union[Dict[str, float], None]]]
|
||||||
|
|
||||||
|
|
||||||
class TextChoices(OpenAIObject):
|
class TextChoices(OpenAIObject):
|
||||||
|
|
|
@ -2002,6 +2002,7 @@ def get_litellm_params(
|
||||||
custom_prompt_dict: Optional[dict] = None,
|
custom_prompt_dict: Optional[dict] = None,
|
||||||
litellm_metadata: Optional[dict] = None,
|
litellm_metadata: Optional[dict] = None,
|
||||||
disable_add_transform_inline_image_block: Optional[bool] = None,
|
disable_add_transform_inline_image_block: Optional[bool] = None,
|
||||||
|
drop_params: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
litellm_params = {
|
litellm_params = {
|
||||||
"acompletion": acompletion,
|
"acompletion": acompletion,
|
||||||
|
@ -2035,6 +2036,7 @@ def get_litellm_params(
|
||||||
"custom_prompt_dict": custom_prompt_dict,
|
"custom_prompt_dict": custom_prompt_dict,
|
||||||
"litellm_metadata": litellm_metadata,
|
"litellm_metadata": litellm_metadata,
|
||||||
"disable_add_transform_inline_image_block": disable_add_transform_inline_image_block,
|
"disable_add_transform_inline_image_block": disable_add_transform_inline_image_block,
|
||||||
|
"drop_params": drop_params,
|
||||||
}
|
}
|
||||||
return litellm_params
|
return litellm_params
|
||||||
|
|
||||||
|
@ -6345,3 +6347,44 @@ def extract_duration_from_srt_or_vtt(srt_or_vtt_content: str) -> Optional[float]
|
||||||
durations.append(total_seconds)
|
durations.append(total_seconds)
|
||||||
|
|
||||||
return max(durations) if durations else None
|
return max(durations) if durations else None
|
||||||
|
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
|
||||||
|
def _add_path_to_api_base(api_base: str, ending_path: str) -> str:
|
||||||
|
"""
|
||||||
|
Adds an ending path to an API base URL while preventing duplicate path segments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_base: Base URL string
|
||||||
|
ending_path: Path to append to the base URL
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Modified URL string with proper path handling
|
||||||
|
"""
|
||||||
|
original_url = httpx.URL(api_base)
|
||||||
|
base_url = original_url.copy_with(params={}) # Removes query params
|
||||||
|
base_path = original_url.path.rstrip("/")
|
||||||
|
end_path = ending_path.lstrip("/")
|
||||||
|
|
||||||
|
# Split paths into segments
|
||||||
|
base_segments = [s for s in base_path.split("/") if s]
|
||||||
|
end_segments = [s for s in end_path.split("/") if s]
|
||||||
|
|
||||||
|
# Find overlapping segments from the end of base_path and start of ending_path
|
||||||
|
final_segments = []
|
||||||
|
for i in range(len(base_segments)):
|
||||||
|
if base_segments[i:] == end_segments[: len(base_segments) - i]:
|
||||||
|
final_segments = base_segments[:i] + end_segments
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# No overlap found, just combine all segments
|
||||||
|
final_segments = base_segments + end_segments
|
||||||
|
|
||||||
|
# Construct the new path
|
||||||
|
modified_path = "/" + "/".join(final_segments)
|
||||||
|
modified_url = base_url.copy_with(path=modified_path)
|
||||||
|
|
||||||
|
# Re-add the original query parameters
|
||||||
|
return str(modified_url.copy_with(params=original_url.params))
|
||||||
|
|
|
@ -28,6 +28,7 @@ from unittest.mock import MagicMock, patch
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm import completion
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -51,18 +52,13 @@ async def test_azure_ai_with_image_url():
|
||||||
|
|
||||||
Test that Azure AI studio can handle image_url passed when content is a list containing both text and image_url
|
Test that Azure AI studio can handle image_url passed when content is a list containing both text and image_url
|
||||||
"""
|
"""
|
||||||
from openai import AsyncOpenAI
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
|
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
|
||||||
client = AsyncOpenAI(
|
client = AsyncHTTPHandler()
|
||||||
api_key="fake-api-key",
|
|
||||||
base_url="https://Phi-3-5-vision-instruct-dcvov.eastus2.models.ai.azure.com",
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch.object(
|
with patch.object(client, "post") as mock_client:
|
||||||
client.chat.completions.with_raw_response, "create"
|
|
||||||
) as mock_client:
|
|
||||||
try:
|
try:
|
||||||
await litellm.acompletion(
|
await litellm.acompletion(
|
||||||
model="azure_ai/Phi-3-5-vision-instruct-dcvov",
|
model="azure_ai/Phi-3-5-vision-instruct-dcvov",
|
||||||
|
@ -94,8 +90,9 @@ async def test_azure_ai_with_image_url():
|
||||||
# Verify the request was made
|
# Verify the request was made
|
||||||
mock_client.assert_called_once()
|
mock_client.assert_called_once()
|
||||||
|
|
||||||
|
print(f"mock_client.call_args.kwargs: {mock_client.call_args.kwargs}")
|
||||||
# Check the request body
|
# Check the request body
|
||||||
request_body = mock_client.call_args.kwargs
|
request_body = json.loads(mock_client.call_args.kwargs["data"])
|
||||||
assert request_body["model"] == "Phi-3-5-vision-instruct-dcvov"
|
assert request_body["model"] == "Phi-3-5-vision-instruct-dcvov"
|
||||||
assert request_body["messages"] == [
|
assert request_body["messages"] == [
|
||||||
{
|
{
|
||||||
|
@ -111,3 +108,79 @@ async def test_azure_ai_with_image_url():
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"api_base, expected_url",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview",
|
||||||
|
"https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"https://litellm8397336933.services.ai.azure.com/models/chat/completions",
|
||||||
|
"https://litellm8397336933.services.ai.azure.com/models/chat/completions",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"https://litellm8397336933.services.ai.azure.com/models",
|
||||||
|
"https://litellm8397336933.services.ai.azure.com/models/chat/completions",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"https://litellm8397336933.services.ai.azure.com",
|
||||||
|
"https://litellm8397336933.services.ai.azure.com/models/chat/completions",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_azure_ai_services_handler(api_base, expected_url):
|
||||||
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
|
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
client = HTTPHandler()
|
||||||
|
|
||||||
|
with patch.object(client, "post") as mock_client:
|
||||||
|
try:
|
||||||
|
response = litellm.completion(
|
||||||
|
model="azure_ai/Meta-Llama-3.1-70B-Instruct",
|
||||||
|
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||||
|
api_key="my-fake-api-key",
|
||||||
|
api_base=api_base,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
|
||||||
|
mock_client.assert_called_once()
|
||||||
|
assert mock_client.call_args.kwargs["headers"]["api-key"] == "my-fake-api-key"
|
||||||
|
assert mock_client.call_args.kwargs["url"] == expected_url
|
||||||
|
|
||||||
|
|
||||||
|
def test_completion_azure_ai_command_r():
|
||||||
|
try:
|
||||||
|
import os
|
||||||
|
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
os.environ["AZURE_AI_API_BASE"] = os.getenv("AZURE_COHERE_API_BASE", "")
|
||||||
|
os.environ["AZURE_AI_API_KEY"] = os.getenv("AZURE_COHERE_API_KEY", "")
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="azure_ai/command-r-plus",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "What is the meaning of life?"}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
assert "azure_ai" in response.model
|
||||||
|
except litellm.Timeout as e:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
|
@ -132,34 +132,6 @@ def test_null_role_response():
|
||||||
assert response.choices[0].message.role == "assistant"
|
assert response.choices[0].message.role == "assistant"
|
||||||
|
|
||||||
|
|
||||||
def test_completion_azure_ai_command_r():
|
|
||||||
try:
|
|
||||||
import os
|
|
||||||
|
|
||||||
litellm.set_verbose = True
|
|
||||||
|
|
||||||
os.environ["AZURE_AI_API_BASE"] = os.getenv("AZURE_COHERE_API_BASE", "")
|
|
||||||
os.environ["AZURE_AI_API_KEY"] = os.getenv("AZURE_COHERE_API_KEY", "")
|
|
||||||
|
|
||||||
response = completion(
|
|
||||||
model="azure_ai/command-r-plus",
|
|
||||||
messages=[
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{"type": "text", "text": "What is the meaning of life?"}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
],
|
|
||||||
) # type: ignore
|
|
||||||
|
|
||||||
assert "azure_ai" in response.model
|
|
||||||
except litellm.Timeout as e:
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
pytest.fail(f"Error occurred: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_completion_azure_ai_mistral_invalid_params(sync_mode):
|
async def test_completion_azure_ai_mistral_invalid_params(sync_mode):
|
||||||
|
|
|
@ -199,4 +199,4 @@ def test_azure_global_standard_get_llm_provider():
|
||||||
api_base="https://my-deployment-francecentral.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview",
|
api_base="https://my-deployment-francecentral.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview",
|
||||||
api_key="fake-api-key",
|
api_key="fake-api-key",
|
||||||
)
|
)
|
||||||
assert custom_llm_provider == "azure"
|
assert custom_llm_provider == "azure_ai"
|
||||||
|
|
|
@ -2954,6 +2954,7 @@ def test_azure_streaming_and_function_calling():
|
||||||
async def test_completion_azure_ai_mistral_invalid_params(sync_mode):
|
async def test_completion_azure_ai_mistral_invalid_params(sync_mode):
|
||||||
try:
|
try:
|
||||||
import os
|
import os
|
||||||
|
from litellm import stream_chunk_builder
|
||||||
|
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
@ -2968,15 +2969,21 @@ async def test_completion_azure_ai_mistral_invalid_params(sync_mode):
|
||||||
"drop_params": True,
|
"drop_params": True,
|
||||||
"stream": True,
|
"stream": True,
|
||||||
}
|
}
|
||||||
|
chunks = []
|
||||||
if sync_mode:
|
if sync_mode:
|
||||||
response: litellm.ModelResponse = completion(**data) # type: ignore
|
response = completion(**data) # type: ignore
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
print(chunk)
|
print(chunk)
|
||||||
|
chunks.append(chunk)
|
||||||
else:
|
else:
|
||||||
response: litellm.ModelResponse = await litellm.acompletion(**data) # type: ignore
|
response = await litellm.acompletion(**data) # type: ignore
|
||||||
|
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
print(chunk)
|
print(chunk)
|
||||||
|
chunks.append(chunk)
|
||||||
|
print(f"chunks: {chunks}")
|
||||||
|
response = stream_chunk_builder(chunks=chunks)
|
||||||
|
assert response.choices[0].message.content is not None
|
||||||
except litellm.Timeout as e:
|
except litellm.Timeout as e:
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -1252,3 +1252,19 @@ def test_fireworks_ai_document_inlining():
|
||||||
|
|
||||||
assert supports_pdf_input("fireworks_ai/llama-3.1-8b-instruct") is True
|
assert supports_pdf_input("fireworks_ai/llama-3.1-8b-instruct") is True
|
||||||
assert supports_vision("fireworks_ai/llama-3.1-8b-instruct") is True
|
assert supports_vision("fireworks_ai/llama-3.1-8b-instruct") is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_logprobs_type():
|
||||||
|
from litellm.types.utils import Logprobs
|
||||||
|
|
||||||
|
logprobs = {
|
||||||
|
"text_offset": None,
|
||||||
|
"token_logprobs": None,
|
||||||
|
"tokens": None,
|
||||||
|
"top_logprobs": None,
|
||||||
|
}
|
||||||
|
logprobs = Logprobs(**logprobs)
|
||||||
|
assert logprobs.text_offset is None
|
||||||
|
assert logprobs.token_logprobs is None
|
||||||
|
assert logprobs.tokens is None
|
||||||
|
assert logprobs.top_logprobs is None
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue