""" Support for gpt model family """ import types from typing import TYPE_CHECKING, Any, List, Optional, Union, cast import httpx import litellm from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage from litellm.types.utils import ModelResponse from ..common_utils import OpenAIError if TYPE_CHECKING: from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj LiteLLMLoggingObj = _LiteLLMLoggingObj else: LiteLLMLoggingObj = Any class OpenAIGPTConfig(BaseConfig): """ Reference: https://platform.openai.com/docs/api-reference/chat/create The class `OpenAIConfig` provides configuration for the OpenAI's Chat API interface. Below are the parameters: - `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition. - `function_call` (string or object): This optional parameter controls how the model calls functions. - `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs. - `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion. - `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion. - `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message. - `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics. - `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens. - `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2. - `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling. """ frequency_penalty: Optional[int] = None function_call: Optional[Union[str, dict]] = None functions: Optional[list] = None logit_bias: Optional[dict] = None max_tokens: Optional[int] = None n: Optional[int] = None presence_penalty: Optional[int] = None stop: Optional[Union[str, list]] = None temperature: Optional[int] = None top_p: Optional[int] = None response_format: Optional[dict] = None def __init__( self, frequency_penalty: Optional[int] = None, function_call: Optional[Union[str, dict]] = None, functions: Optional[list] = None, logit_bias: Optional[dict] = None, max_tokens: Optional[int] = None, n: Optional[int] = None, presence_penalty: Optional[int] = None, stop: Optional[Union[str, list]] = None, temperature: Optional[int] = None, top_p: Optional[int] = None, response_format: Optional[dict] = None, ) -> None: locals_ = locals().copy() for key, value in locals_.items(): if key != "self" and value is not None: setattr(self.__class__, key, value) @classmethod def get_config(cls): return super().get_config() def get_supported_openai_params(self, model: str) -> list: base_params = [ "frequency_penalty", "logit_bias", "logprobs", "top_logprobs", "max_tokens", "max_completion_tokens", "modalities", "prediction", "n", "presence_penalty", "seed", "stop", "stream", "stream_options", "temperature", "top_p", "tools", "tool_choice", "function_call", "functions", "max_retries", "extra_headers", "parallel_tool_calls", ] # works across all models model_specific_params = [] if ( model != "gpt-3.5-turbo-16k" and model != "gpt-4" ): # gpt-4 does not support 'response_format' model_specific_params.append("response_format") if ( model in litellm.open_ai_chat_completion_models ) or model in litellm.open_ai_text_completion_models: model_specific_params.append( "user" ) # user is not a param supported by all openai-compatible endpoints - e.g. azure ai return base_params + model_specific_params def _map_openai_params( self, non_default_params: dict, optional_params: dict, model: str, drop_params: bool, ) -> dict: """ If any supported_openai_params are in non_default_params, add them to optional_params, so they are use in API call Args: non_default_params (dict): Non-default parameters to filter. optional_params (dict): Optional parameters to update. model (str): Model name for parameter support check. Returns: dict: Updated optional_params with supported non-default parameters. """ supported_openai_params = self.get_supported_openai_params(model) for param, value in non_default_params.items(): if param in supported_openai_params: optional_params[param] = value return optional_params def map_openai_params( self, non_default_params: dict, optional_params: dict, model: str, drop_params: bool, ) -> dict: return self._map_openai_params( non_default_params=non_default_params, optional_params=optional_params, model=model, drop_params=drop_params, ) def _transform_messages( self, messages: List[AllMessageValues] ) -> List[AllMessageValues]: return messages def transform_request( self, model: str, messages: List[AllMessageValues], optional_params: dict, litellm_params: dict, headers: dict, ) -> dict: """ Transform the overall request to be sent to the API. Returns: dict: The transformed request. Sent as the body of the API call. """ raise NotImplementedError def transform_response( self, model: str, raw_response: httpx.Response, model_response: ModelResponse, logging_obj: LiteLLMLoggingObj, request_data: dict, messages: List[AllMessageValues], optional_params: dict, encoding: str, api_key: Optional[str] = None, json_mode: Optional[bool] = None, ) -> ModelResponse: """ Transform the response from the API. Returns: dict: The transformed response. """ raise NotImplementedError def get_error_class( self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] ) -> BaseLLMException: return OpenAIError( status_code=status_code, message=error_message, headers=cast(httpx.Headers, headers), ) def validate_environment( self, headers: dict, model: str, messages: List[AllMessageValues], optional_params: dict, api_key: Optional[str] = None, ) -> dict: raise NotImplementedError