mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
Litellm dev 12 30 2024 p2 (#7495)
* test(azure_openai_o1.py): initial commit with testing for azure openai o1 preview model * fix(base_llm_unit_tests.py): handle azure o1 preview response format tests skip as o1 on azure doesn't support tool calling yet * fix: initial commit of azure o1 handler using openai caller simplifies calling + allows fake streaming logic alr. implemented for openai to just work * feat(azure/o1_handler.py): fake o1 streaming for azure o1 models azure does not currently support streaming for o1 * feat(o1_transformation.py): support overriding 'should_fake_stream' on azure/o1 via 'supports_native_streaming' param on model info enables user to toggle on when azure allows o1 streaming without needing to bump versions * style(router.py): remove 'give feedback/get help' messaging when router is used Prevents noisy messaging Closes https://github.com/BerriAI/litellm/issues/5942 * fix(types/utils.py): handle none logprobs Fixes https://github.com/BerriAI/litellm/issues/328 * fix(exception_mapping_utils.py): fix error str unbound error * refactor(azure_ai/): move to openai_like chat completion handler allows for easy swapping of api base url's (e.g. ai.services.com) Fixes https://github.com/BerriAI/litellm/issues/7275 * refactor(azure_ai/): move to base llm http handler * fix(azure_ai/): handle differing api endpoints * fix(azure_ai/): make sure all unit tests are passing * fix: fix linting errors * fix: fix linting errors * fix: fix linting error * fix: fix linting errors * fix(azure_ai/transformation.py): handle extra body param * fix(azure_ai/transformation.py): fix max retries param handling * fix: fix test * test(test_azure_o1.py): fix test * fix(llm_http_handler.py): support handling azure ai unprocessable entity error * fix(llm_http_handler.py): handle sync invalid param error for azure ai * fix(azure_ai/): streaming support with base_llm_http_handler * fix(llm_http_handler.py): working sync stream calls with unprocessable entity handling for azure ai * fix: fix linting errors * fix(llm_http_handler.py): fix linting error * fix(azure_ai/): handle cohere tool call invalid index param error
This commit is contained in:
parent
0f1b298fe0
commit
0120176541
42 changed files with 638 additions and 192 deletions
|
@ -1,10 +1,11 @@
|
|||
# What is this?
|
||||
## Helper utilities
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionToolParam
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
@ -53,17 +54,18 @@ def map_finish_reason(
|
|||
return finish_reason
|
||||
|
||||
|
||||
def remove_index_from_tool_calls(messages, tool_calls):
|
||||
for tool_call in tool_calls:
|
||||
if "index" in tool_call:
|
||||
tool_call.pop("index")
|
||||
|
||||
def remove_index_from_tool_calls(
|
||||
messages: Optional[List[AllMessageValues]],
|
||||
):
|
||||
if messages is not None:
|
||||
for message in messages:
|
||||
if "tool_calls" in message:
|
||||
tool_calls = message["tool_calls"]
|
||||
for tool_call in tool_calls:
|
||||
if "index" in tool_call:
|
||||
tool_call.pop("index")
|
||||
_tool_calls = message.get("tool_calls")
|
||||
if _tool_calls is not None and isinstance(_tool_calls, list):
|
||||
for tool_call in _tool_calls:
|
||||
if (
|
||||
isinstance(tool_call, dict) and "index" in tool_call
|
||||
): # Type guard to ensure it's a dict
|
||||
tool_call.pop("index", None)
|
||||
|
||||
return
|
||||
|
||||
|
|
|
@ -148,11 +148,10 @@ def exception_type( # type: ignore # noqa: PLR0915
|
|||
original_exception=original_exception
|
||||
)
|
||||
try:
|
||||
error_str = str(original_exception)
|
||||
if model:
|
||||
if hasattr(original_exception, "message"):
|
||||
error_str = str(original_exception.message)
|
||||
else:
|
||||
error_str = str(original_exception)
|
||||
if isinstance(original_exception, BaseException):
|
||||
exception_type = type(original_exception).__name__
|
||||
else:
|
||||
|
|
|
@ -741,6 +741,7 @@ class AnthropicConfig(BaseConfig):
|
|||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> Dict:
|
||||
if api_key is None:
|
||||
raise litellm.AuthenticationError(
|
||||
|
|
|
@ -85,6 +85,7 @@ class AnthropicTextConfig(BaseConfig):
|
|||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
|
|
|
@ -283,6 +283,7 @@ class AzureOpenAIConfig(BaseConfig):
|
|||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
raise NotImplementedError(
|
||||
"Azure OpenAI has custom logic for validating environment, as it uses the OpenAI SDK."
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
from typing import List, Optional, Tuple
|
||||
from typing import Any, List, Optional, Tuple, cast
|
||||
|
||||
import httpx
|
||||
from httpx import Response
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
|
@ -6,13 +9,81 @@ from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
|||
_audio_or_image_in_message_content,
|
||||
convert_content_list_to_str,
|
||||
)
|
||||
from litellm.llms.base_llm.chat.transformation import LiteLLMLoggingObj
|
||||
from litellm.llms.openai.common_utils import drop_params_from_unprocessable_entity_error
|
||||
from litellm.llms.openai.openai import OpenAIConfig
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ProviderField
|
||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionToolParam
|
||||
from litellm.types.utils import ModelResponse, ProviderField
|
||||
from litellm.utils import _add_path_to_api_base
|
||||
|
||||
|
||||
class AzureAIStudioConfig(OpenAIConfig):
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
if api_base and "services.ai.azure.com" in api_base:
|
||||
headers["api-key"] = api_key
|
||||
else:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
return headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: str,
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Constructs a complete URL for the API request.
|
||||
|
||||
Args:
|
||||
- api_base: Base URL, e.g.,
|
||||
"https://litellm8397336933.services.ai.azure.com"
|
||||
OR
|
||||
"https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview"
|
||||
- model: Model name.
|
||||
- optional_params: Additional query parameters, including "api_version".
|
||||
- stream: If streaming is required (optional).
|
||||
|
||||
Returns:
|
||||
- A complete URL string, e.g.,
|
||||
"https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview"
|
||||
"""
|
||||
original_url = httpx.URL(api_base)
|
||||
|
||||
# Extract api_version or use default
|
||||
api_version = cast(Optional[str], optional_params.get("api_version"))
|
||||
|
||||
# Check if 'api-version' is already present
|
||||
if "api-version" not in original_url.params and api_version:
|
||||
# Add api_version to optional_params
|
||||
original_url.params["api-version"] = api_version
|
||||
|
||||
# Add the path to the base URL
|
||||
if "services.ai.azure.com" in api_base:
|
||||
new_url = _add_path_to_api_base(
|
||||
api_base=api_base, ending_path="/models/chat/completions"
|
||||
)
|
||||
else:
|
||||
new_url = _add_path_to_api_base(
|
||||
api_base=api_base, ending_path="/chat/completions"
|
||||
)
|
||||
|
||||
# Convert optional_params to query parameters
|
||||
query_params = original_url.params
|
||||
final_url = httpx.URL(new_url).copy_with(params=query_params)
|
||||
|
||||
return str(final_url)
|
||||
|
||||
def get_required_params(self) -> List[ProviderField]:
|
||||
"""For a given provider, return it's required fields with a description"""
|
||||
return [
|
||||
|
@ -62,8 +133,6 @@ class AzureAIStudioConfig(OpenAIConfig):
|
|||
):
|
||||
return True
|
||||
|
||||
if api_base and "services.ai.azure" in api_base:
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
return False
|
||||
|
@ -86,3 +155,81 @@ class AzureAIStudioConfig(OpenAIConfig):
|
|||
)
|
||||
custom_llm_provider = "azure"
|
||||
return api_base, dynamic_api_key, custom_llm_provider
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
extra_body = optional_params.pop("extra_body", {})
|
||||
if extra_body and isinstance(extra_body, dict):
|
||||
optional_params.update(extra_body)
|
||||
optional_params.pop("max_retries", None)
|
||||
return super().transform_request(
|
||||
model, messages, optional_params, litellm_params, headers
|
||||
)
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
model_response.model = f"azure_ai/{model}"
|
||||
return super().transform_response(
|
||||
model=model,
|
||||
raw_response=raw_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
def should_retry_llm_api_inside_llm_translation_on_http_error(
|
||||
self, e: httpx.HTTPStatusError, litellm_params: dict
|
||||
) -> bool:
|
||||
should_drop_params = litellm_params.get("drop_params") or litellm.drop_params
|
||||
error_text = e.response.text
|
||||
if should_drop_params and "Extra inputs are not permitted" in error_text:
|
||||
return True
|
||||
elif (
|
||||
"unknown field: parameter index is not a valid field" in error_text
|
||||
): # remove index from tool calls
|
||||
return True
|
||||
return super().should_retry_llm_api_inside_llm_translation_on_http_error(
|
||||
e=e, litellm_params=litellm_params
|
||||
)
|
||||
|
||||
@property
|
||||
def max_retry_on_unprocessable_entity_error(self) -> int:
|
||||
return 2
|
||||
|
||||
def transform_request_on_unprocessable_entity_error(
|
||||
self, e: httpx.HTTPStatusError, request_data: dict
|
||||
) -> dict:
|
||||
_messages = cast(Optional[List[AllMessageValues]], request_data.get("messages"))
|
||||
if (
|
||||
"unknown field: parameter index is not a valid field" in e.response.text
|
||||
and _messages is not None
|
||||
):
|
||||
litellm.remove_index_from_tool_calls(
|
||||
messages=_messages,
|
||||
)
|
||||
data = drop_params_from_unprocessable_entity_error(e=e, data=request_data)
|
||||
return data
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
import json
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
from litellm.types.utils import GenericStreamingChunk
|
||||
from litellm.types.utils import GenericStreamingChunk, ModelResponseStream
|
||||
|
||||
|
||||
class BaseModelResponseIterator:
|
||||
|
@ -13,7 +13,9 @@ class BaseModelResponseIterator:
|
|||
self.response_iterator = self.streaming_response
|
||||
self.json_mode = json_mode
|
||||
|
||||
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||
def chunk_parser(
|
||||
self, chunk: dict
|
||||
) -> Union[GenericStreamingChunk, ModelResponseStream]:
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
is_finished=False,
|
||||
|
@ -27,7 +29,9 @@ class BaseModelResponseIterator:
|
|||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def _handle_string_chunk(self, str_line: str) -> GenericStreamingChunk:
|
||||
def _handle_string_chunk(
|
||||
self, str_line: str
|
||||
) -> Union[GenericStreamingChunk, ModelResponseStream]:
|
||||
# chunk is a str at this point
|
||||
if "[DONE]" in str_line:
|
||||
return GenericStreamingChunk(
|
||||
|
|
|
@ -82,6 +82,33 @@ class BaseConfig(ABC):
|
|||
"""
|
||||
return False
|
||||
|
||||
def should_retry_llm_api_inside_llm_translation_on_http_error(
|
||||
self, e: httpx.HTTPStatusError, litellm_params: dict
|
||||
) -> bool:
|
||||
"""
|
||||
Returns True if the model/provider should retry the LLM API on UnprocessableEntityError
|
||||
|
||||
Overriden by azure ai - where different models support different parameters
|
||||
"""
|
||||
return False
|
||||
|
||||
def transform_request_on_unprocessable_entity_error(
|
||||
self, e: httpx.HTTPStatusError, request_data: dict
|
||||
) -> dict:
|
||||
"""
|
||||
Transform the request data on UnprocessableEntityError
|
||||
"""
|
||||
return request_data
|
||||
|
||||
@property
|
||||
def max_retry_on_unprocessable_entity_error(self) -> int:
|
||||
"""
|
||||
Returns the max retry count for UnprocessableEntityError
|
||||
|
||||
Used if `should_retry_llm_api_inside_llm_translation_on_http_error` is True
|
||||
"""
|
||||
return 0
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
pass
|
||||
|
@ -104,6 +131,7 @@ class BaseConfig(ABC):
|
|||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
pass
|
||||
|
||||
|
|
|
@ -115,6 +115,7 @@ class AmazonInvokeMixin:
|
|||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
raise NotImplementedError(
|
||||
"validate_environment not implemented for config. Done in invoke_handler.py"
|
||||
|
|
|
@ -119,6 +119,7 @@ class ClarifaiConfig(BaseConfig):
|
|||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
headers = {
|
||||
"accept": "application/json",
|
||||
|
|
|
@ -60,6 +60,7 @@ class CloudflareChatConfig(BaseConfig):
|
|||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
|
|
|
@ -116,6 +116,7 @@ class CohereChatConfig(BaseConfig):
|
|||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
return cohere_validate_environment(
|
||||
headers=headers,
|
||||
|
|
|
@ -102,6 +102,7 @@ class CohereTextConfig(BaseConfig):
|
|||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
return cohere_validate_environment(
|
||||
headers=headers,
|
||||
|
|
|
@ -8,7 +8,7 @@ import litellm
|
|||
import litellm.litellm_core_utils
|
||||
import litellm.types
|
||||
import litellm.types.utils
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
|
||||
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
|
@ -30,6 +30,114 @@ else:
|
|||
|
||||
|
||||
class BaseLLMHTTPHandler:
|
||||
|
||||
async def _make_common_async_call(
|
||||
self,
|
||||
async_httpx_client: AsyncHTTPHandler,
|
||||
provider_config: BaseConfig,
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: dict,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
litellm_params: dict,
|
||||
stream: bool = False,
|
||||
) -> httpx.Response:
|
||||
"""Common implementation across stream + non-stream calls. Meant to ensure consistent error-handling."""
|
||||
max_retry_on_unprocessable_entity_error = (
|
||||
provider_config.max_retry_on_unprocessable_entity_error
|
||||
)
|
||||
|
||||
response: Optional[httpx.Response] = None
|
||||
for i in range(max(max_retry_on_unprocessable_entity_error, 1)):
|
||||
try:
|
||||
response = await async_httpx_client.post(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
timeout=timeout,
|
||||
stream=stream,
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
hit_max_retry = i + 1 == max_retry_on_unprocessable_entity_error
|
||||
should_retry = provider_config.should_retry_llm_api_inside_llm_translation_on_http_error(
|
||||
e=e, litellm_params=litellm_params
|
||||
)
|
||||
if should_retry and not hit_max_retry:
|
||||
data = (
|
||||
provider_config.transform_request_on_unprocessable_entity_error(
|
||||
e=e, request_data=data
|
||||
)
|
||||
)
|
||||
continue
|
||||
else:
|
||||
raise self._handle_error(e=e, provider_config=provider_config)
|
||||
except Exception as e:
|
||||
raise self._handle_error(e=e, provider_config=provider_config)
|
||||
break
|
||||
|
||||
if response is None:
|
||||
raise provider_config.get_error_class(
|
||||
error_message="No response from the API",
|
||||
status_code=422, # don't retry on this error
|
||||
headers={},
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _make_common_sync_call(
|
||||
self,
|
||||
sync_httpx_client: HTTPHandler,
|
||||
provider_config: BaseConfig,
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: dict,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
litellm_params: dict,
|
||||
stream: bool = False,
|
||||
) -> httpx.Response:
|
||||
|
||||
max_retry_on_unprocessable_entity_error = (
|
||||
provider_config.max_retry_on_unprocessable_entity_error
|
||||
)
|
||||
|
||||
response: Optional[httpx.Response] = None
|
||||
|
||||
for i in range(max(max_retry_on_unprocessable_entity_error, 1)):
|
||||
try:
|
||||
response = sync_httpx_client.post(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
timeout=timeout,
|
||||
stream=stream,
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
hit_max_retry = i + 1 == max_retry_on_unprocessable_entity_error
|
||||
should_retry = provider_config.should_retry_llm_api_inside_llm_translation_on_http_error(
|
||||
e=e, litellm_params=litellm_params
|
||||
)
|
||||
if should_retry and not hit_max_retry:
|
||||
data = (
|
||||
provider_config.transform_request_on_unprocessable_entity_error(
|
||||
e=e, request_data=data
|
||||
)
|
||||
)
|
||||
continue
|
||||
else:
|
||||
raise self._handle_error(e=e, provider_config=provider_config)
|
||||
except Exception as e:
|
||||
raise self._handle_error(e=e, provider_config=provider_config)
|
||||
break
|
||||
|
||||
if response is None:
|
||||
raise provider_config.get_error_class(
|
||||
error_message="No response from the API",
|
||||
status_code=422, # don't retry on this error
|
||||
headers={},
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
async def async_completion(
|
||||
self,
|
||||
custom_llm_provider: str,
|
||||
|
@ -55,15 +163,16 @@ class BaseLLMHTTPHandler:
|
|||
else:
|
||||
async_httpx_client = client
|
||||
|
||||
try:
|
||||
response = await async_httpx_client.post(
|
||||
url=api_base,
|
||||
response = await self._make_common_async_call(
|
||||
async_httpx_client=async_httpx_client,
|
||||
provider_config=provider_config,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
data=data,
|
||||
timeout=timeout,
|
||||
litellm_params=litellm_params,
|
||||
stream=False,
|
||||
)
|
||||
except Exception as e:
|
||||
raise self._handle_error(e=e, provider_config=provider_config)
|
||||
return provider_config.transform_response(
|
||||
model=model,
|
||||
raw_response=response,
|
||||
|
@ -93,7 +202,7 @@ class BaseLLMHTTPHandler:
|
|||
stream: Optional[bool] = False,
|
||||
fake_stream: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
headers={},
|
||||
headers: Optional[dict] = {},
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
):
|
||||
provider_config = ProviderConfigManager.get_provider_chat_config(
|
||||
|
@ -102,10 +211,11 @@ class BaseLLMHTTPHandler:
|
|||
# get config from model, custom llm provider
|
||||
headers = provider_config.validate_environment(
|
||||
api_key=api_key,
|
||||
headers=headers,
|
||||
headers=headers or {},
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
api_base = provider_config.get_complete_url(
|
||||
|
@ -154,6 +264,7 @@ class BaseLLMHTTPHandler:
|
|||
if client is not None and isinstance(client, AsyncHTTPHandler)
|
||||
else None
|
||||
),
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
else:
|
||||
|
@ -186,7 +297,7 @@ class BaseLLMHTTPHandler:
|
|||
provider_config=provider_config,
|
||||
api_base=api_base,
|
||||
headers=headers, # type: ignore
|
||||
data=json.dumps(data),
|
||||
data=data,
|
||||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
|
@ -197,6 +308,7 @@ class BaseLLMHTTPHandler:
|
|||
if client is not None and isinstance(client, HTTPHandler)
|
||||
else None
|
||||
),
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
return CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
|
@ -210,19 +322,15 @@ class BaseLLMHTTPHandler:
|
|||
else:
|
||||
sync_httpx_client = client
|
||||
|
||||
try:
|
||||
response = sync_httpx_client.post(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
timeout=timeout,
|
||||
)
|
||||
except Exception as e:
|
||||
raise self._handle_error(
|
||||
e=e,
|
||||
response = self._make_common_sync_call(
|
||||
sync_httpx_client=sync_httpx_client,
|
||||
provider_config=provider_config,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=data,
|
||||
timeout=timeout,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
return provider_config.transform_response(
|
||||
model=model,
|
||||
raw_response=response,
|
||||
|
@ -241,43 +349,32 @@ class BaseLLMHTTPHandler:
|
|||
provider_config: BaseConfig,
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: str,
|
||||
data: dict,
|
||||
model: str,
|
||||
messages: list,
|
||||
logging_obj,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
litellm_params: dict,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
fake_stream: bool = False,
|
||||
client: Optional[HTTPHandler] = None,
|
||||
) -> Tuple[Any, httpx.Headers]:
|
||||
) -> Tuple[Any, dict]:
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
sync_httpx_client = _get_httpx_client()
|
||||
else:
|
||||
sync_httpx_client = client
|
||||
try:
|
||||
stream = True
|
||||
if fake_stream is True:
|
||||
stream = False
|
||||
response = sync_httpx_client.post(
|
||||
api_base, headers=headers, data=data, timeout=timeout, stream=stream
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise self._handle_error(
|
||||
e=e,
|
||||
provider_config=provider_config,
|
||||
)
|
||||
except Exception as e:
|
||||
for exception in litellm.LITELLM_EXCEPTION_TYPES:
|
||||
if isinstance(e, exception):
|
||||
raise e
|
||||
raise self._handle_error(
|
||||
e=e,
|
||||
provider_config=provider_config,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise BaseLLMException(
|
||||
status_code=response.status_code,
|
||||
message=str(response.read()),
|
||||
response = self._make_common_sync_call(
|
||||
sync_httpx_client=sync_httpx_client,
|
||||
provider_config=provider_config,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=data,
|
||||
timeout=timeout,
|
||||
litellm_params=litellm_params,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
if fake_stream is True:
|
||||
|
@ -297,7 +394,7 @@ class BaseLLMHTTPHandler:
|
|||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
return completion_stream, response.headers
|
||||
return completion_stream, dict(response.headers)
|
||||
|
||||
async def acompletion_stream_function(
|
||||
self,
|
||||
|
@ -310,6 +407,7 @@ class BaseLLMHTTPHandler:
|
|||
timeout: Union[float, httpx.Timeout],
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
data: dict,
|
||||
litellm_params: dict,
|
||||
fake_stream: bool = False,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
):
|
||||
|
@ -318,12 +416,13 @@ class BaseLLMHTTPHandler:
|
|||
provider_config=provider_config,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
data=data,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
timeout=timeout,
|
||||
fake_stream=fake_stream,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
|
@ -339,10 +438,11 @@ class BaseLLMHTTPHandler:
|
|||
provider_config: BaseConfig,
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: str,
|
||||
data: dict,
|
||||
messages: list,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
litellm_params: dict,
|
||||
fake_stream: bool = False,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
) -> Tuple[Any, httpx.Headers]:
|
||||
|
@ -355,29 +455,18 @@ class BaseLLMHTTPHandler:
|
|||
stream = True
|
||||
if fake_stream is True:
|
||||
stream = False
|
||||
try:
|
||||
response = await async_httpx_client.post(
|
||||
api_base, headers=headers, data=data, stream=stream, timeout=timeout
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise self._handle_error(
|
||||
e=e,
|
||||
provider_config=provider_config,
|
||||
)
|
||||
except Exception as e:
|
||||
for exception in litellm.LITELLM_EXCEPTION_TYPES:
|
||||
if isinstance(e, exception):
|
||||
raise e
|
||||
raise self._handle_error(
|
||||
e=e,
|
||||
|
||||
response = await self._make_common_async_call(
|
||||
async_httpx_client=async_httpx_client,
|
||||
provider_config=provider_config,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=data,
|
||||
timeout=timeout,
|
||||
litellm_params=litellm_params,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise BaseLLMException(
|
||||
status_code=response.status_code,
|
||||
message=str(response.read()),
|
||||
)
|
||||
if fake_stream is True:
|
||||
completion_stream = provider_config.get_model_response_iterator(
|
||||
streaming_response=response.json(), sync_stream=False
|
||||
|
|
|
@ -118,6 +118,7 @@ class DeepgramAudioTranscriptionConfig(BaseAudioTranscriptionConfig):
|
|||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
api_key = api_key or get_secret_str("DEEPGRAM_API_KEY")
|
||||
return {
|
||||
|
|
|
@ -42,6 +42,7 @@ class FireworksAIMixin:
|
|||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
api_key = self._get_api_key(api_key)
|
||||
if api_key is None:
|
||||
|
|
|
@ -724,12 +724,14 @@ class Huggingface(BaseLLM):
|
|||
token_logprob = token["logprob"]
|
||||
|
||||
# Add the token information to the 'token_info' list
|
||||
_logprob.tokens.append(token_text)
|
||||
_logprob.token_logprobs.append(token_logprob)
|
||||
cast(List[str], _logprob.tokens).append(token_text)
|
||||
cast(List[float], _logprob.token_logprobs).append(token_logprob)
|
||||
|
||||
# stub this to work with llm eval harness
|
||||
top_alt_tokens = {"": -1.0, "": -2.0, "": -3.0} # noqa: F601
|
||||
_logprob.top_logprobs.append(top_alt_tokens)
|
||||
cast(List[Dict[str, float]], _logprob.top_logprobs).append(
|
||||
top_alt_tokens
|
||||
)
|
||||
|
||||
# For each element in the 'tokens' list, extract the relevant information
|
||||
for i, token in enumerate(response_details["tokens"]):
|
||||
|
@ -751,13 +753,15 @@ class Huggingface(BaseLLM):
|
|||
top_alt_tokens[text] = logprob
|
||||
|
||||
# Add the token information to the 'token_info' list
|
||||
_logprob.tokens.append(token_text)
|
||||
_logprob.token_logprobs.append(token_logprob)
|
||||
_logprob.top_logprobs.append(top_alt_tokens)
|
||||
cast(List[str], _logprob.tokens).append(token_text)
|
||||
cast(List[float], _logprob.token_logprobs).append(token_logprob)
|
||||
cast(List[Dict[str, float]], _logprob.top_logprobs).append(
|
||||
top_alt_tokens
|
||||
)
|
||||
|
||||
# Add the text offset of the token
|
||||
# This is computed as the sum of the lengths of all previous tokens
|
||||
_logprob.text_offset.append(
|
||||
cast(List[int], _logprob.text_offset).append(
|
||||
sum(len(t["text"]) for t in response_details["tokens"][:i])
|
||||
)
|
||||
|
||||
|
|
|
@ -356,6 +356,7 @@ class HuggingfaceChatConfig(BaseConfig):
|
|||
messages: List[AllMessageValues],
|
||||
optional_params: Dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> Dict:
|
||||
default_headers = {
|
||||
"content-type": "application/json",
|
||||
|
|
|
@ -94,6 +94,7 @@ class NLPCloudConfig(BaseConfig):
|
|||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
headers = {
|
||||
"accept": "application/json",
|
||||
|
|
|
@ -347,6 +347,7 @@ class OllamaConfig(BaseConfig):
|
|||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
return headers
|
||||
|
||||
|
|
|
@ -89,6 +89,7 @@ class OobaboogaConfig(OpenAIGPTConfig):
|
|||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
headers = {
|
||||
"accept": "application/json",
|
||||
|
|
|
@ -181,6 +181,7 @@ class OpenAIGPTConfig(BaseConfig):
|
|||
Returns:
|
||||
dict: The transformed request. Sent as the body of the API call.
|
||||
"""
|
||||
messages = self._transform_messages(messages=messages, model=model)
|
||||
return {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
|
@ -225,5 +226,6 @@ class OpenAIGPTConfig(BaseConfig):
|
|||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -45,7 +45,8 @@ class OpenAIError(BaseLLMException):
|
|||
####### Error Handling Utils for OpenAI API #######################
|
||||
###################################################################
|
||||
def drop_params_from_unprocessable_entity_error(
|
||||
e: openai.UnprocessableEntityError, data: Dict[str, Any]
|
||||
e: Union[openai.UnprocessableEntityError, httpx.HTTPStatusError],
|
||||
data: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Helper function to read OpenAI UnprocessableEntityError and drop the params that raised an error from the error message.
|
||||
|
@ -58,14 +59,25 @@ def drop_params_from_unprocessable_entity_error(
|
|||
Dict[str, Any]: A new dictionary with invalid parameters removed
|
||||
"""
|
||||
invalid_params: List[str] = []
|
||||
if e.body is not None and isinstance(e.body, dict) and e.body.get("message"):
|
||||
message = e.body.get("message", {})
|
||||
if isinstance(e, httpx.HTTPStatusError):
|
||||
error_json = e.response.json()
|
||||
error_message = error_json.get("error", {})
|
||||
error_body = error_message
|
||||
else:
|
||||
error_body = e.body
|
||||
if (
|
||||
error_body is not None
|
||||
and isinstance(error_body, dict)
|
||||
and error_body.get("message")
|
||||
):
|
||||
message = error_body.get("message", {})
|
||||
if isinstance(message, str):
|
||||
try:
|
||||
message = json.loads(message)
|
||||
except json.JSONDecodeError:
|
||||
message = {"detail": message}
|
||||
detail = message.get("detail")
|
||||
|
||||
if isinstance(detail, List) and len(detail) > 0 and isinstance(detail[0], dict):
|
||||
for error_dict in detail:
|
||||
if (
|
||||
|
@ -76,4 +88,5 @@ def drop_params_from_unprocessable_entity_error(
|
|||
invalid_params.append(error_dict["loc"][1])
|
||||
|
||||
new_data = {k: v for k, v in data.items() if k not in invalid_params}
|
||||
|
||||
return new_data
|
||||
|
|
|
@ -2,9 +2,11 @@ import hashlib
|
|||
import types
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
|
@ -24,10 +26,16 @@ import litellm
|
|||
from litellm import LlmProviders
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.llms.bedrock.chat.invoke_handler import MockResponseIterator
|
||||
from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS
|
||||
from litellm.types.utils import EmbeddingResponse, ImageResponse, ModelResponse
|
||||
from litellm.types.utils import (
|
||||
EmbeddingResponse,
|
||||
ImageResponse,
|
||||
ModelResponse,
|
||||
ModelResponseStream,
|
||||
)
|
||||
from litellm.utils import (
|
||||
CustomStreamWrapper,
|
||||
ProviderConfigManager,
|
||||
|
@ -36,7 +44,6 @@ from litellm.utils import (
|
|||
|
||||
from ...types.llms.openai import *
|
||||
from ..base import BaseLLM
|
||||
from .chat.gpt_transformation import OpenAIGPTConfig
|
||||
from .common_utils import OpenAIError, drop_params_from_unprocessable_entity_error
|
||||
|
||||
|
||||
|
@ -232,6 +239,7 @@ class OpenAIConfig(BaseConfig):
|
|||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
messages = self._transform_messages(messages=messages, model=model)
|
||||
return {"model": model, "messages": messages, **optional_params}
|
||||
|
||||
def transform_response(
|
||||
|
@ -248,10 +256,21 @@ class OpenAIConfig(BaseConfig):
|
|||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
raise NotImplementedError(
|
||||
"OpenAI handler does this transformation as it uses the OpenAI SDK."
|
||||
|
||||
logging_obj.post_call(original_response=raw_response.text)
|
||||
logging_obj.model_call_details["response_headers"] = raw_response.headers
|
||||
final_response_obj = cast(
|
||||
ModelResponse,
|
||||
convert_to_model_response_object(
|
||||
response_object=raw_response.json(),
|
||||
model_response_object=model_response,
|
||||
hidden_params={"headers": raw_response.headers},
|
||||
_response_headers=dict(raw_response.headers),
|
||||
),
|
||||
)
|
||||
|
||||
return final_response_obj
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
|
@ -259,12 +278,37 @@ class OpenAIConfig(BaseConfig):
|
|||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
raise NotImplementedError(
|
||||
"OpenAI handler does this validation as it uses the OpenAI SDK."
|
||||
return {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
**headers,
|
||||
}
|
||||
|
||||
def get_model_response_iterator(
|
||||
self,
|
||||
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||
sync_stream: bool,
|
||||
json_mode: Optional[bool] = False,
|
||||
) -> Any:
|
||||
return OpenAIChatCompletionResponseIterator(
|
||||
streaming_response=streaming_response,
|
||||
sync_stream=sync_stream,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
|
||||
class OpenAIChatCompletionResponseIterator(BaseModelResponseIterator):
|
||||
def chunk_parser(self, chunk: dict) -> ModelResponseStream:
|
||||
"""
|
||||
{'choices': [{'delta': {'content': '', 'role': 'assistant'}, 'finish_reason': None, 'index': 0, 'logprobs': None}], 'created': 1735763082, 'id': 'a83a2b0fbfaf4aab9c2c93cb8ba346d7', 'model': 'mistral-large', 'object': 'chat.completion.chunk'}
|
||||
"""
|
||||
try:
|
||||
return ModelResponseStream(**chunk)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
class OpenAIChatCompletion(BaseLLM):
|
||||
|
||||
def __init__(self) -> None:
|
||||
|
@ -473,14 +517,6 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
if custom_llm_provider is not None and custom_llm_provider != "openai":
|
||||
model_response.model = f"{custom_llm_provider}/{model}"
|
||||
|
||||
if messages is not None and provider_config is not None:
|
||||
if isinstance(provider_config, OpenAIGPTConfig) or isinstance(
|
||||
provider_config, OpenAIConfig
|
||||
): # [TODO]: remove. no longer needed as .transform_request can just handle this.
|
||||
messages = provider_config._transform_messages(
|
||||
messages=messages, model=model
|
||||
)
|
||||
|
||||
for _ in range(
|
||||
2
|
||||
): # if call fails due to alternating messages, retry with reformatted message
|
||||
|
@ -647,12 +683,10 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
new_messages = messages
|
||||
new_messages.append({"role": "user", "content": ""})
|
||||
messages = new_messages
|
||||
elif (
|
||||
"unknown field: parameter index is not a valid field" in str(e)
|
||||
) and "tools" in data:
|
||||
litellm.remove_index_from_tool_calls(
|
||||
tool_calls=data["tools"], messages=messages
|
||||
)
|
||||
elif "unknown field: parameter index is not a valid field" in str(
|
||||
e
|
||||
):
|
||||
litellm.remove_index_from_tool_calls(messages=messages)
|
||||
else:
|
||||
raise e
|
||||
except OpenAIError as e:
|
||||
|
|
|
@ -132,5 +132,6 @@ class PetalsConfig(BaseConfig):
|
|||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
return {}
|
||||
|
|
|
@ -164,6 +164,7 @@ class PredibaseConfig(BaseConfig):
|
|||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
|
|
|
@ -309,6 +309,7 @@ class ReplicateConfig(BaseConfig):
|
|||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
headers = {
|
||||
"Authorization": f"Token {api_key}",
|
||||
|
|
|
@ -260,6 +260,7 @@ class SagemakerConfig(BaseConfig):
|
|||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
|
|
|
@ -48,6 +48,7 @@ class TritonConfig(BaseConfig):
|
|||
messages: List[AllMessageValues],
|
||||
optional_params: Dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> Dict:
|
||||
return {"Content-Type": "application/json"}
|
||||
|
||||
|
|
|
@ -43,6 +43,7 @@ class TritonEmbeddingConfig(BaseEmbeddingConfig):
|
|||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
return {}
|
||||
|
||||
|
|
|
@ -808,6 +808,7 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
|
|||
messages: List[AllMessageValues],
|
||||
optional_params: Dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> Dict:
|
||||
default_headers = {
|
||||
"Content-Type": "application/json",
|
||||
|
|
|
@ -82,6 +82,7 @@ class VoyageEmbeddingConfig(BaseEmbeddingConfig):
|
|||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
if api_key is None:
|
||||
api_key = (
|
||||
|
|
|
@ -166,6 +166,7 @@ class IBMWatsonXMixin:
|
|||
messages: List[AllMessageValues],
|
||||
optional_params: Dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> Dict:
|
||||
default_headers = {
|
||||
"Content-Type": "application/json",
|
||||
|
|
|
@ -1122,6 +1122,7 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
custom_prompt_dict=custom_prompt_dict,
|
||||
litellm_metadata=kwargs.get("litellm_metadata"),
|
||||
disable_add_transform_inline_image_block=disable_add_transform_inline_image_block,
|
||||
drop_params=kwargs.get("drop_params"),
|
||||
)
|
||||
logging.update_environment_variables(
|
||||
model=model,
|
||||
|
@ -1347,39 +1348,28 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
if extra_headers is not None:
|
||||
optional_params["extra_headers"] = extra_headers
|
||||
|
||||
## LOAD CONFIG - if set
|
||||
config = litellm.AzureAIStudioConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > openai_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
## FOR COHERE
|
||||
if "command-r" in model: # make sure tool call in messages are str
|
||||
messages = stringify_json_tool_call_content(messages=messages)
|
||||
|
||||
## COMPLETION CALL
|
||||
try:
|
||||
response = openai_chat_completions.completion(
|
||||
response = base_llm_http_handler.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
headers=headers,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
acompletion=acompletion,
|
||||
logging_obj=logging,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
timeout=timeout, # type: ignore
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
client=client, # pass AsyncOpenAI, OpenAI client
|
||||
organization=organization,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
drop_params=non_default_params.get("drop_params"),
|
||||
encoding=encoding,
|
||||
stream=stream,
|
||||
)
|
||||
except Exception as e:
|
||||
## LOGGING - log the original exception returned
|
||||
|
|
|
@ -1074,10 +1074,10 @@ class EmbeddingResponse(OpenAIObject):
|
|||
|
||||
|
||||
class Logprobs(OpenAIObject):
|
||||
text_offset: List[int]
|
||||
token_logprobs: List[Union[float, None]]
|
||||
tokens: List[str]
|
||||
top_logprobs: List[Union[Dict[str, float], None]]
|
||||
text_offset: Optional[List[int]]
|
||||
token_logprobs: Optional[List[Union[float, None]]]
|
||||
tokens: Optional[List[str]]
|
||||
top_logprobs: Optional[List[Union[Dict[str, float], None]]]
|
||||
|
||||
|
||||
class TextChoices(OpenAIObject):
|
||||
|
|
|
@ -2002,6 +2002,7 @@ def get_litellm_params(
|
|||
custom_prompt_dict: Optional[dict] = None,
|
||||
litellm_metadata: Optional[dict] = None,
|
||||
disable_add_transform_inline_image_block: Optional[bool] = None,
|
||||
drop_params: Optional[bool] = None,
|
||||
):
|
||||
litellm_params = {
|
||||
"acompletion": acompletion,
|
||||
|
@ -2035,6 +2036,7 @@ def get_litellm_params(
|
|||
"custom_prompt_dict": custom_prompt_dict,
|
||||
"litellm_metadata": litellm_metadata,
|
||||
"disable_add_transform_inline_image_block": disable_add_transform_inline_image_block,
|
||||
"drop_params": drop_params,
|
||||
}
|
||||
return litellm_params
|
||||
|
||||
|
@ -6345,3 +6347,44 @@ def extract_duration_from_srt_or_vtt(srt_or_vtt_content: str) -> Optional[float]
|
|||
durations.append(total_seconds)
|
||||
|
||||
return max(durations) if durations else None
|
||||
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
def _add_path_to_api_base(api_base: str, ending_path: str) -> str:
|
||||
"""
|
||||
Adds an ending path to an API base URL while preventing duplicate path segments.
|
||||
|
||||
Args:
|
||||
api_base: Base URL string
|
||||
ending_path: Path to append to the base URL
|
||||
|
||||
Returns:
|
||||
Modified URL string with proper path handling
|
||||
"""
|
||||
original_url = httpx.URL(api_base)
|
||||
base_url = original_url.copy_with(params={}) # Removes query params
|
||||
base_path = original_url.path.rstrip("/")
|
||||
end_path = ending_path.lstrip("/")
|
||||
|
||||
# Split paths into segments
|
||||
base_segments = [s for s in base_path.split("/") if s]
|
||||
end_segments = [s for s in end_path.split("/") if s]
|
||||
|
||||
# Find overlapping segments from the end of base_path and start of ending_path
|
||||
final_segments = []
|
||||
for i in range(len(base_segments)):
|
||||
if base_segments[i:] == end_segments[: len(base_segments) - i]:
|
||||
final_segments = base_segments[:i] + end_segments
|
||||
break
|
||||
else:
|
||||
# No overlap found, just combine all segments
|
||||
final_segments = base_segments + end_segments
|
||||
|
||||
# Construct the new path
|
||||
modified_path = "/" + "/".join(final_segments)
|
||||
modified_url = base_url.copy_with(path=modified_path)
|
||||
|
||||
# Re-add the original query parameters
|
||||
return str(modified_url.copy_with(params=original_url.params))
|
||||
|
|
|
@ -28,6 +28,7 @@ from unittest.mock import MagicMock, patch
|
|||
import pytest
|
||||
|
||||
import litellm
|
||||
from litellm import completion
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -51,18 +52,13 @@ async def test_azure_ai_with_image_url():
|
|||
|
||||
Test that Azure AI studio can handle image_url passed when content is a list containing both text and image_url
|
||||
"""
|
||||
from openai import AsyncOpenAI
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
|
||||
litellm.set_verbose = True
|
||||
|
||||
client = AsyncOpenAI(
|
||||
api_key="fake-api-key",
|
||||
base_url="https://Phi-3-5-vision-instruct-dcvov.eastus2.models.ai.azure.com",
|
||||
)
|
||||
client = AsyncHTTPHandler()
|
||||
|
||||
with patch.object(
|
||||
client.chat.completions.with_raw_response, "create"
|
||||
) as mock_client:
|
||||
with patch.object(client, "post") as mock_client:
|
||||
try:
|
||||
await litellm.acompletion(
|
||||
model="azure_ai/Phi-3-5-vision-instruct-dcvov",
|
||||
|
@ -94,8 +90,9 @@ async def test_azure_ai_with_image_url():
|
|||
# Verify the request was made
|
||||
mock_client.assert_called_once()
|
||||
|
||||
print(f"mock_client.call_args.kwargs: {mock_client.call_args.kwargs}")
|
||||
# Check the request body
|
||||
request_body = mock_client.call_args.kwargs
|
||||
request_body = json.loads(mock_client.call_args.kwargs["data"])
|
||||
assert request_body["model"] == "Phi-3-5-vision-instruct-dcvov"
|
||||
assert request_body["messages"] == [
|
||||
{
|
||||
|
@ -111,3 +108,79 @@ async def test_azure_ai_with_image_url():
|
|||
],
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"api_base, expected_url",
|
||||
[
|
||||
(
|
||||
"https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview",
|
||||
"https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview",
|
||||
),
|
||||
(
|
||||
"https://litellm8397336933.services.ai.azure.com/models/chat/completions",
|
||||
"https://litellm8397336933.services.ai.azure.com/models/chat/completions",
|
||||
),
|
||||
(
|
||||
"https://litellm8397336933.services.ai.azure.com/models",
|
||||
"https://litellm8397336933.services.ai.azure.com/models/chat/completions",
|
||||
),
|
||||
(
|
||||
"https://litellm8397336933.services.ai.azure.com",
|
||||
"https://litellm8397336933.services.ai.azure.com/models/chat/completions",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_azure_ai_services_handler(api_base, expected_url):
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
|
||||
litellm.set_verbose = True
|
||||
|
||||
client = HTTPHandler()
|
||||
|
||||
with patch.object(client, "post") as mock_client:
|
||||
try:
|
||||
response = litellm.completion(
|
||||
model="azure_ai/Meta-Llama-3.1-70B-Instruct",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
api_key="my-fake-api-key",
|
||||
api_base=api_base,
|
||||
client=client,
|
||||
)
|
||||
|
||||
print(response)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
mock_client.assert_called_once()
|
||||
assert mock_client.call_args.kwargs["headers"]["api-key"] == "my-fake-api-key"
|
||||
assert mock_client.call_args.kwargs["url"] == expected_url
|
||||
|
||||
|
||||
def test_completion_azure_ai_command_r():
|
||||
try:
|
||||
import os
|
||||
|
||||
litellm.set_verbose = True
|
||||
|
||||
os.environ["AZURE_AI_API_BASE"] = os.getenv("AZURE_COHERE_API_BASE", "")
|
||||
os.environ["AZURE_AI_API_KEY"] = os.getenv("AZURE_COHERE_API_KEY", "")
|
||||
|
||||
response = completion(
|
||||
model="azure_ai/command-r-plus",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What is the meaning of life?"}
|
||||
],
|
||||
}
|
||||
],
|
||||
) # type: ignore
|
||||
|
||||
assert "azure_ai" in response.model
|
||||
except litellm.Timeout as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
|
|
@ -132,34 +132,6 @@ def test_null_role_response():
|
|||
assert response.choices[0].message.role == "assistant"
|
||||
|
||||
|
||||
def test_completion_azure_ai_command_r():
|
||||
try:
|
||||
import os
|
||||
|
||||
litellm.set_verbose = True
|
||||
|
||||
os.environ["AZURE_AI_API_BASE"] = os.getenv("AZURE_COHERE_API_BASE", "")
|
||||
os.environ["AZURE_AI_API_KEY"] = os.getenv("AZURE_COHERE_API_KEY", "")
|
||||
|
||||
response = completion(
|
||||
model="azure_ai/command-r-plus",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What is the meaning of life?"}
|
||||
],
|
||||
}
|
||||
],
|
||||
) # type: ignore
|
||||
|
||||
assert "azure_ai" in response.model
|
||||
except litellm.Timeout as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_azure_ai_mistral_invalid_params(sync_mode):
|
||||
|
|
|
@ -199,4 +199,4 @@ def test_azure_global_standard_get_llm_provider():
|
|||
api_base="https://my-deployment-francecentral.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview",
|
||||
api_key="fake-api-key",
|
||||
)
|
||||
assert custom_llm_provider == "azure"
|
||||
assert custom_llm_provider == "azure_ai"
|
||||
|
|
|
@ -2954,6 +2954,7 @@ def test_azure_streaming_and_function_calling():
|
|||
async def test_completion_azure_ai_mistral_invalid_params(sync_mode):
|
||||
try:
|
||||
import os
|
||||
from litellm import stream_chunk_builder
|
||||
|
||||
litellm.set_verbose = True
|
||||
|
||||
|
@ -2968,15 +2969,21 @@ async def test_completion_azure_ai_mistral_invalid_params(sync_mode):
|
|||
"drop_params": True,
|
||||
"stream": True,
|
||||
}
|
||||
chunks = []
|
||||
if sync_mode:
|
||||
response: litellm.ModelResponse = completion(**data) # type: ignore
|
||||
response = completion(**data) # type: ignore
|
||||
for chunk in response:
|
||||
print(chunk)
|
||||
chunks.append(chunk)
|
||||
else:
|
||||
response: litellm.ModelResponse = await litellm.acompletion(**data) # type: ignore
|
||||
response = await litellm.acompletion(**data) # type: ignore
|
||||
|
||||
async for chunk in response:
|
||||
print(chunk)
|
||||
chunks.append(chunk)
|
||||
print(f"chunks: {chunks}")
|
||||
response = stream_chunk_builder(chunks=chunks)
|
||||
assert response.choices[0].message.content is not None
|
||||
except litellm.Timeout as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
|
|
|
@ -1252,3 +1252,19 @@ def test_fireworks_ai_document_inlining():
|
|||
|
||||
assert supports_pdf_input("fireworks_ai/llama-3.1-8b-instruct") is True
|
||||
assert supports_vision("fireworks_ai/llama-3.1-8b-instruct") is True
|
||||
|
||||
|
||||
def test_logprobs_type():
|
||||
from litellm.types.utils import Logprobs
|
||||
|
||||
logprobs = {
|
||||
"text_offset": None,
|
||||
"token_logprobs": None,
|
||||
"tokens": None,
|
||||
"top_logprobs": None,
|
||||
}
|
||||
logprobs = Logprobs(**logprobs)
|
||||
assert logprobs.text_offset is None
|
||||
assert logprobs.token_logprobs is None
|
||||
assert logprobs.tokens is None
|
||||
assert logprobs.top_logprobs is None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue