Litellm merge pr (#7161)

* build: merge branch

* test: fix openai naming

* fix(main.py): fix openai renaming

* style: ignore function length for config factory

* fix(sagemaker/): fix routing logic

* fix: fix imports

* fix: fix override
This commit is contained in:
Krish Dholakia 2024-12-10 22:49:26 -08:00 committed by GitHub
parent d5aae81c6d
commit 350cfc36f7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
88 changed files with 3617 additions and 4421 deletions

View file

@ -4,59 +4,42 @@ import time
import traceback
import types
from enum import Enum
from typing import Any, Callable, List, Optional
from typing import Any, Callable, List, Optional, Union
import requests # type: ignore
from httpx._models import Headers
import litellm
from litellm.llms.base_llm.transformation import BaseLLMException
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
from litellm.utils import Choices, Message, ModelResponse, Usage
class MaritalkError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class MaritalkError(BaseLLMException):
def __init__(
self,
status_code: int,
message: str,
headers: Optional[Union[dict, Headers]] = None,
):
super().__init__(status_code=status_code, message=message, headers=headers)
class MaritTalkConfig:
"""
The class `MaritTalkConfig` provides configuration for the MaritTalk's API interface. Here are the parameters:
- `max_tokens` (integer): Maximum number of tokens the model will generate as part of the response. Default is 1.
- `model` (string): The model used for conversation. Default is 'maritalk'.
- `do_sample` (boolean): If set to True, the API will generate a response using sampling. Default is True.
- `temperature` (number): A non-negative float controlling the randomness in generation. Lower temperatures result in less random generations. Default is 0.7.
- `top_p` (number): Selection threshold for token inclusion based on cumulative probability. Default is 0.95.
- `repetition_penalty` (number): Penalty for repetition in the generated conversation. Default is 1.
- `stopping_tokens` (list of string): List of tokens where the conversation can be stopped/stopped.
"""
max_tokens: Optional[int] = None
model: Optional[str] = None
do_sample: Optional[bool] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
repetition_penalty: Optional[float] = None
stopping_tokens: Optional[List[str]] = None
class MaritalkConfig(OpenAIGPTConfig):
def __init__(
self,
max_tokens: Optional[int] = None,
model: Optional[str] = None,
do_sample: Optional[bool] = None,
temperature: Optional[float] = None,
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
top_p: Optional[float] = None,
repetition_penalty: Optional[float] = None,
stopping_tokens: Optional[List[str]] = None,
top_k: Optional[int] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
stop: Optional[List[str]] = None,
stream: Optional[bool] = None,
stream_options: Optional[dict] = None,
tools: Optional[List[dict]] = None,
tool_choice: Optional[Union[str, dict]] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
@ -65,129 +48,27 @@ class MaritTalkConfig:
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
return super().get_config()
def get_supported_openai_params(self, model: str) -> List:
return [
"frequency_penalty",
"presence_penalty",
"top_p",
"top_k",
"temperature",
"max_tokens",
"n",
"stop",
"stream",
"stream_options",
"tools",
"tool_choice",
]
def validate_environment(api_key):
headers = {
"accept": "application/json",
"content-type": "application/json",
}
if api_key:
headers["Authorization"] = f"Key {api_key}"
return headers
def completion(
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params: dict,
litellm_params=None,
logger_fn=None,
):
headers = validate_environment(api_key)
completion_url = api_base
model = model
## Load Config
config = litellm.MaritTalkConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > maritalk_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
data = {
"messages": messages,
**optional_params,
}
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=api_key,
additional_args={"complete_input_dict": data},
)
## COMPLETION CALL
response = requests.post(
completion_url,
headers=headers,
data=json.dumps(data),
stream=optional_params["stream"] if "stream" in optional_params else False,
)
if "stream" in optional_params and optional_params["stream"] is True:
return response.iter_lines()
else:
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, Headers]
) -> BaseLLMException:
return MaritalkError(
status_code=status_code, message=error_message, headers=headers
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
completion_response = response.json()
if "error" in completion_response:
raise MaritalkError(
message=completion_response["error"],
status_code=response.status_code,
)
else:
try:
if len(completion_response["answer"]) > 0:
model_response.choices[0].message.content = completion_response[ # type: ignore
"answer"
]
except Exception:
raise MaritalkError(
message=response.text, status_code=response.status_code
)
## CALCULATING USAGE
prompt = "".join(m["content"] for m in messages)
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
model_response.created = int(time.time())
model_response.model = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
def embedding(
model: str,
input: list,
api_key: Optional[str],
logging_obj: Any,
model_response=None,
encoding=None,
):
pass