forked from phoenix/litellm-mirror
Merge branch 'main' into main
This commit is contained in:
commit
1cbfd312fe
133 changed files with 5662 additions and 1062 deletions
292
litellm/utils.py
292
litellm/utils.py
|
@ -65,6 +65,7 @@ from .integrations.langsmith import LangsmithLogger
|
|||
from .integrations.weights_biases import WeightsBiasesLogger
|
||||
from .integrations.custom_logger import CustomLogger
|
||||
from .integrations.langfuse import LangFuseLogger
|
||||
from .integrations.datadog import DataDogLogger
|
||||
from .integrations.dynamodb import DyanmoDBLogger
|
||||
from .integrations.s3 import S3Logger
|
||||
from .integrations.clickhouse import ClickhouseLogger
|
||||
|
@ -72,7 +73,7 @@ from .integrations.litedebugger import LiteDebugger
|
|||
from .proxy._types import KeyManagementSystem
|
||||
from openai import OpenAIError as OriginalError
|
||||
from openai._models import BaseModel as OpenAIObject
|
||||
from .caching import S3Cache, RedisSemanticCache
|
||||
from .caching import S3Cache, RedisSemanticCache, RedisCache
|
||||
from .exceptions import (
|
||||
AuthenticationError,
|
||||
BadRequestError,
|
||||
|
@ -121,6 +122,7 @@ langsmithLogger = None
|
|||
weightsBiasesLogger = None
|
||||
customLogger = None
|
||||
langFuseLogger = None
|
||||
dataDogLogger = None
|
||||
dynamoLogger = None
|
||||
s3Logger = None
|
||||
genericAPILogger = None
|
||||
|
@ -480,12 +482,12 @@ class ModelResponse(OpenAIObject):
|
|||
object=None,
|
||||
system_fingerprint=None,
|
||||
usage=None,
|
||||
stream=False,
|
||||
stream=None,
|
||||
response_ms=None,
|
||||
hidden_params=None,
|
||||
**params,
|
||||
):
|
||||
if stream:
|
||||
if stream is not None and stream == True:
|
||||
object = "chat.completion.chunk"
|
||||
choices = [StreamingChoices()]
|
||||
else:
|
||||
|
@ -1483,6 +1485,33 @@ class Logging:
|
|||
user_id=kwargs.get("user", None),
|
||||
print_verbose=print_verbose,
|
||||
)
|
||||
if callback == "datadog":
|
||||
global dataDogLogger
|
||||
verbose_logger.debug("reaches datadog for success logging!")
|
||||
kwargs = {}
|
||||
for k, v in self.model_call_details.items():
|
||||
if (
|
||||
k != "original_response"
|
||||
): # copy.deepcopy raises errors as this could be a coroutine
|
||||
kwargs[k] = v
|
||||
# this only logs streaming once, complete_streaming_response exists i.e when stream ends
|
||||
if self.stream:
|
||||
verbose_logger.debug(
|
||||
f"datadog: is complete_streaming_response in kwargs: {kwargs.get('complete_streaming_response', None)}"
|
||||
)
|
||||
if complete_streaming_response is None:
|
||||
continue
|
||||
else:
|
||||
print_verbose("reaches datadog for streaming logging!")
|
||||
result = kwargs["complete_streaming_response"]
|
||||
dataDogLogger.log_event(
|
||||
kwargs=kwargs,
|
||||
response_obj=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
user_id=kwargs.get("user", None),
|
||||
print_verbose=print_verbose,
|
||||
)
|
||||
if callback == "generic":
|
||||
global genericAPILogger
|
||||
verbose_logger.debug("reaches langfuse for success logging!")
|
||||
|
@ -1805,7 +1834,12 @@ class Logging:
|
|||
)
|
||||
result = kwargs["async_complete_streaming_response"]
|
||||
# only add to cache once we have a complete streaming response
|
||||
litellm.cache.add_cache(result, **kwargs)
|
||||
if litellm.cache is not None and not isinstance(
|
||||
litellm.cache.cache, S3Cache
|
||||
):
|
||||
await litellm.cache.async_add_cache(result, **kwargs)
|
||||
else:
|
||||
litellm.cache.add_cache(result, **kwargs)
|
||||
if isinstance(callback, CustomLogger): # custom logger class
|
||||
print_verbose(
|
||||
f"Running Async success callback: {callback}; self.stream: {self.stream}; async_complete_streaming_response: {self.model_call_details.get('async_complete_streaming_response', None)} result={result}"
|
||||
|
@ -2601,7 +2635,7 @@ def client(original_function):
|
|||
if (
|
||||
kwargs.get("max_tokens", None) is not None
|
||||
and model is not None
|
||||
and litellm.drop_params
|
||||
and litellm.modify_params
|
||||
== True # user is okay with params being modified
|
||||
and (
|
||||
call_type == CallTypes.acompletion.value
|
||||
|
@ -2818,7 +2852,9 @@ def client(original_function):
|
|||
):
|
||||
if len(cached_result) == 1 and cached_result[0] is None:
|
||||
cached_result = None
|
||||
elif isinstance(litellm.cache.cache, RedisSemanticCache):
|
||||
elif isinstance(
|
||||
litellm.cache.cache, RedisSemanticCache
|
||||
) or isinstance(litellm.cache.cache, RedisCache):
|
||||
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
|
||||
kwargs["preset_cache_key"] = (
|
||||
preset_cache_key # for streaming calls, we need to pass the preset_cache_key
|
||||
|
@ -3853,7 +3889,9 @@ def completion_cost(
|
|||
* n
|
||||
)
|
||||
else:
|
||||
raise Exception(f"Model={model} not found in completion cost model map")
|
||||
raise Exception(
|
||||
f"Model={image_gen_model_name} not found in completion cost model map"
|
||||
)
|
||||
# Calculate cost based on prompt_tokens, completion_tokens
|
||||
if (
|
||||
"togethercomputer" in model
|
||||
|
@ -4281,6 +4319,7 @@ def get_optional_params(
|
|||
and custom_llm_provider != "together_ai"
|
||||
and custom_llm_provider != "mistral"
|
||||
and custom_llm_provider != "anthropic"
|
||||
and custom_llm_provider != "cohere_chat"
|
||||
and custom_llm_provider != "bedrock"
|
||||
and custom_llm_provider != "ollama_chat"
|
||||
):
|
||||
|
@ -4412,6 +4451,31 @@ def get_optional_params(
|
|||
optional_params["presence_penalty"] = presence_penalty
|
||||
if stop is not None:
|
||||
optional_params["stop_sequences"] = stop
|
||||
elif custom_llm_provider == "cohere_chat":
|
||||
## check if unsupported param passed in
|
||||
supported_params = get_supported_openai_params(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
# handle cohere params
|
||||
if stream:
|
||||
optional_params["stream"] = stream
|
||||
if temperature is not None:
|
||||
optional_params["temperature"] = temperature
|
||||
if max_tokens is not None:
|
||||
optional_params["max_tokens"] = max_tokens
|
||||
if n is not None:
|
||||
optional_params["num_generations"] = n
|
||||
if top_p is not None:
|
||||
optional_params["p"] = top_p
|
||||
if frequency_penalty is not None:
|
||||
optional_params["frequency_penalty"] = frequency_penalty
|
||||
if presence_penalty is not None:
|
||||
optional_params["presence_penalty"] = presence_penalty
|
||||
if stop is not None:
|
||||
optional_params["stop_sequences"] = stop
|
||||
if tools is not None:
|
||||
optional_params["tools"] = tools
|
||||
elif custom_llm_provider == "maritalk":
|
||||
## check if unsupported param passed in
|
||||
supported_params = get_supported_openai_params(
|
||||
|
@ -5095,6 +5159,19 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
|
|||
"stop",
|
||||
"n",
|
||||
]
|
||||
elif custom_llm_provider == "cohere_chat":
|
||||
return [
|
||||
"stream",
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"top_p",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"stop",
|
||||
"n",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
]
|
||||
elif custom_llm_provider == "maritalk":
|
||||
return [
|
||||
"stream",
|
||||
|
@ -5155,6 +5232,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
|
|||
"max_tokens",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"response_format",
|
||||
]
|
||||
elif custom_llm_provider == "replicate":
|
||||
return [
|
||||
|
@ -5271,6 +5349,40 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
|
|||
]
|
||||
|
||||
|
||||
def get_formatted_prompt(
|
||||
data: dict,
|
||||
call_type: Literal[
|
||||
"completion",
|
||||
"embedding",
|
||||
"image_generation",
|
||||
"audio_transcription",
|
||||
"moderation",
|
||||
],
|
||||
) -> str:
|
||||
"""
|
||||
Extracts the prompt from the input data based on the call type.
|
||||
|
||||
Returns a string.
|
||||
"""
|
||||
prompt = ""
|
||||
if call_type == "completion":
|
||||
for m in data["messages"]:
|
||||
if "content" in m and isinstance(m["content"], str):
|
||||
prompt += m["content"]
|
||||
elif call_type == "embedding" or call_type == "moderation":
|
||||
if isinstance(data["input"], str):
|
||||
prompt = data["input"]
|
||||
elif isinstance(data["input"], list):
|
||||
for m in data["input"]:
|
||||
prompt += m
|
||||
elif call_type == "image_generation":
|
||||
prompt = data["prompt"]
|
||||
elif call_type == "audio_transcription":
|
||||
if "prompt" in data:
|
||||
prompt = data["prompt"]
|
||||
return prompt
|
||||
|
||||
|
||||
def get_llm_provider(
|
||||
model: str,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
|
@ -5311,6 +5423,17 @@ def get_llm_provider(
|
|||
# groq is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.groq.com/openai/v1
|
||||
api_base = "https://api.groq.com/openai/v1"
|
||||
dynamic_api_key = get_secret("GROQ_API_KEY")
|
||||
elif custom_llm_provider == "fireworks_ai":
|
||||
# fireworks is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.groq.com/openai/v1
|
||||
if not model.startswith("accounts/fireworks/models"):
|
||||
model = f"accounts/fireworks/models/{model}"
|
||||
api_base = "https://api.fireworks.ai/inference/v1"
|
||||
dynamic_api_key = (
|
||||
get_secret("FIREWORKS_API_KEY")
|
||||
or get_secret("FIREWORKS_AI_API_KEY")
|
||||
or get_secret("FIREWORKSAI_API_KEY")
|
||||
or get_secret("FIREWORKS_AI_TOKEN")
|
||||
)
|
||||
elif custom_llm_provider == "mistral":
|
||||
# mistral is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.mistral.ai
|
||||
api_base = (
|
||||
|
@ -5382,6 +5505,9 @@ def get_llm_provider(
|
|||
## cohere
|
||||
elif model in litellm.cohere_models or model in litellm.cohere_embedding_models:
|
||||
custom_llm_provider = "cohere"
|
||||
## cohere chat models
|
||||
elif model in litellm.cohere_chat_models:
|
||||
custom_llm_provider = "cohere_chat"
|
||||
## replicate
|
||||
elif model in litellm.replicate_models or (":" in model and len(model) > 64):
|
||||
model_parts = model.split(":")
|
||||
|
@ -5997,7 +6123,9 @@ def validate_environment(model: Optional[str] = None) -> dict:
|
|||
|
||||
|
||||
def set_callbacks(callback_list, function_id=None):
|
||||
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, traceloopLogger, athinaLogger, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient, lunaryLogger, promptLayerLogger, langFuseLogger, customLogger, weightsBiasesLogger, langsmithLogger, dynamoLogger, s3Logger
|
||||
|
||||
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, traceloopLogger, athinaLogger, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient, lunaryLogger, promptLayerLogger, langFuseLogger, customLogger, weightsBiasesLogger, langsmithLogger, dynamoLogger, s3Logger, dataDogLogger
|
||||
|
||||
try:
|
||||
for callback in callback_list:
|
||||
print_verbose(f"callback: {callback}")
|
||||
|
@ -6063,6 +6191,8 @@ def set_callbacks(callback_list, function_id=None):
|
|||
promptLayerLogger = PromptLayerLogger()
|
||||
elif callback == "langfuse":
|
||||
langFuseLogger = LangFuseLogger()
|
||||
elif callback == "datadog":
|
||||
dataDogLogger = DataDogLogger()
|
||||
elif callback == "dynamodb":
|
||||
dynamoLogger = DyanmoDBLogger()
|
||||
elif callback == "s3":
|
||||
|
@ -6433,7 +6563,7 @@ def convert_to_model_response_object(
|
|||
"system_fingerprint"
|
||||
]
|
||||
|
||||
if "model" in response_object:
|
||||
if "model" in response_object and model_response_object.model is None:
|
||||
model_response_object.model = response_object["model"]
|
||||
|
||||
if start_time is not None and end_time is not None:
|
||||
|
@ -7422,7 +7552,9 @@ def exception_type(
|
|||
model=model,
|
||||
response=original_exception.response,
|
||||
)
|
||||
elif custom_llm_provider == "cohere": # Cohere
|
||||
elif (
|
||||
custom_llm_provider == "cohere" or custom_llm_provider == "cohere_chat"
|
||||
): # Cohere
|
||||
if (
|
||||
"invalid api token" in error_str
|
||||
or "No API key provided." in error_str
|
||||
|
@ -8201,8 +8333,10 @@ def get_secret(
|
|||
default_value: Optional[Union[str, bool]] = None,
|
||||
):
|
||||
key_management_system = litellm._key_management_system
|
||||
key_management_settings = litellm._key_management_settings
|
||||
if secret_name.startswith("os.environ/"):
|
||||
secret_name = secret_name.replace("os.environ/", "")
|
||||
|
||||
try:
|
||||
if litellm.secret_manager_client is not None:
|
||||
try:
|
||||
|
@ -8210,6 +8344,13 @@ def get_secret(
|
|||
key_manager = "local"
|
||||
if key_management_system is not None:
|
||||
key_manager = key_management_system.value
|
||||
|
||||
if key_management_settings is not None:
|
||||
if (
|
||||
secret_name not in key_management_settings.hosted_keys
|
||||
): # allow user to specify which keys to check in hosted key manager
|
||||
key_manager = "local"
|
||||
|
||||
if (
|
||||
key_manager == KeyManagementSystem.AZURE_KEY_VAULT
|
||||
or type(client).__module__ + "." + type(client).__name__
|
||||
|
@ -8245,9 +8386,30 @@ def get_secret(
|
|||
secret = response.plaintext.decode(
|
||||
"utf-8"
|
||||
) # assumes the original value was encoded with utf-8
|
||||
elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value:
|
||||
try:
|
||||
get_secret_value_response = client.get_secret_value(
|
||||
SecretId=secret_name
|
||||
)
|
||||
print_verbose(
|
||||
f"get_secret_value_response: {get_secret_value_response}"
|
||||
)
|
||||
except Exception as e:
|
||||
print_verbose(f"An error occurred - {str(e)}")
|
||||
# For a list of exceptions thrown, see
|
||||
# https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html
|
||||
raise e
|
||||
|
||||
# assume there is 1 secret per secret_name
|
||||
secret_dict = json.loads(get_secret_value_response["SecretString"])
|
||||
print_verbose(f"secret_dict: {secret_dict}")
|
||||
for k, v in secret_dict.items():
|
||||
secret = v
|
||||
print_verbose(f"secret: {secret}")
|
||||
else: # assume the default is infisicial client
|
||||
secret = client.get_secret(secret_name).secret_value
|
||||
except Exception as e: # check if it's in os.environ
|
||||
print_verbose(f"An exception occurred - {str(e)}")
|
||||
secret = os.getenv(secret_name)
|
||||
try:
|
||||
secret_value_as_bool = ast.literal_eval(secret)
|
||||
|
@ -8555,6 +8717,29 @@ class CustomStreamWrapper:
|
|||
except:
|
||||
raise ValueError(f"Unable to parse response. Original response: {chunk}")
|
||||
|
||||
def handle_cohere_chat_chunk(self, chunk):
|
||||
chunk = chunk.decode("utf-8")
|
||||
data_json = json.loads(chunk)
|
||||
print_verbose(f"chunk: {chunk}")
|
||||
try:
|
||||
text = ""
|
||||
is_finished = False
|
||||
finish_reason = ""
|
||||
if "text" in data_json:
|
||||
text = data_json["text"]
|
||||
elif "is_finished" in data_json and data_json["is_finished"] == True:
|
||||
is_finished = data_json["is_finished"]
|
||||
finish_reason = data_json["finish_reason"]
|
||||
else:
|
||||
return
|
||||
return {
|
||||
"text": text,
|
||||
"is_finished": is_finished,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
except:
|
||||
raise ValueError(f"Unable to parse response. Original response: {chunk}")
|
||||
|
||||
def handle_azure_chunk(self, chunk):
|
||||
is_finished = False
|
||||
finish_reason = ""
|
||||
|
@ -8667,6 +8852,27 @@ class CustomStreamWrapper:
|
|||
traceback.print_exc()
|
||||
raise e
|
||||
|
||||
def handle_azure_text_completion_chunk(self, chunk):
|
||||
try:
|
||||
print_verbose(f"\nRaw OpenAI Chunk\n{chunk}\n")
|
||||
text = ""
|
||||
is_finished = False
|
||||
finish_reason = None
|
||||
choices = getattr(chunk, "choices", [])
|
||||
if len(choices) > 0:
|
||||
text = choices[0].text
|
||||
if choices[0].finish_reason is not None:
|
||||
is_finished = True
|
||||
finish_reason = choices[0].finish_reason
|
||||
return {
|
||||
"text": text,
|
||||
"is_finished": is_finished,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def handle_openai_text_completion_chunk(self, chunk):
|
||||
try:
|
||||
print_verbose(f"\nRaw OpenAI Chunk\n{chunk}\n")
|
||||
|
@ -9032,7 +9238,41 @@ class CustomStreamWrapper:
|
|||
try:
|
||||
if hasattr(chunk, "candidates") == True:
|
||||
try:
|
||||
completion_obj["content"] = chunk.text
|
||||
try:
|
||||
completion_obj["content"] = chunk.text
|
||||
except Exception as e:
|
||||
if "Part has no text." in str(e):
|
||||
## check for function calling
|
||||
function_call = (
|
||||
chunk.candidates[0]
|
||||
.content.parts[0]
|
||||
.function_call
|
||||
)
|
||||
args_dict = {}
|
||||
for k, v in function_call.args.items():
|
||||
args_dict[k] = v
|
||||
args_str = json.dumps(args_dict)
|
||||
_delta_obj = litellm.utils.Delta(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
{
|
||||
"id": f"call_{str(uuid.uuid4())}",
|
||||
"function": {
|
||||
"arguments": args_str,
|
||||
"name": function_call.name,
|
||||
},
|
||||
"type": "function",
|
||||
}
|
||||
],
|
||||
)
|
||||
_streaming_response = StreamingChoices(
|
||||
delta=_delta_obj
|
||||
)
|
||||
_model_response = ModelResponse(stream=True)
|
||||
_model_response.choices = [_streaming_response]
|
||||
response_obj = {"original_chunk": _model_response}
|
||||
else:
|
||||
raise e
|
||||
if (
|
||||
hasattr(chunk.candidates[0], "finish_reason")
|
||||
and chunk.candidates[0].finish_reason.name
|
||||
|
@ -9043,7 +9283,7 @@ class CustomStreamWrapper:
|
|||
chunk.candidates[0].finish_reason.name
|
||||
)
|
||||
)
|
||||
except:
|
||||
except Exception as e:
|
||||
if chunk.candidates[0].finish_reason.name == "SAFETY":
|
||||
raise Exception(
|
||||
f"The response was blocked by VertexAI. {str(chunk)}"
|
||||
|
@ -9063,6 +9303,15 @@ class CustomStreamWrapper:
|
|||
model_response.choices[0].finish_reason = response_obj[
|
||||
"finish_reason"
|
||||
]
|
||||
elif self.custom_llm_provider == "cohere_chat":
|
||||
response_obj = self.handle_cohere_chat_chunk(chunk)
|
||||
if response_obj is None:
|
||||
return
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
if response_obj["is_finished"]:
|
||||
model_response.choices[0].finish_reason = response_obj[
|
||||
"finish_reason"
|
||||
]
|
||||
elif self.custom_llm_provider == "bedrock":
|
||||
if self.sent_last_chunk:
|
||||
raise StopIteration
|
||||
|
@ -9140,6 +9389,14 @@ class CustomStreamWrapper:
|
|||
model_response.choices[0].finish_reason = response_obj[
|
||||
"finish_reason"
|
||||
]
|
||||
elif self.custom_llm_provider == "azure_text":
|
||||
response_obj = self.handle_azure_text_completion_chunk(chunk)
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
print_verbose(f"completion obj content: {completion_obj['content']}")
|
||||
if response_obj["is_finished"]:
|
||||
model_response.choices[0].finish_reason = response_obj[
|
||||
"finish_reason"
|
||||
]
|
||||
elif self.custom_llm_provider == "cached_response":
|
||||
response_obj = {
|
||||
"text": chunk.choices[0].delta.content,
|
||||
|
@ -9375,14 +9632,18 @@ class CustomStreamWrapper:
|
|||
def __next__(self):
|
||||
try:
|
||||
while True:
|
||||
if isinstance(self.completion_stream, str) or isinstance(
|
||||
self.completion_stream, bytes
|
||||
if (
|
||||
isinstance(self.completion_stream, str)
|
||||
or isinstance(self.completion_stream, bytes)
|
||||
or isinstance(self.completion_stream, ModelResponse)
|
||||
):
|
||||
chunk = self.completion_stream
|
||||
else:
|
||||
chunk = next(self.completion_stream)
|
||||
if chunk is not None and chunk != b"":
|
||||
print_verbose(f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}")
|
||||
print_verbose(
|
||||
f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}; custom_llm_provider: {self.custom_llm_provider}"
|
||||
)
|
||||
response: Optional[ModelResponse] = self.chunk_creator(chunk=chunk)
|
||||
print_verbose(f"PROCESSED CHUNK POST CHUNK CREATOR: {response}")
|
||||
|
||||
|
@ -9417,6 +9678,7 @@ class CustomStreamWrapper:
|
|||
or self.custom_llm_provider == "azure"
|
||||
or self.custom_llm_provider == "custom_openai"
|
||||
or self.custom_llm_provider == "text-completion-openai"
|
||||
or self.custom_llm_provider == "azure_text"
|
||||
or self.custom_llm_provider == "huggingface"
|
||||
or self.custom_llm_provider == "ollama"
|
||||
or self.custom_llm_provider == "ollama_chat"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue