mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
rename llms/OpenAI/
-> llms/openai/
(#7154)
* rename OpenAI -> openai * fix file rename * fix rename changes * fix organization of openai/transcription * fix import OA fine tuning API * fix openai ft handler * fix handler import
This commit is contained in:
parent
e903fe6038
commit
bfb6891eb7
48 changed files with 53 additions and 59 deletions
229
litellm/llms/openai/chat/gpt_transformation.py
Normal file
229
litellm/llms/openai/chat/gpt_transformation.py
Normal file
|
@ -0,0 +1,229 @@
|
|||
"""
|
||||
Support for gpt model family
|
||||
"""
|
||||
|
||||
import types
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union, cast
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
from ..common_utils import OpenAIError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class OpenAIGPTConfig(BaseConfig):
|
||||
"""
|
||||
Reference: https://platform.openai.com/docs/api-reference/chat/create
|
||||
|
||||
The class `OpenAIConfig` provides configuration for the OpenAI's Chat API interface. Below are the parameters:
|
||||
|
||||
- `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition.
|
||||
|
||||
- `function_call` (string or object): This optional parameter controls how the model calls functions.
|
||||
|
||||
- `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs.
|
||||
|
||||
- `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.
|
||||
|
||||
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion.
|
||||
|
||||
- `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message.
|
||||
|
||||
- `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics.
|
||||
|
||||
- `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.
|
||||
|
||||
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2.
|
||||
|
||||
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
|
||||
"""
|
||||
|
||||
frequency_penalty: Optional[int] = None
|
||||
function_call: Optional[Union[str, dict]] = None
|
||||
functions: Optional[list] = None
|
||||
logit_bias: Optional[dict] = None
|
||||
max_tokens: Optional[int] = None
|
||||
n: Optional[int] = None
|
||||
presence_penalty: Optional[int] = None
|
||||
stop: Optional[Union[str, list]] = None
|
||||
temperature: Optional[int] = None
|
||||
top_p: Optional[int] = None
|
||||
response_format: Optional[dict] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
frequency_penalty: Optional[int] = None,
|
||||
function_call: Optional[Union[str, dict]] = None,
|
||||
functions: Optional[list] = None,
|
||||
logit_bias: Optional[dict] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
presence_penalty: Optional[int] = None,
|
||||
stop: Optional[Union[str, list]] = None,
|
||||
temperature: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
response_format: Optional[dict] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return super().get_config()
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
base_params = [
|
||||
"frequency_penalty",
|
||||
"logit_bias",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"modalities",
|
||||
"prediction",
|
||||
"n",
|
||||
"presence_penalty",
|
||||
"seed",
|
||||
"stop",
|
||||
"stream",
|
||||
"stream_options",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"function_call",
|
||||
"functions",
|
||||
"max_retries",
|
||||
"extra_headers",
|
||||
"parallel_tool_calls",
|
||||
] # works across all models
|
||||
|
||||
model_specific_params = []
|
||||
if (
|
||||
model != "gpt-3.5-turbo-16k" and model != "gpt-4"
|
||||
): # gpt-4 does not support 'response_format'
|
||||
model_specific_params.append("response_format")
|
||||
|
||||
if (
|
||||
model in litellm.open_ai_chat_completion_models
|
||||
) or model in litellm.open_ai_text_completion_models:
|
||||
model_specific_params.append(
|
||||
"user"
|
||||
) # user is not a param supported by all openai-compatible endpoints - e.g. azure ai
|
||||
return base_params + model_specific_params
|
||||
|
||||
def _map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
If any supported_openai_params are in non_default_params, add them to optional_params, so they are use in API call
|
||||
|
||||
Args:
|
||||
non_default_params (dict): Non-default parameters to filter.
|
||||
optional_params (dict): Optional parameters to update.
|
||||
model (str): Model name for parameter support check.
|
||||
|
||||
Returns:
|
||||
dict: Updated optional_params with supported non-default parameters.
|
||||
"""
|
||||
supported_openai_params = self.get_supported_openai_params(model)
|
||||
for param, value in non_default_params.items():
|
||||
if param in supported_openai_params:
|
||||
optional_params[param] = value
|
||||
return optional_params
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
return self._map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
drop_params=drop_params,
|
||||
)
|
||||
|
||||
def _transform_messages(
|
||||
self, messages: List[AllMessageValues]
|
||||
) -> List[AllMessageValues]:
|
||||
return messages
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform the overall request to be sent to the API.
|
||||
|
||||
Returns:
|
||||
dict: The transformed request. Sent as the body of the API call.
|
||||
"""
|
||||
return {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
encoding: str,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
"""
|
||||
Transform the response from the API.
|
||||
|
||||
Returns:
|
||||
dict: The transformed response.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return OpenAIError(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
headers=cast(httpx.Headers, headers),
|
||||
)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
raise NotImplementedError
|
Loading…
Add table
Add a link
Reference in a new issue