mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
* test(azure_openai_o1.py): initial commit with testing for azure openai o1 preview model * fix(base_llm_unit_tests.py): handle azure o1 preview response format tests skip as o1 on azure doesn't support tool calling yet * fix: initial commit of azure o1 handler using openai caller simplifies calling + allows fake streaming logic alr. implemented for openai to just work * feat(azure/o1_handler.py): fake o1 streaming for azure o1 models azure does not currently support streaming for o1 * feat(o1_transformation.py): support overriding 'should_fake_stream' on azure/o1 via 'supports_native_streaming' param on model info enables user to toggle on when azure allows o1 streaming without needing to bump versions * style(router.py): remove 'give feedback/get help' messaging when router is used Prevents noisy messaging Closes https://github.com/BerriAI/litellm/issues/5942 * fix(types/utils.py): handle none logprobs Fixes https://github.com/BerriAI/litellm/issues/328 * fix(exception_mapping_utils.py): fix error str unbound error * refactor(azure_ai/): move to openai_like chat completion handler allows for easy swapping of api base url's (e.g. ai.services.com) Fixes https://github.com/BerriAI/litellm/issues/7275 * refactor(azure_ai/): move to base llm http handler * fix(azure_ai/): handle differing api endpoints * fix(azure_ai/): make sure all unit tests are passing * fix: fix linting errors * fix: fix linting errors * fix: fix linting error * fix: fix linting errors * fix(azure_ai/transformation.py): handle extra body param * fix(azure_ai/transformation.py): fix max retries param handling * fix: fix test * test(test_azure_o1.py): fix test * fix(llm_http_handler.py): support handling azure ai unprocessable entity error * fix(llm_http_handler.py): handle sync invalid param error for azure ai * fix(azure_ai/): streaming support with base_llm_http_handler * fix(llm_http_handler.py): working sync stream calls with unprocessable entity handling for azure ai * fix: fix linting errors * fix(llm_http_handler.py): fix linting error * fix(azure_ai/): handle cohere tool call invalid index param error
231 lines
7.7 KiB
Python
231 lines
7.7 KiB
Python
"""
|
|
Support for gpt model family
|
|
"""
|
|
|
|
from typing import TYPE_CHECKING, Any, List, Optional, Union, cast
|
|
|
|
import httpx
|
|
|
|
import litellm
|
|
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
|
from litellm.types.llms.openai import AllMessageValues
|
|
from litellm.types.utils import ModelResponse
|
|
|
|
from ..common_utils import OpenAIError
|
|
|
|
if TYPE_CHECKING:
|
|
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
|
|
|
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
|
else:
|
|
LiteLLMLoggingObj = Any
|
|
|
|
|
|
class OpenAIGPTConfig(BaseConfig):
|
|
"""
|
|
Reference: https://platform.openai.com/docs/api-reference/chat/create
|
|
|
|
The class `OpenAIConfig` provides configuration for the OpenAI's Chat API interface. Below are the parameters:
|
|
|
|
- `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition.
|
|
|
|
- `function_call` (string or object): This optional parameter controls how the model calls functions.
|
|
|
|
- `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs.
|
|
|
|
- `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.
|
|
|
|
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion.
|
|
|
|
- `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message.
|
|
|
|
- `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics.
|
|
|
|
- `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.
|
|
|
|
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2.
|
|
|
|
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
|
|
"""
|
|
|
|
frequency_penalty: Optional[int] = None
|
|
function_call: Optional[Union[str, dict]] = None
|
|
functions: Optional[list] = None
|
|
logit_bias: Optional[dict] = None
|
|
max_tokens: Optional[int] = None
|
|
n: Optional[int] = None
|
|
presence_penalty: Optional[int] = None
|
|
stop: Optional[Union[str, list]] = None
|
|
temperature: Optional[int] = None
|
|
top_p: Optional[int] = None
|
|
response_format: Optional[dict] = None
|
|
|
|
def __init__(
|
|
self,
|
|
frequency_penalty: Optional[int] = None,
|
|
function_call: Optional[Union[str, dict]] = None,
|
|
functions: Optional[list] = None,
|
|
logit_bias: Optional[dict] = None,
|
|
max_tokens: Optional[int] = None,
|
|
n: Optional[int] = None,
|
|
presence_penalty: Optional[int] = None,
|
|
stop: Optional[Union[str, list]] = None,
|
|
temperature: Optional[int] = None,
|
|
top_p: Optional[int] = None,
|
|
response_format: Optional[dict] = None,
|
|
) -> None:
|
|
locals_ = locals().copy()
|
|
for key, value in locals_.items():
|
|
if key != "self" and value is not None:
|
|
setattr(self.__class__, key, value)
|
|
|
|
@classmethod
|
|
def get_config(cls):
|
|
return super().get_config()
|
|
|
|
def get_supported_openai_params(self, model: str) -> list:
|
|
base_params = [
|
|
"frequency_penalty",
|
|
"logit_bias",
|
|
"logprobs",
|
|
"top_logprobs",
|
|
"max_tokens",
|
|
"max_completion_tokens",
|
|
"modalities",
|
|
"prediction",
|
|
"n",
|
|
"presence_penalty",
|
|
"seed",
|
|
"stop",
|
|
"stream",
|
|
"stream_options",
|
|
"temperature",
|
|
"top_p",
|
|
"tools",
|
|
"tool_choice",
|
|
"function_call",
|
|
"functions",
|
|
"max_retries",
|
|
"extra_headers",
|
|
"parallel_tool_calls",
|
|
] # works across all models
|
|
|
|
model_specific_params = []
|
|
if (
|
|
model != "gpt-3.5-turbo-16k" and model != "gpt-4"
|
|
): # gpt-4 does not support 'response_format'
|
|
model_specific_params.append("response_format")
|
|
|
|
if (
|
|
model in litellm.open_ai_chat_completion_models
|
|
) or model in litellm.open_ai_text_completion_models:
|
|
model_specific_params.append(
|
|
"user"
|
|
) # user is not a param supported by all openai-compatible endpoints - e.g. azure ai
|
|
return base_params + model_specific_params
|
|
|
|
def _map_openai_params(
|
|
self,
|
|
non_default_params: dict,
|
|
optional_params: dict,
|
|
model: str,
|
|
drop_params: bool,
|
|
) -> dict:
|
|
"""
|
|
If any supported_openai_params are in non_default_params, add them to optional_params, so they are use in API call
|
|
|
|
Args:
|
|
non_default_params (dict): Non-default parameters to filter.
|
|
optional_params (dict): Optional parameters to update.
|
|
model (str): Model name for parameter support check.
|
|
|
|
Returns:
|
|
dict: Updated optional_params with supported non-default parameters.
|
|
"""
|
|
supported_openai_params = self.get_supported_openai_params(model)
|
|
for param, value in non_default_params.items():
|
|
if param in supported_openai_params:
|
|
optional_params[param] = value
|
|
return optional_params
|
|
|
|
def map_openai_params(
|
|
self,
|
|
non_default_params: dict,
|
|
optional_params: dict,
|
|
model: str,
|
|
drop_params: bool,
|
|
) -> dict:
|
|
return self._map_openai_params(
|
|
non_default_params=non_default_params,
|
|
optional_params=optional_params,
|
|
model=model,
|
|
drop_params=drop_params,
|
|
)
|
|
|
|
def _transform_messages(
|
|
self, messages: List[AllMessageValues], model: str
|
|
) -> List[AllMessageValues]:
|
|
return messages
|
|
|
|
def transform_request(
|
|
self,
|
|
model: str,
|
|
messages: List[AllMessageValues],
|
|
optional_params: dict,
|
|
litellm_params: dict,
|
|
headers: dict,
|
|
) -> dict:
|
|
"""
|
|
Transform the overall request to be sent to the API.
|
|
|
|
Returns:
|
|
dict: The transformed request. Sent as the body of the API call.
|
|
"""
|
|
messages = self._transform_messages(messages=messages, model=model)
|
|
return {
|
|
"model": model,
|
|
"messages": messages,
|
|
**optional_params,
|
|
}
|
|
|
|
def transform_response(
|
|
self,
|
|
model: str,
|
|
raw_response: httpx.Response,
|
|
model_response: ModelResponse,
|
|
logging_obj: LiteLLMLoggingObj,
|
|
request_data: dict,
|
|
messages: List[AllMessageValues],
|
|
optional_params: dict,
|
|
litellm_params: dict,
|
|
encoding: Any,
|
|
api_key: Optional[str] = None,
|
|
json_mode: Optional[bool] = None,
|
|
) -> ModelResponse:
|
|
"""
|
|
Transform the response from the API.
|
|
|
|
Returns:
|
|
dict: The transformed response.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def get_error_class(
|
|
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
|
) -> BaseLLMException:
|
|
return OpenAIError(
|
|
status_code=status_code,
|
|
message=error_message,
|
|
headers=cast(httpx.Headers, headers),
|
|
)
|
|
|
|
def validate_environment(
|
|
self,
|
|
headers: dict,
|
|
model: str,
|
|
messages: List[AllMessageValues],
|
|
optional_params: dict,
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
) -> dict:
|
|
raise NotImplementedError
|