Merge branch 'main' into feat/friendliai

This commit is contained in:
Wonseok Lee (Jack) 2024-06-13 09:59:56 +09:00 committed by GitHub
commit 776c75c1e5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
99 changed files with 202794 additions and 632 deletions

View file

@ -30,7 +30,7 @@ from dataclasses import (
dataclass,
field,
)
import os
import litellm._service_logger # for storing API inputs, outputs, and metadata
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
from litellm.caching import DualCache
@ -49,9 +49,9 @@ except (ImportError, AttributeError):
filename = pkg_resources.resource_filename(__name__, "llms/tokenizers")
os.environ["TIKTOKEN_CACHE_DIR"] = (
filename # use local copy of tiktoken b/c of - https://github.com/BerriAI/litellm/issues/1071
)
os.environ["TIKTOKEN_CACHE_DIR"] = os.getenv(
"CUSTOM_TIKTOKEN_CACHE_DIR", filename
) # use local copy of tiktoken b/c of - https://github.com/BerriAI/litellm/issues/1071
encoding = tiktoken.get_encoding("cl100k_base")
from importlib import resources
@ -63,6 +63,11 @@ claude_json_str = json.dumps(json_data)
import importlib.metadata
from ._logging import verbose_logger
from .types.router import LiteLLM_Params
from .types.llms.openai import (
ChatCompletionToolCallChunk,
ChatCompletionToolCallFunctionChunk,
ChatCompletionDeltaToolCallChunk,
)
from .integrations.traceloop import TraceloopLogger
from .integrations.athina import AthinaLogger
from .integrations.helicone import HeliconeLogger
@ -933,7 +938,6 @@ class TextCompletionResponse(OpenAIObject):
object=None,
**params,
):
if stream:
object = "text_completion.chunk"
choices = [TextChoices()]
@ -942,7 +946,6 @@ class TextCompletionResponse(OpenAIObject):
if choices is not None and isinstance(choices, list):
new_choices = []
for choice in choices:
if isinstance(choice, TextChoices):
_new_choice = choice
elif isinstance(choice, dict):
@ -1018,7 +1021,6 @@ class ImageObject(OpenAIObject):
revised_prompt: Optional[str] = None
def __init__(self, b64_json=None, url=None, revised_prompt=None):
super().__init__(b64_json=b64_json, url=url, revised_prompt=revised_prompt)
def __contains__(self, key):
@ -1342,28 +1344,29 @@ class Logging:
)
else:
verbose_logger.debug(f"\033[92m{curl_command}\033[0m\n")
# log raw request to provider (like LangFuse)
try:
# [Non-blocking Extra Debug Information in metadata]
_litellm_params = self.model_call_details.get("litellm_params", {})
_metadata = _litellm_params.get("metadata", {}) or {}
if (
litellm.turn_off_message_logging is not None
and litellm.turn_off_message_logging is True
):
# log raw request to provider (like LangFuse) -- if opted in.
if litellm.log_raw_request_response is True:
try:
# [Non-blocking Extra Debug Information in metadata]
_litellm_params = self.model_call_details.get("litellm_params", {})
_metadata = _litellm_params.get("metadata", {}) or {}
if (
litellm.turn_off_message_logging is not None
and litellm.turn_off_message_logging is True
):
_metadata["raw_request"] = (
"redacted by litellm. \
'litellm.turn_off_message_logging=True'"
)
else:
_metadata["raw_request"] = str(curl_command)
except Exception as e:
_metadata["raw_request"] = (
"redacted by litellm. \
'litellm.turn_off_message_logging=True'"
"Unable to Log \
raw request: {}".format(
str(e)
)
)
else:
_metadata["raw_request"] = str(curl_command)
except Exception as e:
_metadata["raw_request"] = (
"Unable to Log \
raw request: {}".format(
str(e)
)
)
if self.logger_fn and callable(self.logger_fn):
try:
self.logger_fn(
@ -1621,7 +1624,6 @@ class Logging:
end_time=end_time,
)
except Exception as e:
complete_streaming_response = None
else:
self.sync_streaming_chunks.append(result)
@ -2391,7 +2393,6 @@ class Logging:
"async_complete_streaming_response"
in self.model_call_details
):
await customLogger.async_log_event(
kwargs=self.model_call_details,
response_obj=self.model_call_details[
@ -2730,7 +2731,7 @@ class Logging:
only redacts when litellm.turn_off_message_logging == True
"""
# check if user opted out of logging message/response to callbacks
if litellm.turn_off_message_logging == True:
if litellm.turn_off_message_logging is True:
# remove messages, prompts, input, response from logging
self.model_call_details["messages"] = [
{"role": "user", "content": "redacted-by-litellm"}
@ -3250,7 +3251,7 @@ def client(original_function):
stream=kwargs.get("stream", False),
)
if kwargs.get("stream", False) == True:
if kwargs.get("stream", False) is True:
cached_result = CustomStreamWrapper(
completion_stream=cached_result,
model=model,
@ -4030,7 +4031,10 @@ def openai_token_counter(
"""
print_verbose(f"LiteLLM: Utils - Counting tokens for OpenAI model={model}")
try:
encoding = tiktoken.encoding_for_model(model)
if "gpt-4o" in model:
encoding = tiktoken.get_encoding("o200k_base")
else:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
print_verbose("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
@ -4894,6 +4898,18 @@ def get_optional_params_embeddings(
)
final_params = {**optional_params, **kwargs}
return final_params
if custom_llm_provider == "vertex_ai":
supported_params = get_supported_openai_params(
model=model,
custom_llm_provider="vertex_ai",
request_type="embeddings",
)
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.VertexAITextEmbeddingConfig().map_openai_params(
non_default_params=non_default_params, optional_params={}
)
final_params = {**optional_params, **kwargs}
return final_params
if custom_llm_provider == "vertex_ai":
if len(non_default_params.keys()) > 0:
if litellm.drop_params is True: # drop the unsupported non-default values
@ -4927,7 +4943,18 @@ def get_optional_params_embeddings(
message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.",
)
return {**non_default_params, **kwargs}
if custom_llm_provider == "mistral":
supported_params = get_supported_openai_params(
model=model,
custom_llm_provider="mistral",
request_type="embeddings",
)
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.MistralEmbeddingConfig().map_openai_params(
non_default_params=non_default_params, optional_params={}
)
final_params = {**optional_params, **kwargs}
return final_params
if (
custom_llm_provider != "openai"
and custom_llm_provider != "azure"
@ -6166,13 +6193,16 @@ def get_api_base(
if litellm.model_alias_map and model in litellm.model_alias_map:
model = litellm.model_alias_map[model]
try:
model, custom_llm_provider, dynamic_api_key, dynamic_api_base = (
get_llm_provider(
model=model,
custom_llm_provider=_optional_params.custom_llm_provider,
api_base=_optional_params.api_base,
api_key=_optional_params.api_key,
)
(
model,
custom_llm_provider,
dynamic_api_key,
dynamic_api_base,
) = get_llm_provider(
model=model,
custom_llm_provider=_optional_params.custom_llm_provider,
api_base=_optional_params.api_base,
api_key=_optional_params.api_key,
)
except Exception as e:
verbose_logger.debug("Error occurred in getting api base - {}".format(str(e)))
@ -6220,7 +6250,7 @@ def get_first_chars_messages(kwargs: dict) -> str:
def get_supported_openai_params(
model: str,
custom_llm_provider: str,
custom_llm_provider: Optional[str] = None,
request_type: Literal["chat_completion", "embeddings"] = "chat_completion",
) -> Optional[list]:
"""
@ -6235,6 +6265,11 @@ def get_supported_openai_params(
- List if custom_llm_provider is mapped
- None if unmapped
"""
if not custom_llm_provider:
try:
custom_llm_provider = litellm.get_llm_provider(model=model)[1]
except BadRequestError:
return None
if custom_llm_provider == "bedrock":
return litellm.AmazonConverseConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "ollama":
@ -6328,7 +6363,10 @@ def get_supported_openai_params(
"max_retries",
]
elif custom_llm_provider == "mistral":
return litellm.MistralConfig().get_supported_openai_params()
if request_type == "chat_completion":
return litellm.MistralConfig().get_supported_openai_params()
elif request_type == "embeddings":
return litellm.MistralEmbeddingConfig().get_supported_openai_params()
elif custom_llm_provider == "replicate":
return [
"stream",
@ -6370,7 +6408,10 @@ def get_supported_openai_params(
elif custom_llm_provider == "palm" or custom_llm_provider == "gemini":
return ["temperature", "top_p", "stream", "n", "stop", "max_tokens"]
elif custom_llm_provider == "vertex_ai":
return litellm.VertexAIConfig().get_supported_openai_params()
if request_type == "chat_completion":
return litellm.VertexAIConfig().get_supported_openai_params()
elif request_type == "embeddings":
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()
elif custom_llm_provider == "sagemaker":
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
elif custom_llm_provider == "aleph_alpha":
@ -6577,6 +6618,9 @@ def get_llm_provider(
or get_secret("FIREWORKSAI_API_KEY")
or get_secret("FIREWORKS_AI_TOKEN")
)
elif custom_llm_provider == "azure_ai":
api_base = api_base or get_secret("AZURE_AI_API_BASE") # type: ignore
dynamic_api_key = api_key or get_secret("AZURE_AI_API_KEY")
elif custom_llm_provider == "mistral":
# mistral is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.mistral.ai
api_base = (
@ -7458,7 +7502,6 @@ def validate_environment(model: Optional[str] = None) -> dict:
def set_callbacks(callback_list, function_id=None):
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, traceloopLogger, athinaLogger, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient, lunaryLogger, promptLayerLogger, langFuseLogger, customLogger, weightsBiasesLogger, langsmithLogger, logfireLogger, dynamoLogger, s3Logger, dataDogLogger, prometheusLogger, greenscaleLogger, openMeterLogger
try:
@ -8767,6 +8810,13 @@ def exception_type(
response=original_exception.response,
litellm_debug_info=extra_information,
)
if "Request failed during generation" in error_str:
# this is an internal server error from predibase
raise litellm.InternalServerError(
message=f"PredibaseException - {error_str}",
llm_provider="predibase",
model=model,
)
elif hasattr(original_exception, "status_code"):
if original_exception.status_code == 500:
exception_mapping_worked = True
@ -9085,7 +9135,7 @@ def exception_type(
):
exception_mapping_worked = True
raise RateLimitError(
message=f"VertexAIException RateLimitError - {error_str}",
message=f"litellm.RateLimitError: VertexAIException - {error_str}",
model=model,
llm_provider="vertex_ai",
litellm_debug_info=extra_information,
@ -9097,7 +9147,14 @@ def exception_type(
),
),
)
elif "500 Internal Server Error" in error_str:
exception_mapping_worked = True
raise ServiceUnavailableError(
message=f"litellm.ServiceUnavailableError: VertexAIException - {error_str}",
model=model,
llm_provider="vertex_ai",
litellm_debug_info=extra_information,
)
if hasattr(original_exception, "status_code"):
if original_exception.status_code == 400:
exception_mapping_worked = True
@ -10048,6 +10105,14 @@ def get_secret(
return oidc_token
else:
raise ValueError("Github OIDC provider failed")
elif oidc_provider == "azure":
# https://azure.github.io/azure-workload-identity/docs/quick-start.html
azure_federated_token_file = os.getenv("AZURE_FEDERATED_TOKEN_FILE")
if azure_federated_token_file is None:
raise ValueError("AZURE_FEDERATED_TOKEN_FILE not found in environment")
with open(azure_federated_token_file, "r") as f:
oidc_token = f.read()
return oidc_token
else:
raise ValueError("Unsupported OIDC provider")
@ -11311,7 +11376,6 @@ class CustomStreamWrapper:
raise StopIteration
response_obj: GenericStreamingChunk = chunk
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
@ -11326,6 +11390,10 @@ class CustomStreamWrapper:
completion_tokens=response_obj["usage"]["outputTokens"],
total_tokens=response_obj["usage"]["totalTokens"],
)
if "tool_use" in response_obj and response_obj["tool_use"] is not None:
completion_obj["tool_calls"] = [response_obj["tool_use"]]
elif self.custom_llm_provider == "sagemaker":
print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}")
response_obj = self.handle_sagemaker_stream(chunk)
@ -11342,7 +11410,6 @@ class CustomStreamWrapper:
new_chunk = self.completion_stream[:chunk_size]
completion_obj["content"] = new_chunk
self.completion_stream = self.completion_stream[chunk_size:]
time.sleep(0.05)
elif self.custom_llm_provider == "palm":
# fake streaming
response_obj = {}
@ -11355,7 +11422,6 @@ class CustomStreamWrapper:
new_chunk = self.completion_stream[:chunk_size]
completion_obj["content"] = new_chunk
self.completion_stream = self.completion_stream[chunk_size:]
time.sleep(0.05)
elif self.custom_llm_provider == "ollama":
response_obj = self.handle_ollama_stream(chunk)
completion_obj["content"] = response_obj["text"]
@ -11442,7 +11508,7 @@ class CustomStreamWrapper:
# for azure, we need to pass the model from the orignal chunk
self.model = chunk.model
response_obj = self.handle_openai_chat_completion_chunk(chunk)
if response_obj == None:
if response_obj is None:
return
completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}")
@ -11575,7 +11641,7 @@ class CustomStreamWrapper:
else:
if (
self.stream_options is not None
and self.stream_options["include_usage"] == True
and self.stream_options["include_usage"] is True
):
return model_response
return
@ -11600,8 +11666,14 @@ class CustomStreamWrapper:
return model_response
elif (
"content" in completion_obj
and isinstance(completion_obj["content"], str)
and len(completion_obj["content"]) > 0
and (
isinstance(completion_obj["content"], str)
and len(completion_obj["content"]) > 0
)
or (
"tool_calls" in completion_obj
and len(completion_obj["tool_calls"]) > 0
)
): # cannot set content of an OpenAI Object to be an empty string
hold, model_response_str = self.check_special_tokens(
chunk=completion_obj["content"],
@ -11657,7 +11729,7 @@ class CustomStreamWrapper:
else:
## else
completion_obj["content"] = model_response_str
if self.sent_first_chunk == False:
if self.sent_first_chunk is False:
completion_obj["role"] = "assistant"
self.sent_first_chunk = True
model_response.choices[0].delta = Delta(**completion_obj)
@ -11666,7 +11738,7 @@ class CustomStreamWrapper:
else:
return
elif self.received_finish_reason is not None:
if self.sent_last_chunk == True:
if self.sent_last_chunk is True:
raise StopIteration
# flush any remaining holding chunk
if len(self.holding_chunk) > 0: