forked from phoenix/litellm-mirror
Merge branch 'BerriAI:main' into main
This commit is contained in:
commit
3feb0ef897
160 changed files with 16239 additions and 1783 deletions
393
litellm/utils.py
393
litellm/utils.py
|
@ -20,6 +20,7 @@ import datetime, time
|
|||
import tiktoken
|
||||
import uuid
|
||||
import aiohttp
|
||||
import textwrap
|
||||
import logging
|
||||
import asyncio, httpx, inspect
|
||||
from inspect import iscoroutine
|
||||
|
@ -28,7 +29,9 @@ from tokenizers import Tokenizer
|
|||
from dataclasses import (
|
||||
dataclass,
|
||||
field,
|
||||
) # for storing API inputs, outputs, and metadata
|
||||
)
|
||||
|
||||
import litellm._service_logger # for storing API inputs, outputs, and metadata
|
||||
|
||||
try:
|
||||
# this works in python 3.8
|
||||
|
@ -53,6 +56,7 @@ os.environ["TIKTOKEN_CACHE_DIR"] = (
|
|||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
import importlib.metadata
|
||||
from ._logging import verbose_logger
|
||||
from .types.router import LiteLLM_Params
|
||||
from .integrations.traceloop import TraceloopLogger
|
||||
from .integrations.athina import AthinaLogger
|
||||
from .integrations.helicone import HeliconeLogger
|
||||
|
@ -67,6 +71,7 @@ from .integrations.custom_logger import CustomLogger
|
|||
from .integrations.langfuse import LangFuseLogger
|
||||
from .integrations.datadog import DataDogLogger
|
||||
from .integrations.prometheus import PrometheusLogger
|
||||
from .integrations.prometheus_services import PrometheusServicesLogger
|
||||
from .integrations.dynamodb import DyanmoDBLogger
|
||||
from .integrations.s3 import S3Logger
|
||||
from .integrations.clickhouse import ClickhouseLogger
|
||||
|
@ -209,6 +214,8 @@ def map_finish_reason(
|
|||
return "stop"
|
||||
elif finish_reason == "max_tokens": # anthropic
|
||||
return "length"
|
||||
elif finish_reason == "tool_use": # anthropic
|
||||
return "tool_calls"
|
||||
return finish_reason
|
||||
|
||||
|
||||
|
@ -235,6 +242,7 @@ class HiddenParams(OpenAIObject):
|
|||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
protected_namespaces = ()
|
||||
|
||||
def get(self, key, default=None):
|
||||
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
|
||||
|
@ -604,7 +612,7 @@ class ModelResponse(OpenAIObject):
|
|||
|
||||
|
||||
class Embedding(OpenAIObject):
|
||||
embedding: list = []
|
||||
embedding: Union[list, str] = []
|
||||
index: int
|
||||
object: str
|
||||
|
||||
|
@ -1075,6 +1083,9 @@ class Logging:
|
|||
headers = {}
|
||||
data = additional_args.get("complete_input_dict", {})
|
||||
api_base = additional_args.get("api_base", "")
|
||||
self.model_call_details["litellm_params"]["api_base"] = str(
|
||||
api_base
|
||||
) # used for alerting
|
||||
masked_headers = {
|
||||
k: (v[:-20] + "*" * 20) if (isinstance(v, str) and len(v) > 20) else v
|
||||
for k, v in headers.items()
|
||||
|
@ -1100,7 +1111,6 @@ class Logging:
|
|||
curl_command = self.model_call_details
|
||||
|
||||
# only print verbose if verbose logger is not set
|
||||
|
||||
if verbose_logger.level == 0:
|
||||
# this means verbose logger was not switched on - user is in litellm.set_verbose=True
|
||||
print_verbose(f"\033[92m{curl_command}\033[0m\n")
|
||||
|
@ -1203,11 +1213,10 @@ class Logging:
|
|||
self.model_call_details["original_response"] = original_response
|
||||
self.model_call_details["additional_args"] = additional_args
|
||||
self.model_call_details["log_event_type"] = "post_api_call"
|
||||
|
||||
# User Logging -> if you pass in a custom logging function
|
||||
print_verbose(
|
||||
f"RAW RESPONSE:\n{self.model_call_details.get('original_response', self.model_call_details)}\n\n",
|
||||
log_level="INFO",
|
||||
log_level="DEBUG",
|
||||
)
|
||||
if self.logger_fn and callable(self.logger_fn):
|
||||
try:
|
||||
|
@ -2013,9 +2022,6 @@ class Logging:
|
|||
else:
|
||||
litellm.cache.add_cache(result, **kwargs)
|
||||
if isinstance(callback, CustomLogger): # custom logger class
|
||||
print_verbose(
|
||||
f"Running Async success callback: {callback}; self.stream: {self.stream}; async_complete_streaming_response: {self.model_call_details.get('async_complete_streaming_response', None)} result={result}"
|
||||
)
|
||||
if self.stream == True:
|
||||
if (
|
||||
"async_complete_streaming_response"
|
||||
|
@ -2399,7 +2405,6 @@ def client(original_function):
|
|||
if litellm.use_client or (
|
||||
"use_client" in kwargs and kwargs["use_client"] == True
|
||||
):
|
||||
print_verbose(f"litedebugger initialized")
|
||||
if "lite_debugger" not in litellm.input_callback:
|
||||
litellm.input_callback.append("lite_debugger")
|
||||
if "lite_debugger" not in litellm.success_callback:
|
||||
|
@ -2526,7 +2531,7 @@ def client(original_function):
|
|||
):
|
||||
rules_obj.pre_call_rules(
|
||||
input="".join(
|
||||
m["content"]
|
||||
m.get("content", "")
|
||||
for m in messages
|
||||
if isinstance(m["content"], str)
|
||||
),
|
||||
|
@ -2573,7 +2578,7 @@ def client(original_function):
|
|||
langfuse_secret=kwargs.pop("langfuse_secret", None),
|
||||
)
|
||||
## check if metadata is passed in
|
||||
litellm_params = {}
|
||||
litellm_params = {"api_base": ""}
|
||||
if "metadata" in kwargs:
|
||||
litellm_params["metadata"] = kwargs["metadata"]
|
||||
logging_obj.update_environment_variables(
|
||||
|
@ -3023,7 +3028,7 @@ def client(original_function):
|
|||
)
|
||||
): # allow users to control returning cached responses from the completion function
|
||||
# checking cache
|
||||
print_verbose(f"INSIDE CHECKING CACHE")
|
||||
print_verbose("INSIDE CHECKING CACHE")
|
||||
if (
|
||||
litellm.cache is not None
|
||||
and str(original_function.__name__)
|
||||
|
@ -3060,7 +3065,7 @@ def client(original_function):
|
|||
cached_result = await litellm.cache.async_get_cache(
|
||||
*args, **kwargs
|
||||
)
|
||||
else:
|
||||
else: # for s3 caching. [NOT RECOMMENDED IN PROD - this will slow down responses since boto3 is sync]
|
||||
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
|
||||
kwargs["preset_cache_key"] = (
|
||||
preset_cache_key # for streaming calls, we need to pass the preset_cache_key
|
||||
|
@ -3103,6 +3108,7 @@ def client(original_function):
|
|||
"preset_cache_key", None
|
||||
),
|
||||
"stream_response": kwargs.get("stream_response", {}),
|
||||
"api_base": kwargs.get("api_base", ""),
|
||||
},
|
||||
input=kwargs.get("messages", ""),
|
||||
api_key=kwargs.get("api_key", None),
|
||||
|
@ -3129,6 +3135,22 @@ def client(original_function):
|
|||
response_object=cached_result,
|
||||
model_response_object=ModelResponse(),
|
||||
)
|
||||
if (
|
||||
call_type == CallTypes.atext_completion.value
|
||||
and isinstance(cached_result, dict)
|
||||
):
|
||||
if kwargs.get("stream", False) == True:
|
||||
cached_result = convert_to_streaming_response_async(
|
||||
response_object=cached_result,
|
||||
)
|
||||
cached_result = CustomStreamWrapper(
|
||||
completion_stream=cached_result,
|
||||
model=model,
|
||||
custom_llm_provider="cached_response",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
else:
|
||||
cached_result = TextCompletionResponse(**cached_result)
|
||||
elif call_type == CallTypes.aembedding.value and isinstance(
|
||||
cached_result, dict
|
||||
):
|
||||
|
@ -3197,7 +3219,13 @@ def client(original_function):
|
|||
for val in non_null_list:
|
||||
idx, cr = val # (idx, cr) tuple
|
||||
if cr is not None:
|
||||
final_embedding_cached_response.data[idx] = cr
|
||||
final_embedding_cached_response.data[idx] = (
|
||||
Embedding(
|
||||
embedding=cr["embedding"],
|
||||
index=idx,
|
||||
object="embedding",
|
||||
)
|
||||
)
|
||||
if len(remaining_list) == 0:
|
||||
# LOG SUCCESS
|
||||
cache_hit = True
|
||||
|
@ -3236,6 +3264,7 @@ def client(original_function):
|
|||
"stream_response": kwargs.get(
|
||||
"stream_response", {}
|
||||
),
|
||||
"api_base": "",
|
||||
},
|
||||
input=kwargs.get("messages", ""),
|
||||
api_key=kwargs.get("api_key", None),
|
||||
|
@ -4156,6 +4185,30 @@ def supports_function_calling(model: str):
|
|||
)
|
||||
|
||||
|
||||
def supports_vision(model: str):
|
||||
"""
|
||||
Check if the given model supports vision and return a boolean value.
|
||||
|
||||
Parameters:
|
||||
model (str): The model name to be checked.
|
||||
|
||||
Returns:
|
||||
bool: True if the model supports vision, False otherwise.
|
||||
|
||||
Raises:
|
||||
Exception: If the given model is not found in model_prices_and_context_window.json.
|
||||
"""
|
||||
if model in litellm.model_cost:
|
||||
model_info = litellm.model_cost[model]
|
||||
if model_info.get("supports_vision", False):
|
||||
return True
|
||||
return False
|
||||
else:
|
||||
raise Exception(
|
||||
f"Model not in model_prices_and_context_window.json. You passed model={model}."
|
||||
)
|
||||
|
||||
|
||||
def supports_parallel_function_calling(model: str):
|
||||
"""
|
||||
Check if the given model supports parallel function calling and return True if it does, False otherwise.
|
||||
|
@ -4523,6 +4576,7 @@ def get_optional_params(
|
|||
and custom_llm_provider != "vertex_ai"
|
||||
and custom_llm_provider != "anyscale"
|
||||
and custom_llm_provider != "together_ai"
|
||||
and custom_llm_provider != "groq"
|
||||
and custom_llm_provider != "mistral"
|
||||
and custom_llm_provider != "anthropic"
|
||||
and custom_llm_provider != "cohere_chat"
|
||||
|
@ -4859,8 +4913,17 @@ def get_optional_params(
|
|||
optional_params["top_p"] = top_p
|
||||
if stream:
|
||||
optional_params["stream"] = stream
|
||||
if n is not None:
|
||||
optional_params["candidate_count"] = n
|
||||
if stop is not None:
|
||||
if isinstance(stop, str):
|
||||
optional_params["stop_sequences"] = [stop]
|
||||
elif isinstance(stop, list):
|
||||
optional_params["stop_sequences"] = stop
|
||||
if max_tokens is not None:
|
||||
optional_params["max_output_tokens"] = max_tokens
|
||||
if response_format is not None and response_format["type"] == "json_object":
|
||||
optional_params["response_mime_type"] = "application/json"
|
||||
if tools is not None and isinstance(tools, list):
|
||||
from vertexai.preview import generative_models
|
||||
|
||||
|
@ -4878,6 +4941,17 @@ def get_optional_params(
|
|||
print_verbose(
|
||||
f"(end) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK - optional_params: {optional_params}"
|
||||
)
|
||||
elif (
|
||||
custom_llm_provider == "vertex_ai" and model in litellm.vertex_anthropic_models
|
||||
):
|
||||
supported_params = get_supported_openai_params(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
optional_params = litellm.VertexAIAnthropicConfig().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
elif custom_llm_provider == "sagemaker":
|
||||
## check if unsupported param passed in
|
||||
supported_params = get_supported_openai_params(
|
||||
|
@ -5202,6 +5276,29 @@ def get_optional_params(
|
|||
optional_params["extra_body"] = (
|
||||
extra_body # openai client supports `extra_body` param
|
||||
)
|
||||
elif custom_llm_provider == "groq":
|
||||
supported_params = get_supported_openai_params(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
|
||||
if temperature is not None:
|
||||
optional_params["temperature"] = temperature
|
||||
if max_tokens is not None:
|
||||
optional_params["max_tokens"] = max_tokens
|
||||
if top_p is not None:
|
||||
optional_params["top_p"] = top_p
|
||||
if stream is not None:
|
||||
optional_params["stream"] = stream
|
||||
if stop is not None:
|
||||
optional_params["stop"] = stop
|
||||
if tools is not None:
|
||||
optional_params["tools"] = tools
|
||||
if tool_choice is not None:
|
||||
optional_params["tool_choice"] = tool_choice
|
||||
if response_format is not None:
|
||||
optional_params["response_format"] = tool_choice
|
||||
|
||||
elif custom_llm_provider == "openrouter":
|
||||
supported_params = get_supported_openai_params(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
|
@ -5258,7 +5355,9 @@ def get_optional_params(
|
|||
extra_body # openai client supports `extra_body` param
|
||||
)
|
||||
else: # assume passing in params for openai/azure openai
|
||||
print_verbose(f"UNMAPPED PROVIDER, ASSUMING IT'S OPENAI/AZURE")
|
||||
print_verbose(
|
||||
f"UNMAPPED PROVIDER, ASSUMING IT'S OPENAI/AZURE - model={model}, custom_llm_provider={custom_llm_provider}"
|
||||
)
|
||||
supported_params = get_supported_openai_params(
|
||||
model=model, custom_llm_provider="openai"
|
||||
)
|
||||
|
@ -5319,6 +5418,55 @@ def get_optional_params(
|
|||
return optional_params
|
||||
|
||||
|
||||
def get_api_base(model: str, optional_params: dict) -> Optional[str]:
|
||||
"""
|
||||
Returns the api base used for calling the model.
|
||||
|
||||
Parameters:
|
||||
- model: str - the model passed to litellm.completion()
|
||||
- optional_params - the additional params passed to litellm.completion - eg. api_base, api_key, etc. See `LiteLLM_Params` - https://github.com/BerriAI/litellm/blob/f09e6ba98d65e035a79f73bc069145002ceafd36/litellm/router.py#L67
|
||||
|
||||
Returns:
|
||||
- string (api_base) or None
|
||||
|
||||
Example:
|
||||
```
|
||||
from litellm import get_api_base
|
||||
|
||||
get_api_base(model="gemini/gemini-pro")
|
||||
```
|
||||
"""
|
||||
_optional_params = LiteLLM_Params(**optional_params) # convert to pydantic object
|
||||
# get llm provider
|
||||
try:
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
|
||||
model=model
|
||||
)
|
||||
except:
|
||||
custom_llm_provider = None
|
||||
if _optional_params.api_base is not None:
|
||||
return _optional_params.api_base
|
||||
|
||||
if (
|
||||
_optional_params.vertex_location is not None
|
||||
and _optional_params.vertex_project is not None
|
||||
):
|
||||
_api_base = "{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/google/models/{}:streamGenerateContent".format(
|
||||
_optional_params.vertex_location,
|
||||
_optional_params.vertex_project,
|
||||
_optional_params.vertex_location,
|
||||
model,
|
||||
)
|
||||
return _api_base
|
||||
|
||||
if custom_llm_provider is not None and custom_llm_provider == "gemini":
|
||||
_api_base = "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent".format(
|
||||
model
|
||||
)
|
||||
return _api_base
|
||||
return None
|
||||
|
||||
|
||||
def get_supported_openai_params(model: str, custom_llm_provider: str):
|
||||
"""
|
||||
Returns the supported openai params for a given model + provider
|
||||
|
@ -5355,6 +5503,17 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
|
|||
"tools",
|
||||
"tool_choice",
|
||||
]
|
||||
elif custom_llm_provider == "groq":
|
||||
return [
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"top_p",
|
||||
"stream",
|
||||
"stop",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"response_format",
|
||||
]
|
||||
elif custom_llm_provider == "cohere":
|
||||
return [
|
||||
"stream",
|
||||
|
@ -5485,6 +5644,9 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
|
|||
"stream",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"response_format",
|
||||
"n",
|
||||
"stop",
|
||||
]
|
||||
elif custom_llm_provider == "sagemaker":
|
||||
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
|
||||
|
@ -5592,6 +5754,19 @@ def get_formatted_prompt(
|
|||
return prompt
|
||||
|
||||
|
||||
def _is_non_openai_azure_model(model: str) -> bool:
|
||||
try:
|
||||
model_name = model.split("/", 1)[1]
|
||||
if (
|
||||
model_name in litellm.cohere_chat_models
|
||||
or f"mistral/{model_name}" in litellm.mistral_chat_models
|
||||
):
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
def get_llm_provider(
|
||||
model: str,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
|
@ -5602,6 +5777,13 @@ def get_llm_provider(
|
|||
dynamic_api_key = None
|
||||
# check if llm provider provided
|
||||
|
||||
# AZURE AI-Studio Logic - Azure AI Studio supports AZURE/Cohere
|
||||
# If User passes azure/command-r-plus -> we should send it to cohere_chat/command-r-plus
|
||||
if model.split("/", 1)[0] == "azure":
|
||||
if _is_non_openai_azure_model(model):
|
||||
custom_llm_provider = "openai"
|
||||
return model, custom_llm_provider, dynamic_api_key, api_base
|
||||
|
||||
if custom_llm_provider:
|
||||
return model, custom_llm_provider, dynamic_api_key, api_base
|
||||
|
||||
|
@ -5845,6 +6027,16 @@ def get_api_key(llm_provider: str, dynamic_api_key: Optional[str]):
|
|||
return api_key
|
||||
|
||||
|
||||
def get_utc_datetime():
|
||||
import datetime as dt
|
||||
from datetime import datetime
|
||||
|
||||
if hasattr(dt, "UTC"):
|
||||
return datetime.now(dt.UTC) # type: ignore
|
||||
else:
|
||||
return datetime.utcnow() # type: ignore
|
||||
|
||||
|
||||
def get_max_tokens(model: str):
|
||||
"""
|
||||
Get the maximum number of output tokens allowed for a given model.
|
||||
|
@ -6466,8 +6658,11 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args, k
|
|||
for detail in additional_details:
|
||||
slack_msg += f"{detail}: {additional_details[detail]}\n"
|
||||
slack_msg += f"Traceback: {traceback_exception}"
|
||||
truncated_slack_msg = textwrap.shorten(
|
||||
slack_msg, width=512, placeholder="..."
|
||||
)
|
||||
slack_app.client.chat_postMessage(
|
||||
channel=alerts_channel, text=slack_msg
|
||||
channel=alerts_channel, text=truncated_slack_msg
|
||||
)
|
||||
elif callback == "sentry":
|
||||
capture_exception(exception)
|
||||
|
@ -7395,7 +7590,6 @@ def exception_type(
|
|||
message=f"AnthropicException - {original_exception.message}",
|
||||
model=model,
|
||||
llm_provider="anthropic",
|
||||
request=original_exception.request,
|
||||
)
|
||||
elif original_exception.status_code == 429:
|
||||
exception_mapping_worked = True
|
||||
|
@ -7492,7 +7686,6 @@ def exception_type(
|
|||
message=f"ReplicateException - {original_exception.message}",
|
||||
model=model,
|
||||
llm_provider="replicate",
|
||||
request=original_exception.request,
|
||||
)
|
||||
elif original_exception.status_code == 429:
|
||||
exception_mapping_worked = True
|
||||
|
@ -7684,7 +7877,7 @@ def exception_type(
|
|||
)
|
||||
elif (
|
||||
"429 Quota exceeded" in error_str
|
||||
or "IndexError: list index out of range"
|
||||
or "IndexError: list index out of range" in error_str
|
||||
):
|
||||
exception_mapping_worked = True
|
||||
raise RateLimitError(
|
||||
|
@ -7901,7 +8094,6 @@ def exception_type(
|
|||
message=f"HuggingfaceException - {original_exception.message}",
|
||||
model=model,
|
||||
llm_provider="huggingface",
|
||||
request=original_exception.request,
|
||||
)
|
||||
elif original_exception.status_code == 429:
|
||||
exception_mapping_worked = True
|
||||
|
@ -7961,7 +8153,6 @@ def exception_type(
|
|||
message=f"AI21Exception - {original_exception.message}",
|
||||
model=model,
|
||||
llm_provider="ai21",
|
||||
request=original_exception.request,
|
||||
)
|
||||
if original_exception.status_code == 422:
|
||||
exception_mapping_worked = True
|
||||
|
@ -8051,7 +8242,6 @@ def exception_type(
|
|||
message=f"NLPCloudException - {original_exception.message}",
|
||||
model=model,
|
||||
llm_provider="nlp_cloud",
|
||||
request=original_exception.request,
|
||||
)
|
||||
elif (
|
||||
original_exception.status_code == 429
|
||||
|
@ -8168,7 +8358,6 @@ def exception_type(
|
|||
message=f"TogetherAIException - {original_exception.message}",
|
||||
model=model,
|
||||
llm_provider="together_ai",
|
||||
request=original_exception.request,
|
||||
)
|
||||
elif original_exception.status_code == 422:
|
||||
exception_mapping_worked = True
|
||||
|
@ -8368,7 +8557,6 @@ def exception_type(
|
|||
message=f"AzureException - {original_exception.message}",
|
||||
model=model,
|
||||
llm_provider="azure",
|
||||
request=original_exception.request,
|
||||
)
|
||||
if original_exception.status_code == 422:
|
||||
exception_mapping_worked = True
|
||||
|
@ -8706,8 +8894,39 @@ class CustomStreamWrapper:
|
|||
self.holding_chunk = ""
|
||||
return hold, curr_chunk
|
||||
|
||||
def handle_anthropic_text_chunk(self, chunk):
|
||||
str_line = chunk
|
||||
if isinstance(chunk, bytes): # Handle binary data
|
||||
str_line = chunk.decode("utf-8") # Convert bytes to string
|
||||
text = ""
|
||||
is_finished = False
|
||||
finish_reason = None
|
||||
if str_line.startswith("data:"):
|
||||
data_json = json.loads(str_line[5:])
|
||||
type_chunk = data_json.get("type", None)
|
||||
if type_chunk == "completion":
|
||||
text = data_json.get("completion")
|
||||
finish_reason = data_json.get("stop_reason")
|
||||
if finish_reason is not None:
|
||||
is_finished = True
|
||||
return {
|
||||
"text": text,
|
||||
"is_finished": is_finished,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
elif "error" in str_line:
|
||||
raise ValueError(f"Unable to parse response. Original response: {str_line}")
|
||||
else:
|
||||
return {
|
||||
"text": text,
|
||||
"is_finished": is_finished,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
|
||||
def handle_anthropic_chunk(self, chunk):
|
||||
str_line = chunk.decode("utf-8") # Convert bytes to string
|
||||
str_line = chunk
|
||||
if isinstance(chunk, bytes): # Handle binary data
|
||||
str_line = chunk.decode("utf-8") # Convert bytes to string
|
||||
text = ""
|
||||
is_finished = False
|
||||
finish_reason = None
|
||||
|
@ -8742,6 +8961,58 @@ class CustomStreamWrapper:
|
|||
"finish_reason": finish_reason,
|
||||
}
|
||||
|
||||
def handle_vertexai_anthropic_chunk(self, chunk):
|
||||
"""
|
||||
- MessageStartEvent(message=Message(id='msg_01LeRRgvX4gwkX3ryBVgtuYZ', content=[], model='claude-3-sonnet-20240229', role='assistant', stop_reason=None, stop_sequence=None, type='message', usage=Usage(input_tokens=8, output_tokens=1)), type='message_start'); custom_llm_provider: vertex_ai
|
||||
- ContentBlockStartEvent(content_block=ContentBlock(text='', type='text'), index=0, type='content_block_start'); custom_llm_provider: vertex_ai
|
||||
- ContentBlockDeltaEvent(delta=TextDelta(text='Hello', type='text_delta'), index=0, type='content_block_delta'); custom_llm_provider: vertex_ai
|
||||
"""
|
||||
text = ""
|
||||
prompt_tokens = None
|
||||
completion_tokens = None
|
||||
is_finished = False
|
||||
finish_reason = None
|
||||
type_chunk = getattr(chunk, "type", None)
|
||||
if type_chunk == "message_start":
|
||||
message = getattr(chunk, "message", None)
|
||||
text = "" # lets us return a chunk with usage to user
|
||||
_usage = getattr(message, "usage", None)
|
||||
if _usage is not None:
|
||||
prompt_tokens = getattr(_usage, "input_tokens", None)
|
||||
completion_tokens = getattr(_usage, "output_tokens", None)
|
||||
elif type_chunk == "content_block_delta":
|
||||
"""
|
||||
Anthropic content chunk
|
||||
chunk = {'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': 'Hello'}}
|
||||
"""
|
||||
delta = getattr(chunk, "delta", None)
|
||||
if delta is not None:
|
||||
text = getattr(delta, "text", "")
|
||||
else:
|
||||
text = ""
|
||||
elif type_chunk == "message_delta":
|
||||
"""
|
||||
Anthropic
|
||||
chunk = {'type': 'message_delta', 'delta': {'stop_reason': 'max_tokens', 'stop_sequence': None}, 'usage': {'output_tokens': 10}}
|
||||
"""
|
||||
# TODO - get usage from this chunk, set in response
|
||||
delta = getattr(chunk, "delta", None)
|
||||
if delta is not None:
|
||||
finish_reason = getattr(delta, "stop_reason", "stop")
|
||||
is_finished = True
|
||||
_usage = getattr(chunk, "usage", None)
|
||||
if _usage is not None:
|
||||
prompt_tokens = getattr(_usage, "input_tokens", None)
|
||||
completion_tokens = getattr(_usage, "output_tokens", None)
|
||||
|
||||
return {
|
||||
"text": text,
|
||||
"is_finished": is_finished,
|
||||
"finish_reason": finish_reason,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
}
|
||||
|
||||
def handle_together_ai_chunk(self, chunk):
|
||||
chunk = chunk.decode("utf-8")
|
||||
text = ""
|
||||
|
@ -9339,6 +9610,14 @@ class CustomStreamWrapper:
|
|||
completion_obj["content"] = response_obj["text"]
|
||||
if response_obj["is_finished"]:
|
||||
self.received_finish_reason = response_obj["finish_reason"]
|
||||
elif (
|
||||
self.custom_llm_provider
|
||||
and self.custom_llm_provider == "anthropic_text"
|
||||
):
|
||||
response_obj = self.handle_anthropic_text_chunk(chunk)
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
if response_obj["is_finished"]:
|
||||
self.received_finish_reason = response_obj["finish_reason"]
|
||||
elif self.model == "replicate" or self.custom_llm_provider == "replicate":
|
||||
response_obj = self.handle_replicate_chunk(chunk)
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
|
@ -9409,7 +9688,33 @@ class CustomStreamWrapper:
|
|||
else:
|
||||
completion_obj["content"] = str(chunk)
|
||||
elif self.custom_llm_provider and (self.custom_llm_provider == "vertex_ai"):
|
||||
if hasattr(chunk, "candidates") == True:
|
||||
if self.model.startswith("claude-3"):
|
||||
response_obj = self.handle_vertexai_anthropic_chunk(chunk=chunk)
|
||||
if response_obj is None:
|
||||
return
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
if response_obj.get("prompt_tokens", None) is not None:
|
||||
model_response.usage.prompt_tokens = response_obj[
|
||||
"prompt_tokens"
|
||||
]
|
||||
if response_obj.get("completion_tokens", None) is not None:
|
||||
model_response.usage.completion_tokens = response_obj[
|
||||
"completion_tokens"
|
||||
]
|
||||
if hasattr(model_response.usage, "prompt_tokens"):
|
||||
model_response.usage.total_tokens = (
|
||||
getattr(model_response.usage, "total_tokens", 0)
|
||||
+ model_response.usage.prompt_tokens
|
||||
)
|
||||
if hasattr(model_response.usage, "completion_tokens"):
|
||||
model_response.usage.total_tokens = (
|
||||
getattr(model_response.usage, "total_tokens", 0)
|
||||
+ model_response.usage.completion_tokens
|
||||
)
|
||||
|
||||
if response_obj["is_finished"]:
|
||||
self.received_finish_reason = response_obj["finish_reason"]
|
||||
elif hasattr(chunk, "candidates") == True:
|
||||
try:
|
||||
try:
|
||||
completion_obj["content"] = chunk.text
|
||||
|
@ -9661,6 +9966,18 @@ class CustomStreamWrapper:
|
|||
print_verbose(f"self.sent_first_chunk: {self.sent_first_chunk}")
|
||||
## RETURN ARG
|
||||
if (
|
||||
"content" in completion_obj
|
||||
and isinstance(completion_obj["content"], str)
|
||||
and len(completion_obj["content"]) == 0
|
||||
and hasattr(model_response.usage, "prompt_tokens")
|
||||
):
|
||||
if self.sent_first_chunk == False:
|
||||
completion_obj["role"] = "assistant"
|
||||
self.sent_first_chunk = True
|
||||
model_response.choices[0].delta = Delta(**completion_obj)
|
||||
print_verbose(f"returning model_response: {model_response}")
|
||||
return model_response
|
||||
elif (
|
||||
"content" in completion_obj
|
||||
and isinstance(completion_obj["content"], str)
|
||||
and len(completion_obj["content"]) > 0
|
||||
|
@ -9877,6 +10194,8 @@ class CustomStreamWrapper:
|
|||
or self.custom_llm_provider == "custom_openai"
|
||||
or self.custom_llm_provider == "text-completion-openai"
|
||||
or self.custom_llm_provider == "azure_text"
|
||||
or self.custom_llm_provider == "anthropic"
|
||||
or self.custom_llm_provider == "anthropic_text"
|
||||
or self.custom_llm_provider == "huggingface"
|
||||
or self.custom_llm_provider == "ollama"
|
||||
or self.custom_llm_provider == "ollama_chat"
|
||||
|
@ -10482,28 +10801,26 @@ def print_args_passed_to_litellm(original_function, args, kwargs):
|
|||
|
||||
args_str = ", ".join(map(repr, args))
|
||||
kwargs_str = ", ".join(f"{key}={repr(value)}" for key, value in kwargs.items())
|
||||
print_verbose("\n", log_level="INFO") # new line before
|
||||
print_verbose("\033[92mRequest to litellm:\033[0m", log_level="INFO")
|
||||
print_verbose(
|
||||
"\n",
|
||||
) # new line before
|
||||
print_verbose(
|
||||
"\033[92mRequest to litellm:\033[0m",
|
||||
)
|
||||
if args and kwargs:
|
||||
print_verbose(
|
||||
f"\033[92mlitellm.{original_function.__name__}({args_str}, {kwargs_str})\033[0m",
|
||||
log_level="INFO",
|
||||
f"\033[92mlitellm.{original_function.__name__}({args_str}, {kwargs_str})\033[0m"
|
||||
)
|
||||
elif args:
|
||||
print_verbose(
|
||||
f"\033[92mlitellm.{original_function.__name__}({args_str})\033[0m",
|
||||
log_level="INFO",
|
||||
f"\033[92mlitellm.{original_function.__name__}({args_str})\033[0m"
|
||||
)
|
||||
elif kwargs:
|
||||
print_verbose(
|
||||
f"\033[92mlitellm.{original_function.__name__}({kwargs_str})\033[0m",
|
||||
log_level="INFO",
|
||||
f"\033[92mlitellm.{original_function.__name__}({kwargs_str})\033[0m"
|
||||
)
|
||||
else:
|
||||
print_verbose(
|
||||
f"\033[92mlitellm.{original_function.__name__}()\033[0m",
|
||||
log_level="INFO",
|
||||
)
|
||||
print_verbose(f"\033[92mlitellm.{original_function.__name__}()\033[0m")
|
||||
print_verbose("\n") # new line after
|
||||
except:
|
||||
# This should always be non blocking
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue