Merge branch 'main' into main

This commit is contained in:
Vincelwt 2024-03-19 12:50:04 +09:00 committed by GitHub
commit 1cbfd312fe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
133 changed files with 5662 additions and 1062 deletions

View file

@ -65,6 +65,7 @@ from .integrations.langsmith import LangsmithLogger
from .integrations.weights_biases import WeightsBiasesLogger
from .integrations.custom_logger import CustomLogger
from .integrations.langfuse import LangFuseLogger
from .integrations.datadog import DataDogLogger
from .integrations.dynamodb import DyanmoDBLogger
from .integrations.s3 import S3Logger
from .integrations.clickhouse import ClickhouseLogger
@ -72,7 +73,7 @@ from .integrations.litedebugger import LiteDebugger
from .proxy._types import KeyManagementSystem
from openai import OpenAIError as OriginalError
from openai._models import BaseModel as OpenAIObject
from .caching import S3Cache, RedisSemanticCache
from .caching import S3Cache, RedisSemanticCache, RedisCache
from .exceptions import (
AuthenticationError,
BadRequestError,
@ -121,6 +122,7 @@ langsmithLogger = None
weightsBiasesLogger = None
customLogger = None
langFuseLogger = None
dataDogLogger = None
dynamoLogger = None
s3Logger = None
genericAPILogger = None
@ -480,12 +482,12 @@ class ModelResponse(OpenAIObject):
object=None,
system_fingerprint=None,
usage=None,
stream=False,
stream=None,
response_ms=None,
hidden_params=None,
**params,
):
if stream:
if stream is not None and stream == True:
object = "chat.completion.chunk"
choices = [StreamingChoices()]
else:
@ -1483,6 +1485,33 @@ class Logging:
user_id=kwargs.get("user", None),
print_verbose=print_verbose,
)
if callback == "datadog":
global dataDogLogger
verbose_logger.debug("reaches datadog for success logging!")
kwargs = {}
for k, v in self.model_call_details.items():
if (
k != "original_response"
): # copy.deepcopy raises errors as this could be a coroutine
kwargs[k] = v
# this only logs streaming once, complete_streaming_response exists i.e when stream ends
if self.stream:
verbose_logger.debug(
f"datadog: is complete_streaming_response in kwargs: {kwargs.get('complete_streaming_response', None)}"
)
if complete_streaming_response is None:
continue
else:
print_verbose("reaches datadog for streaming logging!")
result = kwargs["complete_streaming_response"]
dataDogLogger.log_event(
kwargs=kwargs,
response_obj=result,
start_time=start_time,
end_time=end_time,
user_id=kwargs.get("user", None),
print_verbose=print_verbose,
)
if callback == "generic":
global genericAPILogger
verbose_logger.debug("reaches langfuse for success logging!")
@ -1805,7 +1834,12 @@ class Logging:
)
result = kwargs["async_complete_streaming_response"]
# only add to cache once we have a complete streaming response
litellm.cache.add_cache(result, **kwargs)
if litellm.cache is not None and not isinstance(
litellm.cache.cache, S3Cache
):
await litellm.cache.async_add_cache(result, **kwargs)
else:
litellm.cache.add_cache(result, **kwargs)
if isinstance(callback, CustomLogger): # custom logger class
print_verbose(
f"Running Async success callback: {callback}; self.stream: {self.stream}; async_complete_streaming_response: {self.model_call_details.get('async_complete_streaming_response', None)} result={result}"
@ -2601,7 +2635,7 @@ def client(original_function):
if (
kwargs.get("max_tokens", None) is not None
and model is not None
and litellm.drop_params
and litellm.modify_params
== True # user is okay with params being modified
and (
call_type == CallTypes.acompletion.value
@ -2818,7 +2852,9 @@ def client(original_function):
):
if len(cached_result) == 1 and cached_result[0] is None:
cached_result = None
elif isinstance(litellm.cache.cache, RedisSemanticCache):
elif isinstance(
litellm.cache.cache, RedisSemanticCache
) or isinstance(litellm.cache.cache, RedisCache):
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
kwargs["preset_cache_key"] = (
preset_cache_key # for streaming calls, we need to pass the preset_cache_key
@ -3853,7 +3889,9 @@ def completion_cost(
* n
)
else:
raise Exception(f"Model={model} not found in completion cost model map")
raise Exception(
f"Model={image_gen_model_name} not found in completion cost model map"
)
# Calculate cost based on prompt_tokens, completion_tokens
if (
"togethercomputer" in model
@ -4281,6 +4319,7 @@ def get_optional_params(
and custom_llm_provider != "together_ai"
and custom_llm_provider != "mistral"
and custom_llm_provider != "anthropic"
and custom_llm_provider != "cohere_chat"
and custom_llm_provider != "bedrock"
and custom_llm_provider != "ollama_chat"
):
@ -4412,6 +4451,31 @@ def get_optional_params(
optional_params["presence_penalty"] = presence_penalty
if stop is not None:
optional_params["stop_sequences"] = stop
elif custom_llm_provider == "cohere_chat":
## check if unsupported param passed in
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
# handle cohere params
if stream:
optional_params["stream"] = stream
if temperature is not None:
optional_params["temperature"] = temperature
if max_tokens is not None:
optional_params["max_tokens"] = max_tokens
if n is not None:
optional_params["num_generations"] = n
if top_p is not None:
optional_params["p"] = top_p
if frequency_penalty is not None:
optional_params["frequency_penalty"] = frequency_penalty
if presence_penalty is not None:
optional_params["presence_penalty"] = presence_penalty
if stop is not None:
optional_params["stop_sequences"] = stop
if tools is not None:
optional_params["tools"] = tools
elif custom_llm_provider == "maritalk":
## check if unsupported param passed in
supported_params = get_supported_openai_params(
@ -5095,6 +5159,19 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
"stop",
"n",
]
elif custom_llm_provider == "cohere_chat":
return [
"stream",
"temperature",
"max_tokens",
"top_p",
"frequency_penalty",
"presence_penalty",
"stop",
"n",
"tools",
"tool_choice",
]
elif custom_llm_provider == "maritalk":
return [
"stream",
@ -5155,6 +5232,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
"max_tokens",
"tools",
"tool_choice",
"response_format",
]
elif custom_llm_provider == "replicate":
return [
@ -5271,6 +5349,40 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
]
def get_formatted_prompt(
data: dict,
call_type: Literal[
"completion",
"embedding",
"image_generation",
"audio_transcription",
"moderation",
],
) -> str:
"""
Extracts the prompt from the input data based on the call type.
Returns a string.
"""
prompt = ""
if call_type == "completion":
for m in data["messages"]:
if "content" in m and isinstance(m["content"], str):
prompt += m["content"]
elif call_type == "embedding" or call_type == "moderation":
if isinstance(data["input"], str):
prompt = data["input"]
elif isinstance(data["input"], list):
for m in data["input"]:
prompt += m
elif call_type == "image_generation":
prompt = data["prompt"]
elif call_type == "audio_transcription":
if "prompt" in data:
prompt = data["prompt"]
return prompt
def get_llm_provider(
model: str,
custom_llm_provider: Optional[str] = None,
@ -5311,6 +5423,17 @@ def get_llm_provider(
# groq is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.groq.com/openai/v1
api_base = "https://api.groq.com/openai/v1"
dynamic_api_key = get_secret("GROQ_API_KEY")
elif custom_llm_provider == "fireworks_ai":
# fireworks is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.groq.com/openai/v1
if not model.startswith("accounts/fireworks/models"):
model = f"accounts/fireworks/models/{model}"
api_base = "https://api.fireworks.ai/inference/v1"
dynamic_api_key = (
get_secret("FIREWORKS_API_KEY")
or get_secret("FIREWORKS_AI_API_KEY")
or get_secret("FIREWORKSAI_API_KEY")
or get_secret("FIREWORKS_AI_TOKEN")
)
elif custom_llm_provider == "mistral":
# mistral is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.mistral.ai
api_base = (
@ -5382,6 +5505,9 @@ def get_llm_provider(
## cohere
elif model in litellm.cohere_models or model in litellm.cohere_embedding_models:
custom_llm_provider = "cohere"
## cohere chat models
elif model in litellm.cohere_chat_models:
custom_llm_provider = "cohere_chat"
## replicate
elif model in litellm.replicate_models or (":" in model and len(model) > 64):
model_parts = model.split(":")
@ -5997,7 +6123,9 @@ def validate_environment(model: Optional[str] = None) -> dict:
def set_callbacks(callback_list, function_id=None):
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, traceloopLogger, athinaLogger, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient, lunaryLogger, promptLayerLogger, langFuseLogger, customLogger, weightsBiasesLogger, langsmithLogger, dynamoLogger, s3Logger
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, traceloopLogger, athinaLogger, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient, lunaryLogger, promptLayerLogger, langFuseLogger, customLogger, weightsBiasesLogger, langsmithLogger, dynamoLogger, s3Logger, dataDogLogger
try:
for callback in callback_list:
print_verbose(f"callback: {callback}")
@ -6063,6 +6191,8 @@ def set_callbacks(callback_list, function_id=None):
promptLayerLogger = PromptLayerLogger()
elif callback == "langfuse":
langFuseLogger = LangFuseLogger()
elif callback == "datadog":
dataDogLogger = DataDogLogger()
elif callback == "dynamodb":
dynamoLogger = DyanmoDBLogger()
elif callback == "s3":
@ -6433,7 +6563,7 @@ def convert_to_model_response_object(
"system_fingerprint"
]
if "model" in response_object:
if "model" in response_object and model_response_object.model is None:
model_response_object.model = response_object["model"]
if start_time is not None and end_time is not None:
@ -7422,7 +7552,9 @@ def exception_type(
model=model,
response=original_exception.response,
)
elif custom_llm_provider == "cohere": # Cohere
elif (
custom_llm_provider == "cohere" or custom_llm_provider == "cohere_chat"
): # Cohere
if (
"invalid api token" in error_str
or "No API key provided." in error_str
@ -8201,8 +8333,10 @@ def get_secret(
default_value: Optional[Union[str, bool]] = None,
):
key_management_system = litellm._key_management_system
key_management_settings = litellm._key_management_settings
if secret_name.startswith("os.environ/"):
secret_name = secret_name.replace("os.environ/", "")
try:
if litellm.secret_manager_client is not None:
try:
@ -8210,6 +8344,13 @@ def get_secret(
key_manager = "local"
if key_management_system is not None:
key_manager = key_management_system.value
if key_management_settings is not None:
if (
secret_name not in key_management_settings.hosted_keys
): # allow user to specify which keys to check in hosted key manager
key_manager = "local"
if (
key_manager == KeyManagementSystem.AZURE_KEY_VAULT
or type(client).__module__ + "." + type(client).__name__
@ -8245,9 +8386,30 @@ def get_secret(
secret = response.plaintext.decode(
"utf-8"
) # assumes the original value was encoded with utf-8
elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value:
try:
get_secret_value_response = client.get_secret_value(
SecretId=secret_name
)
print_verbose(
f"get_secret_value_response: {get_secret_value_response}"
)
except Exception as e:
print_verbose(f"An error occurred - {str(e)}")
# For a list of exceptions thrown, see
# https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html
raise e
# assume there is 1 secret per secret_name
secret_dict = json.loads(get_secret_value_response["SecretString"])
print_verbose(f"secret_dict: {secret_dict}")
for k, v in secret_dict.items():
secret = v
print_verbose(f"secret: {secret}")
else: # assume the default is infisicial client
secret = client.get_secret(secret_name).secret_value
except Exception as e: # check if it's in os.environ
print_verbose(f"An exception occurred - {str(e)}")
secret = os.getenv(secret_name)
try:
secret_value_as_bool = ast.literal_eval(secret)
@ -8555,6 +8717,29 @@ class CustomStreamWrapper:
except:
raise ValueError(f"Unable to parse response. Original response: {chunk}")
def handle_cohere_chat_chunk(self, chunk):
chunk = chunk.decode("utf-8")
data_json = json.loads(chunk)
print_verbose(f"chunk: {chunk}")
try:
text = ""
is_finished = False
finish_reason = ""
if "text" in data_json:
text = data_json["text"]
elif "is_finished" in data_json and data_json["is_finished"] == True:
is_finished = data_json["is_finished"]
finish_reason = data_json["finish_reason"]
else:
return
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
}
except:
raise ValueError(f"Unable to parse response. Original response: {chunk}")
def handle_azure_chunk(self, chunk):
is_finished = False
finish_reason = ""
@ -8667,6 +8852,27 @@ class CustomStreamWrapper:
traceback.print_exc()
raise e
def handle_azure_text_completion_chunk(self, chunk):
try:
print_verbose(f"\nRaw OpenAI Chunk\n{chunk}\n")
text = ""
is_finished = False
finish_reason = None
choices = getattr(chunk, "choices", [])
if len(choices) > 0:
text = choices[0].text
if choices[0].finish_reason is not None:
is_finished = True
finish_reason = choices[0].finish_reason
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
}
except Exception as e:
raise e
def handle_openai_text_completion_chunk(self, chunk):
try:
print_verbose(f"\nRaw OpenAI Chunk\n{chunk}\n")
@ -9032,7 +9238,41 @@ class CustomStreamWrapper:
try:
if hasattr(chunk, "candidates") == True:
try:
completion_obj["content"] = chunk.text
try:
completion_obj["content"] = chunk.text
except Exception as e:
if "Part has no text." in str(e):
## check for function calling
function_call = (
chunk.candidates[0]
.content.parts[0]
.function_call
)
args_dict = {}
for k, v in function_call.args.items():
args_dict[k] = v
args_str = json.dumps(args_dict)
_delta_obj = litellm.utils.Delta(
content=None,
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {
"arguments": args_str,
"name": function_call.name,
},
"type": "function",
}
],
)
_streaming_response = StreamingChoices(
delta=_delta_obj
)
_model_response = ModelResponse(stream=True)
_model_response.choices = [_streaming_response]
response_obj = {"original_chunk": _model_response}
else:
raise e
if (
hasattr(chunk.candidates[0], "finish_reason")
and chunk.candidates[0].finish_reason.name
@ -9043,7 +9283,7 @@ class CustomStreamWrapper:
chunk.candidates[0].finish_reason.name
)
)
except:
except Exception as e:
if chunk.candidates[0].finish_reason.name == "SAFETY":
raise Exception(
f"The response was blocked by VertexAI. {str(chunk)}"
@ -9063,6 +9303,15 @@ class CustomStreamWrapper:
model_response.choices[0].finish_reason = response_obj[
"finish_reason"
]
elif self.custom_llm_provider == "cohere_chat":
response_obj = self.handle_cohere_chat_chunk(chunk)
if response_obj is None:
return
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[
"finish_reason"
]
elif self.custom_llm_provider == "bedrock":
if self.sent_last_chunk:
raise StopIteration
@ -9140,6 +9389,14 @@ class CustomStreamWrapper:
model_response.choices[0].finish_reason = response_obj[
"finish_reason"
]
elif self.custom_llm_provider == "azure_text":
response_obj = self.handle_azure_text_completion_chunk(chunk)
completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}")
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[
"finish_reason"
]
elif self.custom_llm_provider == "cached_response":
response_obj = {
"text": chunk.choices[0].delta.content,
@ -9375,14 +9632,18 @@ class CustomStreamWrapper:
def __next__(self):
try:
while True:
if isinstance(self.completion_stream, str) or isinstance(
self.completion_stream, bytes
if (
isinstance(self.completion_stream, str)
or isinstance(self.completion_stream, bytes)
or isinstance(self.completion_stream, ModelResponse)
):
chunk = self.completion_stream
else:
chunk = next(self.completion_stream)
if chunk is not None and chunk != b"":
print_verbose(f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}")
print_verbose(
f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}; custom_llm_provider: {self.custom_llm_provider}"
)
response: Optional[ModelResponse] = self.chunk_creator(chunk=chunk)
print_verbose(f"PROCESSED CHUNK POST CHUNK CREATOR: {response}")
@ -9417,6 +9678,7 @@ class CustomStreamWrapper:
or self.custom_llm_provider == "azure"
or self.custom_llm_provider == "custom_openai"
or self.custom_llm_provider == "text-completion-openai"
or self.custom_llm_provider == "azure_text"
or self.custom_llm_provider == "huggingface"
or self.custom_llm_provider == "ollama"
or self.custom_llm_provider == "ollama_chat"