formatting improvements

This commit is contained in:
ishaan-jaff 2023-08-28 09:20:50 -07:00
parent 3e0a16acf4
commit a69b7ffcfa
17 changed files with 464 additions and 323 deletions

View file

@ -5,8 +5,12 @@ input_callback: List[str] = []
success_callback: List[str] = [] success_callback: List[str] = []
failure_callback: List[str] = [] failure_callback: List[str] = []
set_verbose = False set_verbose = False
email: Optional[str] = None # for hosted dashboard. Learn more - https://docs.litellm.ai/docs/debugging/hosted_debugging email: Optional[
token: Optional[str] = None # for hosted dashboard. Learn more - https://docs.litellm.ai/docs/debugging/hosted_debugging str
] = None # for hosted dashboard. Learn more - https://docs.litellm.ai/docs/debugging/hosted_debugging
token: Optional[
str
] = None # for hosted dashboard. Learn more - https://docs.litellm.ai/docs/debugging/hosted_debugging
telemetry = True telemetry = True
max_tokens = 256 # OpenAI Defaults max_tokens = 256 # OpenAI Defaults
retry = True retry = True

View file

@ -1,4 +1,5 @@
import importlib_metadata import importlib_metadata
try: try:
version = importlib_metadata.version("litellm") version = importlib_metadata.version("litellm")
except: except:

View file

@ -1,20 +1,21 @@
###### LiteLLM Integration with GPT Cache ######### ###### LiteLLM Integration with GPT Cache #########
import gptcache import gptcache
# openai.ChatCompletion._llm_handler = litellm.completion # openai.ChatCompletion._llm_handler = litellm.completion
from gptcache.adapter import openai from gptcache.adapter import openai
import litellm import litellm
class LiteLLMChatCompletion(gptcache.adapter.openai.ChatCompletion): class LiteLLMChatCompletion(gptcache.adapter.openai.ChatCompletion):
@classmethod @classmethod
def _llm_handler(cls, *llm_args, **llm_kwargs): def _llm_handler(cls, *llm_args, **llm_kwargs):
return litellm.completion(*llm_args, **llm_kwargs) return litellm.completion(*llm_args, **llm_kwargs)
completion = LiteLLMChatCompletion.create completion = LiteLLMChatCompletion.create
###### End of LiteLLM Integration with GPT Cache ######### ###### End of LiteLLM Integration with GPT Cache #########
# ####### Example usage ############### # ####### Example usage ###############
# from gptcache import cache # from gptcache import cache
# completion = LiteLLMChatCompletion.create # completion = LiteLLMChatCompletion.create
@ -23,9 +24,3 @@ completion = LiteLLMChatCompletion.create
# cache.set_openai_key() # cache.set_openai_key()
# result = completion(model="claude-2", messages=[{"role": "user", "content": "cto of litellm"}]) # result = completion(model="claude-2", messages=[{"role": "user", "content": "cto of litellm"}])
# print(result) # print(result)

View file

@ -1,5 +1,6 @@
import requests, traceback, json, os import requests, traceback, json, os
class LiteDebugger: class LiteDebugger:
user_email = None user_email = None
dashboard_url = None dashboard_url = None
@ -15,7 +16,9 @@ class LiteDebugger:
self.user_email = os.getenv("LITELLM_EMAIL") or email self.user_email = os.getenv("LITELLM_EMAIL") or email
self.dashboard_url = "https://admin.litellm.ai/" + self.user_email self.dashboard_url = "https://admin.litellm.ai/" + self.user_email
try: try:
print(f"\033[92mHere's your LiteLLM Dashboard 👉 \033[94m\033[4m{self.dashboard_url}\033[0m") print(
f"\033[92mHere's your LiteLLM Dashboard 👉 \033[94m\033[4m{self.dashboard_url}\033[0m"
)
except: except:
print(f"Here's your LiteLLM Dashboard 👉 {self.dashboard_url}") print(f"Here's your LiteLLM Dashboard 👉 {self.dashboard_url}")
if self.user_email == None: if self.user_email == None:
@ -28,17 +31,25 @@ class LiteDebugger:
) )
def input_log_event( def input_log_event(
self, model, messages, end_user, litellm_call_id, print_verbose, litellm_params, optional_params self,
model,
messages,
end_user,
litellm_call_id,
print_verbose,
litellm_params,
optional_params,
): ):
try: try:
print_verbose( print_verbose(
f"LiteLLMDebugger: Logging - Enters input logging function for model {model}" f"LiteLLMDebugger: Logging - Enters input logging function for model {model}"
) )
def remove_key_value(dictionary, key): def remove_key_value(dictionary, key):
new_dict = dictionary.copy() # Create a copy of the original dictionary new_dict = dictionary.copy() # Create a copy of the original dictionary
new_dict.pop(key) # Remove the specified key-value pair from the copy new_dict.pop(key) # Remove the specified key-value pair from the copy
return new_dict return new_dict
updated_litellm_params = remove_key_value(litellm_params, "logger_fn") updated_litellm_params = remove_key_value(litellm_params, "logger_fn")
litellm_data_obj = { litellm_data_obj = {
@ -49,7 +60,7 @@ class LiteDebugger:
"litellm_call_id": litellm_call_id, "litellm_call_id": litellm_call_id,
"user_email": self.user_email, "user_email": self.user_email,
"litellm_params": updated_litellm_params, "litellm_params": updated_litellm_params,
"optional_params": optional_params "optional_params": optional_params,
} }
print_verbose( print_verbose(
f"LiteLLMDebugger: Logging - logged data obj {litellm_data_obj}" f"LiteLLMDebugger: Logging - logged data obj {litellm_data_obj}"
@ -65,10 +76,8 @@ class LiteDebugger:
f"[Non-Blocking Error] LiteDebugger: Logging Error - {traceback.format_exc()}" f"[Non-Blocking Error] LiteDebugger: Logging Error - {traceback.format_exc()}"
) )
pass pass
def post_call_log_event( def post_call_log_event(self, original_response, litellm_call_id, print_verbose):
self, original_response, litellm_call_id, print_verbose
):
try: try:
litellm_data_obj = { litellm_data_obj = {
"status": "received", "status": "received",
@ -110,7 +119,7 @@ class LiteDebugger:
"model": response_obj["model"], "model": response_obj["model"],
"total_cost": total_cost, "total_cost": total_cost,
"messages": messages, "messages": messages,
"response": response['choices'][0]['message']['content'], "response": response["choices"][0]["message"]["content"],
"end_user": end_user, "end_user": end_user,
"litellm_call_id": litellm_call_id, "litellm_call_id": litellm_call_id,
"status": "success", "status": "success",
@ -124,7 +133,12 @@ class LiteDebugger:
headers={"content-type": "application/json"}, headers={"content-type": "application/json"},
data=json.dumps(litellm_data_obj), data=json.dumps(litellm_data_obj),
) )
elif "data" in response_obj and isinstance(response_obj["data"], list) and len(response_obj["data"]) > 0 and "embedding" in response_obj["data"][0]: elif (
"data" in response_obj
and isinstance(response_obj["data"], list)
and len(response_obj["data"]) > 0
and "embedding" in response_obj["data"][0]
):
print(f"messages: {messages}") print(f"messages: {messages}")
litellm_data_obj = { litellm_data_obj = {
"response_time": response_time, "response_time": response_time,
@ -145,7 +159,10 @@ class LiteDebugger:
headers={"content-type": "application/json"}, headers={"content-type": "application/json"},
data=json.dumps(litellm_data_obj), data=json.dumps(litellm_data_obj),
) )
elif isinstance(response_obj, object) and response_obj.__class__.__name__ == "CustomStreamWrapper": elif (
isinstance(response_obj, object)
and response_obj.__class__.__name__ == "CustomStreamWrapper"
):
litellm_data_obj = { litellm_data_obj = {
"response_time": response_time, "response_time": response_time,
"total_cost": total_cost, "total_cost": total_cost,

View file

@ -12,20 +12,17 @@ dotenv.load_dotenv() # Loading env variables using dotenv
# convert to {completion: xx, tokens: xx} # convert to {completion: xx, tokens: xx}
def parse_usage(usage): def parse_usage(usage):
return { return {
"completion": "completion": usage["completion_tokens"] if "completion_tokens" in usage else 0,
usage["completion_tokens"] if "completion_tokens" in usage else 0, "prompt": usage["prompt_tokens"] if "prompt_tokens" in usage else 0,
"prompt":
usage["prompt_tokens"] if "prompt_tokens" in usage else 0,
} }
def parse_messages(input): def parse_messages(input):
if input is None: if input is None:
return None return None
def clean_message(message): def clean_message(message):
#if is strin, return as is # if is strin, return as is
if isinstance(message, str): if isinstance(message, str):
return message return message
@ -50,75 +47,78 @@ class LLMonitorLogger:
# Class variables or attributes # Class variables or attributes
def __init__(self): def __init__(self):
# Instance variables # Instance variables
self.api_url = os.getenv( self.api_url = os.getenv("LLMONITOR_API_URL") or "https://app.llmonitor.com"
"LLMONITOR_API_URL") or "https://app.llmonitor.com"
self.app_id = os.getenv("LLMONITOR_APP_ID") self.app_id = os.getenv("LLMONITOR_APP_ID")
def log_event( def log_event(
self, self,
type, type,
event, event,
run_id, run_id,
model, model,
print_verbose, print_verbose,
input=None, input=None,
user_id=None, user_id=None,
response_obj=None, response_obj=None,
start_time=datetime.datetime.now(), start_time=datetime.datetime.now(),
end_time=datetime.datetime.now(), end_time=datetime.datetime.now(),
error=None, error=None,
): ):
# Method definition # Method definition
try: try:
print_verbose( print_verbose(f"LLMonitor Logging - Logging request for model {model}")
f"LLMonitor Logging - Logging request for model {model}")
if response_obj: if response_obj:
usage = parse_usage( usage = (
response_obj['usage']) if 'usage' in response_obj else None parse_usage(response_obj["usage"])
output = response_obj[ if "usage" in response_obj
'choices'] if 'choices' in response_obj else None else None
)
output = response_obj["choices"] if "choices" in response_obj else None
else: else:
usage = None usage = None
output = None output = None
if error: if error:
error_obj = {'stack': error} error_obj = {"stack": error}
else: else:
error_obj = None error_obj = None
data = [{ data = [
"type": type, {
"name": model, "type": type,
"runId": run_id, "name": model,
"app": self.app_id, "runId": run_id,
'event': 'start', "app": self.app_id,
"timestamp": start_time.isoformat(), "event": "start",
"userId": user_id, "timestamp": start_time.isoformat(),
"input": parse_messages(input), "userId": user_id,
}, { "input": parse_messages(input),
"type": type, },
"runId": run_id, {
"app": self.app_id, "type": type,
"event": event, "runId": run_id,
"error": error_obj, "app": self.app_id,
"timestamp": end_time.isoformat(), "event": event,
"userId": user_id, "error": error_obj,
"output": parse_messages(output), "timestamp": end_time.isoformat(),
"tokensUsage": usage, "userId": user_id,
}] "output": parse_messages(output),
"tokensUsage": usage,
},
]
# print_verbose(f"LLMonitor Logging - final data object: {data}") # print_verbose(f"LLMonitor Logging - final data object: {data}")
response = requests.post( response = requests.post(
self.api_url + '/api/report', self.api_url + "/api/report",
headers={'Content-Type': 'application/json'}, headers={"Content-Type": "application/json"},
json={'events': data}) json={"events": data},
)
print_verbose(f"LLMonitor Logging - response: {response}") print_verbose(f"LLMonitor Logging - response: {response}")
except: except:
# traceback.print_exc() # traceback.print_exc()
print_verbose( print_verbose(f"LLMonitor Logging Error - {traceback.format_exc()}")
f"LLMonitor Logging Error - {traceback.format_exc()}")
pass pass

View file

@ -7,6 +7,7 @@ import requests
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
class PromptLayerLogger: class PromptLayerLogger:
# Class variables or attributes # Class variables or attributes
def __init__(self): def __init__(self):
@ -26,7 +27,9 @@ class PromptLayerLogger:
"function_name": "openai.ChatCompletion.create", "function_name": "openai.ChatCompletion.create",
"kwargs": kwargs, "kwargs": kwargs,
"tags": ["hello", "world"], "tags": ["hello", "world"],
"request_response": dict(response_obj), # TODO: Check if we need a dict "request_response": dict(
response_obj
), # TODO: Check if we need a dict
"request_start_time": int(start_time.timestamp()), "request_start_time": int(start_time.timestamp()),
"request_end_time": int(end_time.timestamp()), "request_end_time": int(end_time.timestamp()),
"api_key": self.key, "api_key": self.key,
@ -34,11 +37,12 @@ class PromptLayerLogger:
# "prompt_id": "<PROMPT ID>", # "prompt_id": "<PROMPT ID>",
# "prompt_input_variables": "<Dictionary of variables for prompt>", # "prompt_input_variables": "<Dictionary of variables for prompt>",
# "prompt_version":1, # "prompt_version":1,
}, },
) )
print_verbose(f"Prompt Layer Logging - final response object: {request_response}") print_verbose(
f"Prompt Layer Logging - final response object: {request_response}"
)
except: except:
# traceback.print_exc() # traceback.print_exc()
print_verbose(f"Prompt Layer Error - {traceback.format_exc()}") print_verbose(f"Prompt Layer Error - {traceback.format_exc()}")

View file

@ -94,7 +94,10 @@ class AnthropicLLM:
## COMPLETION CALL ## COMPLETION CALL
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
response = requests.post( response = requests.post(
self.completion_url, headers=self.headers, data=json.dumps(data), stream=optional_params["stream"] self.completion_url,
headers=self.headers,
data=json.dumps(data),
stream=optional_params["stream"],
) )
return response.iter_lines() return response.iter_lines()
else: else:
@ -142,4 +145,3 @@ class AnthropicLLM:
self, self,
): # logic for parsing in - calling - parsing out model embedding calls ): # logic for parsing in - calling - parsing out model embedding calls
pass pass

View file

@ -5,6 +5,7 @@ import time
from typing import Callable from typing import Callable
from litellm.utils import ModelResponse from litellm.utils import ModelResponse
class BasetenError(Exception): class BasetenError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
@ -15,9 +16,7 @@ class BasetenError(Exception):
class BasetenLLM: class BasetenLLM:
def __init__( def __init__(self, encoding, logging_obj, api_key=None):
self, encoding, logging_obj, api_key=None
):
self.encoding = encoding self.encoding = encoding
self.completion_url_fragment_1 = "https://app.baseten.co/models/" self.completion_url_fragment_1 = "https://app.baseten.co/models/"
self.completion_url_fragment_2 = "/predict" self.completion_url_fragment_2 = "/predict"
@ -55,13 +54,9 @@ class BasetenLLM:
for message in messages: for message in messages:
if "role" in message: if "role" in message:
if message["role"] == "user": if message["role"] == "user":
prompt += ( prompt += f"{message['content']}"
f"{message['content']}"
)
else: else:
prompt += ( prompt += f"{message['content']}"
f"{message['content']}"
)
else: else:
prompt += f"{message['content']}" prompt += f"{message['content']}"
data = { data = {
@ -78,7 +73,9 @@ class BasetenLLM:
) )
## COMPLETION CALL ## COMPLETION CALL
response = requests.post( response = requests.post(
self.completion_url_fragment_1 + model + self.completion_url_fragment_2, headers=self.headers, data=json.dumps(data) self.completion_url_fragment_1 + model + self.completion_url_fragment_2,
headers=self.headers,
data=json.dumps(data),
) )
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
return response.iter_lines() return response.iter_lines()
@ -100,19 +97,33 @@ class BasetenLLM:
) )
else: else:
if "model_output" in completion_response: if "model_output" in completion_response:
if isinstance(completion_response["model_output"], dict) and "data" in completion_response["model_output"] and isinstance(completion_response["model_output"]["data"], list): if (
model_response["choices"][0]["message"]["content"] = completion_response["model_output"]["data"][0] isinstance(completion_response["model_output"], dict)
and "data" in completion_response["model_output"]
and isinstance(
completion_response["model_output"]["data"], list
)
):
model_response["choices"][0]["message"][
"content"
] = completion_response["model_output"]["data"][0]
elif isinstance(completion_response["model_output"], str): elif isinstance(completion_response["model_output"], str):
model_response["choices"][0]["message"]["content"] = completion_response["model_output"] model_response["choices"][0]["message"][
elif "completion" in completion_response and isinstance(completion_response["completion"], str): "content"
model_response["choices"][0]["message"]["content"] = completion_response["completion"] ] = completion_response["model_output"]
elif "completion" in completion_response and isinstance(
completion_response["completion"], str
):
model_response["choices"][0]["message"][
"content"
] = completion_response["completion"]
else: else:
raise ValueError(f"Unable to parse response. Original response: {response.text}") raise ValueError(
f"Unable to parse response. Original response: {response.text}"
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len( prompt_tokens = len(self.encoding.encode(prompt))
self.encoding.encode(prompt)
)
completion_tokens = len( completion_tokens = len(
self.encoding.encode(model_response["choices"][0]["message"]["content"]) self.encoding.encode(model_response["choices"][0]["message"]["content"])
) )

View file

@ -103,7 +103,9 @@ def completion(
return completion_with_fallbacks(**args) return completion_with_fallbacks(**args)
if litellm.model_alias_map and model in litellm.model_alias_map: if litellm.model_alias_map and model in litellm.model_alias_map:
args["model_alias_map"] = litellm.model_alias_map args["model_alias_map"] = litellm.model_alias_map
model = litellm.model_alias_map[model] # update the model to the actual value if an alias has been passed in model = litellm.model_alias_map[
model
] # update the model to the actual value if an alias has been passed in
model_response = ModelResponse() model_response = ModelResponse()
if azure: # this flag is deprecated, remove once notebooks are also updated. if azure: # this flag is deprecated, remove once notebooks are also updated.
custom_llm_provider = "azure" custom_llm_provider = "azure"
@ -146,7 +148,7 @@ def completion(
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
custom_api_base=custom_api_base, custom_api_base=custom_api_base,
litellm_call_id=litellm_call_id, litellm_call_id=litellm_call_id,
model_alias_map=litellm.model_alias_map model_alias_map=litellm.model_alias_map,
) )
logging = Logging( logging = Logging(
model=model, model=model,
@ -216,7 +218,10 @@ def completion(
# note: if a user sets a custom base - we should ensure this works # note: if a user sets a custom base - we should ensure this works
# allow for the setting of dynamic and stateful api-bases # allow for the setting of dynamic and stateful api-bases
api_base = ( api_base = (
custom_api_base or litellm.api_base or get_secret("OPENAI_API_BASE") or "https://api.openai.com/v1" custom_api_base
or litellm.api_base
or get_secret("OPENAI_API_BASE")
or "https://api.openai.com/v1"
) )
openai.api_base = api_base openai.api_base = api_base
openai.api_version = None openai.api_version = None
@ -255,9 +260,11 @@ def completion(
original_response=response, original_response=response,
additional_args={"headers": litellm.headers}, additional_args={"headers": litellm.headers},
) )
elif (model in litellm.open_ai_text_completion_models or elif (
"ft:babbage-002" in model or # support for finetuned completion models model in litellm.open_ai_text_completion_models
"ft:davinci-002" in model): or "ft:babbage-002" in model
or "ft:davinci-002" in model # support for finetuned completion models
):
openai.api_type = "openai" openai.api_type = "openai"
openai.api_base = ( openai.api_base = (
litellm.api_base litellm.api_base
@ -544,7 +551,10 @@ def completion(
logging.pre_call(input=prompt, api_key=TOGETHER_AI_TOKEN) logging.pre_call(input=prompt, api_key=TOGETHER_AI_TOKEN)
print(f"TOGETHER_AI_TOKEN: {TOGETHER_AI_TOKEN}") print(f"TOGETHER_AI_TOKEN: {TOGETHER_AI_TOKEN}")
if "stream_tokens" in optional_params and optional_params["stream_tokens"] == True: if (
"stream_tokens" in optional_params
and optional_params["stream_tokens"] == True
):
res = requests.post( res = requests.post(
endpoint, endpoint,
json={ json={
@ -698,9 +708,7 @@ def completion(
): ):
custom_llm_provider = "baseten" custom_llm_provider = "baseten"
baseten_key = ( baseten_key = (
api_key api_key or litellm.baseten_key or os.environ.get("BASETEN_API_KEY")
or litellm.baseten_key
or os.environ.get("BASETEN_API_KEY")
) )
baseten_client = BasetenLLM( baseten_client = BasetenLLM(
encoding=encoding, api_key=baseten_key, logging_obj=logging encoding=encoding, api_key=baseten_key, logging_obj=logging
@ -767,11 +775,14 @@ def completion(
model=model, custom_llm_provider=custom_llm_provider, original_exception=e model=model, custom_llm_provider=custom_llm_provider, original_exception=e
) )
def completion_with_retries(*args, **kwargs): def completion_with_retries(*args, **kwargs):
import tenacity import tenacity
retryer = tenacity.Retrying(stop=tenacity.stop_after_attempt(3), reraise=True) retryer = tenacity.Retrying(stop=tenacity.stop_after_attempt(3), reraise=True)
return retryer(completion, *args, **kwargs) return retryer(completion, *args, **kwargs)
def batch_completion(*args, **kwargs): def batch_completion(*args, **kwargs):
batch_messages = args[1] if len(args) > 1 else kwargs.get("messages") batch_messages = args[1] if len(args) > 1 else kwargs.get("messages")
completions = [] completions = []
@ -865,14 +876,16 @@ def embedding(
custom_llm_provider="azure" if azure == True else None, custom_llm_provider="azure" if azure == True else None,
) )
###### Text Completion ################ ###### Text Completion ################
def text_completion(*args, **kwargs): def text_completion(*args, **kwargs):
if 'prompt' in kwargs: if "prompt" in kwargs:
messages = [{'role': 'system', 'content': kwargs['prompt']}] messages = [{"role": "system", "content": kwargs["prompt"]}]
kwargs['messages'] = messages kwargs["messages"] = messages
kwargs.pop('prompt') kwargs.pop("prompt")
return completion(*args, **kwargs) return completion(*args, **kwargs)
####### HELPER FUNCTIONS ################ ####### HELPER FUNCTIONS ################
## Set verbose to true -> ```litellm.set_verbose = True``` ## Set verbose to true -> ```litellm.set_verbose = True```
def print_verbose(print_statement): def print_verbose(print_statement):

View file

@ -14,6 +14,7 @@ from litellm import embedding, completion
messages = [{"role": "user", "content": "who is ishaan Github? "}] messages = [{"role": "user", "content": "who is ishaan Github? "}]
# test if response cached # test if response cached
def test_caching(): def test_caching():
try: try:
@ -50,14 +51,16 @@ def test_caching_with_models():
print(f"response1: {response1}") print(f"response1: {response1}")
print(f"response2: {response2}") print(f"response2: {response2}")
pytest.fail(f"Error occurred:") pytest.fail(f"Error occurred:")
# test_caching_with_models()
# test_caching_with_models()
def test_gpt_cache(): def test_gpt_cache():
# INIT GPT Cache # # INIT GPT Cache #
from gptcache import cache from gptcache import cache
from litellm.cache import completion from litellm.cache import completion
cache.init() cache.init()
cache.set_openai_key() cache.set_openai_key()
@ -67,10 +70,11 @@ def test_gpt_cache():
print(f"response2: {response2}") print(f"response2: {response2}")
print(f"response3: {response3}") print(f"response3: {response3}")
if response3['choices'] != response2['choices']: if response3["choices"] != response2["choices"]:
# if models are different, it should not return cached response # if models are different, it should not return cached response
print(f"response2: {response2}") print(f"response2: {response2}")
print(f"response3: {response3}") print(f"response3: {response3}")
pytest.fail(f"Error occurred:") pytest.fail(f"Error occurred:")
# test_gpt_cache()
# test_gpt_cache()

View file

@ -142,9 +142,12 @@ def test_completion_openai():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
def test_completion_openai_prompt(): def test_completion_openai_prompt():
try: try:
response = text_completion(model="gpt-3.5-turbo", prompt="What's the weather in SF?") response = text_completion(
model="gpt-3.5-turbo", prompt="What's the weather in SF?"
)
response_str = response["choices"][0]["message"]["content"] response_str = response["choices"][0]["message"]["content"]
response_str_2 = response.choices[0].message.content response_str_2 = response.choices[0].message.content
print(response) print(response)
@ -154,6 +157,7 @@ def test_completion_openai_prompt():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
def test_completion_text_openai(): def test_completion_text_openai():
try: try:
response = completion(model="text-davinci-003", messages=messages) response = completion(model="text-davinci-003", messages=messages)

View file

@ -27,7 +27,7 @@
# # print(f"user_model_dict: {user_model_dict}") # # print(f"user_model_dict: {user_model_dict}")
# pass # pass
# # normal call # # normal call
# def test_completion_custom_provider_model_name(): # def test_completion_custom_provider_model_name():
# try: # try:
# response = completion_with_retries( # response = completion_with_retries(
@ -41,7 +41,7 @@
# pytest.fail(f"Error occurred: {e}") # pytest.fail(f"Error occurred: {e}")
# # bad call # # bad call
# # def test_completion_custom_provider_model_name(): # # def test_completion_custom_provider_model_name():
# # try: # # try:
# # response = completion_with_retries( # # response = completion_with_retries(
@ -54,7 +54,7 @@
# # except Exception as e: # # except Exception as e:
# # pytest.fail(f"Error occurred: {e}") # # pytest.fail(f"Error occurred: {e}")
# # impact on exception mapping # # impact on exception mapping
# def test_context_window(): # def test_context_window():
# sample_text = "how does a court case get to the Supreme Court?" * 5000 # sample_text = "how does a court case get to the Supreme Court?" * 5000
# messages = [{"content": sample_text, "role": "user"}] # messages = [{"content": sample_text, "role": "user"}]
@ -83,4 +83,4 @@
# test_context_window() # test_context_window()
# test_completion_custom_provider_model_name() # test_completion_custom_provider_model_name()

View file

@ -22,4 +22,6 @@ def test_openai_embedding():
# print(f"response: {str(response)}") # print(f"response: {str(response)}")
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
test_openai_embedding()
test_openai_embedding()

View file

@ -4,7 +4,7 @@
import sys import sys
import os import os
sys.path.insert(0, os.path.abspath('../..')) sys.path.insert(0, os.path.abspath("../.."))
from litellm import completion, embedding from litellm import completion, embedding
import litellm import litellm
@ -17,11 +17,10 @@ litellm.set_verbose = True
def test_chat_openai(): def test_chat_openai():
try: try:
response = completion(model="gpt-3.5-turbo", response = completion(
messages=[{ model="gpt-3.5-turbo",
"role": "user", messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
"content": "Hi 👋 - i'm openai" )
}])
print(response) print(response)
@ -31,7 +30,7 @@ def test_chat_openai():
def test_embedding_openai(): def test_embedding_openai():
try: try:
response = embedding(model="text-embedding-ada-002", input=['test']) response = embedding(model="text-embedding-ada-002", input=["test"])
# Add any assertions here to check the response # Add any assertions here to check the response
print(f"response: {str(response)[:50]}") print(f"response: {str(response)[:50]}")
except Exception as e: except Exception as e:
@ -39,4 +38,4 @@ def test_embedding_openai():
test_chat_openai() test_chat_openai()
test_embedding_openai() test_embedding_openai()

View file

@ -13,5 +13,19 @@ from litellm import embedding, completion
litellm.set_verbose = True litellm.set_verbose = True
# Test: Check if the alias created via LiteDebugger is mapped correctly # Test: Check if the alias created via LiteDebugger is mapped correctly
{"top_p": 0.75, "prompt": "What's the meaning of life?", "num_beams": 4, "temperature": 0.1} {
print(completion("llama2", messages=[{"role": "user", "content": "Hey, how's it going?"}], top_p=0.1, temperature=0, num_beams=4, max_tokens=60)) "top_p": 0.75,
"prompt": "What's the meaning of life?",
"num_beams": 4,
"temperature": 0.1,
}
print(
completion(
"llama2",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
top_p=0.1,
temperature=0,
num_beams=4,
max_tokens=60,
)
)

View file

@ -3,12 +3,14 @@
import sys, os import sys, os
import traceback import traceback
import time import time
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import litellm import litellm
from litellm import completion from litellm import completion
litellm.logging = False litellm.logging = False
litellm.set_verbose = False litellm.set_verbose = False
@ -31,11 +33,11 @@ messages = [{"content": user_message, "role": "user"}]
# complete_response = "" # complete_response = ""
# start_time = time.time() # start_time = time.time()
# for chunk in response: # for chunk in response:
# chunk_time = time.time() # chunk_time = time.time()
# print(f"time since initial request: {chunk_time - start_time:.5f}") # print(f"time since initial request: {chunk_time - start_time:.5f}")
# print(chunk["choices"][0]["delta"]) # print(chunk["choices"][0]["delta"])
# complete_response += chunk["choices"][0]["delta"]["content"] # complete_response += chunk["choices"][0]["delta"]["content"]
# if complete_response == "": # if complete_response == "":
# raise Exception("Empty response received") # raise Exception("Empty response received")
# except: # except:
# print(f"error occurred: {traceback.format_exc()}") # print(f"error occurred: {traceback.format_exc()}")
@ -50,11 +52,11 @@ messages = [{"content": user_message, "role": "user"}]
# response = "" # response = ""
# start_time = time.time() # start_time = time.time()
# for chunk in response: # for chunk in response:
# chunk_time = time.time() # chunk_time = time.time()
# print(f"time since initial request: {chunk_time - start_time:.2f}") # print(f"time since initial request: {chunk_time - start_time:.2f}")
# print(chunk["choices"][0]["delta"]) # print(chunk["choices"][0]["delta"])
# response += chunk["choices"][0]["delta"] # response += chunk["choices"][0]["delta"]
# if response == "": # if response == "":
# raise Exception("Empty response received") # raise Exception("Empty response received")
# except: # except:
# print(f"error occurred: {traceback.format_exc()}") # print(f"error occurred: {traceback.format_exc()}")
@ -73,7 +75,7 @@ try:
print(f"time since initial request: {chunk_time - start_time:.5f}") print(f"time since initial request: {chunk_time - start_time:.5f}")
print(chunk["choices"][0]["delta"]) print(chunk["choices"][0]["delta"])
complete_response += chunk["choices"][0]["delta"]["content"] complete_response += chunk["choices"][0]["delta"]["content"]
if complete_response == "": if complete_response == "":
raise Exception("Empty response received") raise Exception("Empty response received")
except: except:
print(f"error occurred: {traceback.format_exc()}") print(f"error occurred: {traceback.format_exc()}")
@ -88,11 +90,11 @@ except:
# ) # )
# complete_response = "" # complete_response = ""
# for chunk in response: # for chunk in response:
# chunk_time = time.time() # chunk_time = time.time()
# print(f"time since initial request: {chunk_time - start_time:.2f}") # print(f"time since initial request: {chunk_time - start_time:.2f}")
# print(chunk["choices"][0]["delta"]) # print(chunk["choices"][0]["delta"])
# complete_response += chunk["choices"][0]["delta"]["content"] if len(chunk["choices"][0]["delta"].keys()) > 0 else "" # complete_response += chunk["choices"][0]["delta"]["content"] if len(chunk["choices"][0]["delta"].keys()) > 0 else ""
# if complete_response == "": # if complete_response == "":
# raise Exception("Empty response received") # raise Exception("Empty response received")
# except: # except:
# print(f"error occurred: {traceback.format_exc()}") # print(f"error occurred: {traceback.format_exc()}")
@ -102,16 +104,20 @@ except:
try: try:
start_time = time.time() start_time = time.time()
response = completion( response = completion(
model="Replit-Code-3B", messages=messages, logger_fn=logger_fn, stream= True model="Replit-Code-3B", messages=messages, logger_fn=logger_fn, stream=True
) )
complete_response = "" complete_response = ""
print(f"returned response object: {response}") print(f"returned response object: {response}")
for chunk in response: for chunk in response:
chunk_time = time.time() chunk_time = time.time()
print(f"time since initial request: {chunk_time - start_time:.2f}") print(f"time since initial request: {chunk_time - start_time:.2f}")
print(chunk["choices"][0]["delta"]) print(chunk["choices"][0]["delta"])
complete_response += chunk["choices"][0]["delta"]["content"] if len(chunk["choices"][0]["delta"].keys()) > 0 else "" complete_response += (
if complete_response == "": chunk["choices"][0]["delta"]["content"]
if len(chunk["choices"][0]["delta"].keys()) > 0
else ""
)
if complete_response == "":
raise Exception("Empty response received") raise Exception("Empty response received")
except: except:
print(f"error occurred: {traceback.format_exc()}") print(f"error occurred: {traceback.format_exc()}")
@ -121,16 +127,23 @@ except:
try: try:
start_time = time.time() start_time = time.time()
response = completion( response = completion(
model="together_ai/bigcode/starcoder", messages=messages, logger_fn=logger_fn, stream= True model="together_ai/bigcode/starcoder",
messages=messages,
logger_fn=logger_fn,
stream=True,
) )
complete_response = "" complete_response = ""
print(f"returned response object: {response}") print(f"returned response object: {response}")
for chunk in response: for chunk in response:
chunk_time = time.time() chunk_time = time.time()
complete_response += chunk["choices"][0]["delta"]["content"] if len(chunk["choices"][0]["delta"].keys()) > 0 else "" complete_response += (
chunk["choices"][0]["delta"]["content"]
if len(chunk["choices"][0]["delta"].keys()) > 0
else ""
)
if len(complete_response) > 0: if len(complete_response) > 0:
print(complete_response) print(complete_response)
if complete_response == "": if complete_response == "":
raise Exception("Empty response received") raise Exception("Empty response received")
except: except:
print(f"error occurred: {traceback.format_exc()}") print(f"error occurred: {traceback.format_exc()}")
@ -144,11 +157,11 @@ except:
# ) # )
# response = "" # response = ""
# for chunk in response: # for chunk in response:
# chunk_time = time.time() # chunk_time = time.time()
# print(f"time since initial request: {chunk_time - start_time:.2f}") # print(f"time since initial request: {chunk_time - start_time:.2f}")
# print(chunk["choices"][0]["delta"]) # print(chunk["choices"][0]["delta"])
# response += chunk["choices"][0]["delta"] # response += chunk["choices"][0]["delta"]
# if response == "": # if response == "":
# raise Exception("Empty response received") # raise Exception("Empty response received")
# except: # except:
# print(f"error occurred: {traceback.format_exc()}") # print(f"error occurred: {traceback.format_exc()}")
@ -162,11 +175,11 @@ except:
# ) # )
# response = "" # response = ""
# for chunk in response: # for chunk in response:
# chunk_time = time.time() # chunk_time = time.time()
# print(f"time since initial request: {chunk_time - start_time:.2f}") # print(f"time since initial request: {chunk_time - start_time:.2f}")
# print(chunk["choices"][0]["delta"]) # print(chunk["choices"][0]["delta"])
# response += chunk["choices"][0]["delta"] # response += chunk["choices"][0]["delta"]
# if response == "": # if response == "":
# raise Exception("Empty response received") # raise Exception("Empty response received")
# except: # except:
# print(f"error occurred: {traceback.format_exc()}") # print(f"error occurred: {traceback.format_exc()}")

View file

@ -69,7 +69,6 @@ last_fetched_at_keys = None
class Message(OpenAIObject): class Message(OpenAIObject):
def __init__(self, content="default", role="assistant", **params): def __init__(self, content="default", role="assistant", **params):
super(Message, self).__init__(**params) super(Message, self).__init__(**params)
self.content = content self.content = content
@ -77,12 +76,7 @@ class Message(OpenAIObject):
class Choices(OpenAIObject): class Choices(OpenAIObject):
def __init__(self, finish_reason="stop", index=0, message=Message(), **params):
def __init__(self,
finish_reason="stop",
index=0,
message=Message(),
**params):
super(Choices, self).__init__(**params) super(Choices, self).__init__(**params)
self.finish_reason = finish_reason self.finish_reason = finish_reason
self.index = index self.index = index
@ -90,22 +84,20 @@ class Choices(OpenAIObject):
class ModelResponse(OpenAIObject): class ModelResponse(OpenAIObject):
def __init__(self, choices=None, created=None, model=None, usage=None, **params):
def __init__(self,
choices=None,
created=None,
model=None,
usage=None,
**params):
super(ModelResponse, self).__init__(**params) super(ModelResponse, self).__init__(**params)
self.choices = choices if choices else [Choices()] self.choices = choices if choices else [Choices()]
self.created = created self.created = created
self.model = model self.model = model
self.usage = (usage if usage else { self.usage = (
"prompt_tokens": None, usage
"completion_tokens": None, if usage
"total_tokens": None, else {
}) "prompt_tokens": None,
"completion_tokens": None,
"total_tokens": None,
}
)
def to_dict_recursive(self): def to_dict_recursive(self):
d = super().to_dict_recursive() d = super().to_dict_recursive()
@ -173,7 +165,9 @@ class Logging:
self.model_call_details["api_key"] = api_key self.model_call_details["api_key"] = api_key
self.model_call_details["additional_args"] = additional_args self.model_call_details["additional_args"] = additional_args
if model: # if model name was changes pre-call, overwrite the initial model call name with the new one if (
model
): # if model name was changes pre-call, overwrite the initial model call name with the new one
self.model_call_details["model"] = model self.model_call_details["model"] = model
# User Logging -> if you pass in a custom logging function # User Logging -> if you pass in a custom logging function
@ -203,8 +197,7 @@ class Logging:
model=model, model=model,
messages=messages, messages=messages,
end_user=litellm._thread_context.user, end_user=litellm._thread_context.user,
litellm_call_id=self. litellm_call_id=self.litellm_params["litellm_call_id"],
litellm_params["litellm_call_id"],
print_verbose=print_verbose, print_verbose=print_verbose,
) )
@ -217,8 +210,7 @@ class Logging:
model=model, model=model,
messages=messages, messages=messages,
end_user=litellm._thread_context.user, end_user=litellm._thread_context.user,
litellm_call_id=self. litellm_call_id=self.litellm_params["litellm_call_id"],
litellm_params["litellm_call_id"],
litellm_params=self.model_call_details["litellm_params"], litellm_params=self.model_call_details["litellm_params"],
optional_params=self.model_call_details["optional_params"], optional_params=self.model_call_details["optional_params"],
print_verbose=print_verbose, print_verbose=print_verbose,
@ -263,7 +255,7 @@ class Logging:
print_verbose( print_verbose(
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {traceback.format_exc()}" f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {traceback.format_exc()}"
) )
# Input Integration Logging -> If you want to log the fact that an attempt to call the model was made # Input Integration Logging -> If you want to log the fact that an attempt to call the model was made
for callback in litellm.input_callback: for callback in litellm.input_callback:
try: try:
@ -274,8 +266,7 @@ class Logging:
print_verbose(f"liteDebuggerClient: {liteDebuggerClient}") print_verbose(f"liteDebuggerClient: {liteDebuggerClient}")
liteDebuggerClient.post_call_log_event( liteDebuggerClient.post_call_log_event(
original_response=original_response, original_response=original_response,
litellm_call_id=self. litellm_call_id=self.litellm_params["litellm_call_id"],
litellm_params["litellm_call_id"],
print_verbose=print_verbose, print_verbose=print_verbose,
) )
except: except:
@ -295,6 +286,7 @@ class Logging:
# Add more methods as needed # Add more methods as needed
def exception_logging( def exception_logging(
additional_args={}, additional_args={},
logger_fn=None, logger_fn=None,
@ -329,13 +321,18 @@ def exception_logging(
# make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking # make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking
def client(original_function): def client(original_function):
global liteDebuggerClient, get_all_keys global liteDebuggerClient, get_all_keys
def function_setup( def function_setup(
*args, **kwargs *args, **kwargs
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc. ): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
try: try:
global callback_list, add_breadcrumb, user_logger_fn global callback_list, add_breadcrumb, user_logger_fn
if litellm.email is not None or os.getenv("LITELLM_EMAIL", None) is not None or litellm.token is not None or os.getenv("LITELLM_TOKEN", None): # add to input, success and failure callbacks if user is using hosted product if (
litellm.email is not None
or os.getenv("LITELLM_EMAIL", None) is not None
or litellm.token is not None
or os.getenv("LITELLM_TOKEN", None)
): # add to input, success and failure callbacks if user is using hosted product
get_all_keys() get_all_keys()
if "lite_debugger" not in callback_list and litellm.logging: if "lite_debugger" not in callback_list and litellm.logging:
litellm.input_callback.append("lite_debugger") litellm.input_callback.append("lite_debugger")
@ -381,11 +378,12 @@ def client(original_function):
if litellm.telemetry: if litellm.telemetry:
try: try:
model = args[0] if len(args) > 0 else kwargs["model"] model = args[0] if len(args) > 0 else kwargs["model"]
exception = kwargs[ exception = kwargs["exception"] if "exception" in kwargs else None
"exception"] if "exception" in kwargs else None custom_llm_provider = (
custom_llm_provider = (kwargs["custom_llm_provider"] kwargs["custom_llm_provider"]
if "custom_llm_provider" in kwargs else if "custom_llm_provider" in kwargs
None) else None
)
safe_crash_reporting( safe_crash_reporting(
model=model, model=model,
exception=exception, exception=exception,
@ -410,10 +408,10 @@ def client(original_function):
def check_cache(*args, **kwargs): def check_cache(*args, **kwargs):
try: # never block execution try: # never block execution
prompt = get_prompt(*args, **kwargs) prompt = get_prompt(*args, **kwargs)
if (prompt != None): # check if messages / prompt exists if prompt != None: # check if messages / prompt exists
if litellm.caching_with_models: if litellm.caching_with_models:
# if caching with model names is enabled, key is prompt + model name # if caching with model names is enabled, key is prompt + model name
if ("model" in kwargs): if "model" in kwargs:
cache_key = prompt + kwargs["model"] cache_key = prompt + kwargs["model"]
if cache_key in local_cache: if cache_key in local_cache:
return local_cache[cache_key] return local_cache[cache_key]
@ -423,7 +421,7 @@ def client(original_function):
return result return result
else: else:
return None return None
return None # default to return None return None # default to return None
except: except:
return None return None
@ -431,7 +429,7 @@ def client(original_function):
try: # never block execution try: # never block execution
prompt = get_prompt(*args, **kwargs) prompt = get_prompt(*args, **kwargs)
if litellm.caching_with_models: # caching with model + prompt if litellm.caching_with_models: # caching with model + prompt
if ("model" in kwargs): if "model" in kwargs:
cache_key = prompt + kwargs["model"] cache_key = prompt + kwargs["model"]
local_cache[cache_key] = result local_cache[cache_key] = result
else: # caching based only on prompts else: # caching based only on prompts
@ -449,7 +447,8 @@ def client(original_function):
start_time = datetime.datetime.now() start_time = datetime.datetime.now()
# [OPTIONAL] CHECK CACHE # [OPTIONAL] CHECK CACHE
if (litellm.caching or litellm.caching_with_models) and ( if (litellm.caching or litellm.caching_with_models) and (
cached_result := check_cache(*args, **kwargs)) is not None: cached_result := check_cache(*args, **kwargs)
) is not None:
result = cached_result result = cached_result
return result return result
# MODEL CALL # MODEL CALL
@ -458,25 +457,22 @@ def client(original_function):
return result return result
end_time = datetime.datetime.now() end_time = datetime.datetime.now()
# [OPTIONAL] ADD TO CACHE # [OPTIONAL] ADD TO CACHE
if (litellm.caching or litellm.caching_with_models): if litellm.caching or litellm.caching_with_models:
add_cache(result, *args, **kwargs) add_cache(result, *args, **kwargs)
# LOG SUCCESS # LOG SUCCESS
my_thread = threading.Thread( my_thread = threading.Thread(
target=handle_success, target=handle_success, args=(args, kwargs, result, start_time, end_time)
args=(args, kwargs, result, start_time, ) # don't interrupt execution of main thread
end_time)) # don't interrupt execution of main thread
my_thread.start() my_thread.start()
# RETURN RESULT # RETURN RESULT
return result return result
except Exception as e: except Exception as e:
traceback_exception = traceback.format_exc() traceback_exception = traceback.format_exc()
crash_reporting(*args, **kwargs, exception=traceback_exception) crash_reporting(*args, **kwargs, exception=traceback_exception)
end_time = datetime.datetime.now() end_time = datetime.datetime.now()
my_thread = threading.Thread( my_thread = threading.Thread(
target=handle_failure, target=handle_failure,
args=(e, traceback_exception, start_time, end_time, args, args=(e, traceback_exception, start_time, end_time, args, kwargs),
kwargs),
) # don't interrupt execution of main thread ) # don't interrupt execution of main thread
my_thread.start() my_thread.start()
if hasattr(e, "message"): if hasattr(e, "message"):
@ -506,18 +502,18 @@ def token_counter(model, text):
return num_tokens return num_tokens
def cost_per_token(model="gpt-3.5-turbo", def cost_per_token(model="gpt-3.5-turbo", prompt_tokens=0, completion_tokens=0):
prompt_tokens=0,
completion_tokens=0):
# given # given
prompt_tokens_cost_usd_dollar = 0 prompt_tokens_cost_usd_dollar = 0
completion_tokens_cost_usd_dollar = 0 completion_tokens_cost_usd_dollar = 0
model_cost_ref = litellm.model_cost model_cost_ref = litellm.model_cost
if model in model_cost_ref: if model in model_cost_ref:
prompt_tokens_cost_usd_dollar = ( prompt_tokens_cost_usd_dollar = (
model_cost_ref[model]["input_cost_per_token"] * prompt_tokens) model_cost_ref[model]["input_cost_per_token"] * prompt_tokens
)
completion_tokens_cost_usd_dollar = ( completion_tokens_cost_usd_dollar = (
model_cost_ref[model]["output_cost_per_token"] * completion_tokens) model_cost_ref[model]["output_cost_per_token"] * completion_tokens
)
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
else: else:
# calculate average input cost # calculate average input cost
@ -538,9 +534,8 @@ def completion_cost(model="gpt-3.5-turbo", prompt="", completion=""):
prompt_tokens = token_counter(model=model, text=prompt) prompt_tokens = token_counter(model=model, text=prompt)
completion_tokens = token_counter(model=model, text=completion) completion_tokens = token_counter(model=model, text=completion)
prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar = cost_per_token( prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar = cost_per_token(
model=model, model=model, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens
prompt_tokens=prompt_tokens, )
completion_tokens=completion_tokens)
return prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar return prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
@ -558,7 +553,7 @@ def get_litellm_params(
custom_llm_provider=None, custom_llm_provider=None,
custom_api_base=None, custom_api_base=None,
litellm_call_id=None, litellm_call_id=None,
model_alias_map=None model_alias_map=None,
): ):
litellm_params = { litellm_params = {
"return_async": return_async, "return_async": return_async,
@ -569,13 +564,13 @@ def get_litellm_params(
"custom_llm_provider": custom_llm_provider, "custom_llm_provider": custom_llm_provider,
"custom_api_base": custom_api_base, "custom_api_base": custom_api_base,
"litellm_call_id": litellm_call_id, "litellm_call_id": litellm_call_id,
"model_alias_map": model_alias_map "model_alias_map": model_alias_map,
} }
return litellm_params return litellm_params
def get_optional_params( # use the openai defaults def get_optional_params( # use the openai defaults
# 12 optional params # 12 optional params
functions=[], functions=[],
function_call="", function_call="",
@ -588,7 +583,7 @@ def get_optional_params( # use the openai defaults
presence_penalty=0, presence_penalty=0,
frequency_penalty=0, frequency_penalty=0,
logit_bias={}, logit_bias={},
num_beams=1, num_beams=1,
user="", user="",
deployment_id=None, deployment_id=None,
model=None, model=None,
@ -635,8 +630,9 @@ def get_optional_params( # use the openai defaults
optional_params["max_tokens"] = max_tokens optional_params["max_tokens"] = max_tokens
if frequency_penalty != 0: if frequency_penalty != 0:
optional_params["frequency_penalty"] = frequency_penalty optional_params["frequency_penalty"] = frequency_penalty
elif (model == "chat-bison" elif (
): # chat-bison has diff args from chat-bison@001 ty Google model == "chat-bison"
): # chat-bison has diff args from chat-bison@001 ty Google
if temperature != 1: if temperature != 1:
optional_params["temperature"] = temperature optional_params["temperature"] = temperature
if top_p != 1: if top_p != 1:
@ -702,10 +698,7 @@ def load_test_model(
test_prompt = prompt test_prompt = prompt
if num_calls: if num_calls:
test_calls = num_calls test_calls = num_calls
messages = [[{ messages = [[{"role": "user", "content": test_prompt}] for _ in range(test_calls)]
"role": "user",
"content": test_prompt
}] for _ in range(test_calls)]
start_time = time.time() start_time = time.time()
try: try:
litellm.batch_completion( litellm.batch_completion(
@ -743,15 +736,17 @@ def set_callbacks(callback_list):
try: try:
import sentry_sdk import sentry_sdk
except ImportError: except ImportError:
print_verbose( print_verbose("Package 'sentry_sdk' is missing. Installing it...")
"Package 'sentry_sdk' is missing. Installing it...")
subprocess.check_call( subprocess.check_call(
[sys.executable, "-m", "pip", "install", "sentry_sdk"]) [sys.executable, "-m", "pip", "install", "sentry_sdk"]
)
import sentry_sdk import sentry_sdk
sentry_sdk_instance = sentry_sdk sentry_sdk_instance = sentry_sdk
sentry_trace_rate = (os.environ.get("SENTRY_API_TRACE_RATE") sentry_trace_rate = (
if "SENTRY_API_TRACE_RATE" in os.environ os.environ.get("SENTRY_API_TRACE_RATE")
else "1.0") if "SENTRY_API_TRACE_RATE" in os.environ
else "1.0"
)
sentry_sdk_instance.init( sentry_sdk_instance.init(
dsn=os.environ.get("SENTRY_API_URL"), dsn=os.environ.get("SENTRY_API_URL"),
traces_sample_rate=float(sentry_trace_rate), traces_sample_rate=float(sentry_trace_rate),
@ -762,10 +757,10 @@ def set_callbacks(callback_list):
try: try:
from posthog import Posthog from posthog import Posthog
except ImportError: except ImportError:
print_verbose( print_verbose("Package 'posthog' is missing. Installing it...")
"Package 'posthog' is missing. Installing it...")
subprocess.check_call( subprocess.check_call(
[sys.executable, "-m", "pip", "install", "posthog"]) [sys.executable, "-m", "pip", "install", "posthog"]
)
from posthog import Posthog from posthog import Posthog
posthog = Posthog( posthog = Posthog(
project_api_key=os.environ.get("POSTHOG_API_KEY"), project_api_key=os.environ.get("POSTHOG_API_KEY"),
@ -775,10 +770,10 @@ def set_callbacks(callback_list):
try: try:
from slack_bolt import App from slack_bolt import App
except ImportError: except ImportError:
print_verbose( print_verbose("Package 'slack_bolt' is missing. Installing it...")
"Package 'slack_bolt' is missing. Installing it...")
subprocess.check_call( subprocess.check_call(
[sys.executable, "-m", "pip", "install", "slack_bolt"]) [sys.executable, "-m", "pip", "install", "slack_bolt"]
)
from slack_bolt import App from slack_bolt import App
slack_app = App( slack_app = App(
token=os.environ.get("SLACK_API_TOKEN"), token=os.environ.get("SLACK_API_TOKEN"),
@ -809,8 +804,7 @@ def set_callbacks(callback_list):
raise e raise e
def handle_failure(exception, traceback_exception, start_time, end_time, args, def handle_failure(exception, traceback_exception, start_time, end_time, args, kwargs):
kwargs):
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient
try: try:
# print_verbose(f"handle_failure args: {args}") # print_verbose(f"handle_failure args: {args}")
@ -820,7 +814,8 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args,
failure_handler = additional_details.pop("failure_handler", None) failure_handler = additional_details.pop("failure_handler", None)
additional_details["Event_Name"] = additional_details.pop( additional_details["Event_Name"] = additional_details.pop(
"failed_event_name", "litellm.failed_query") "failed_event_name", "litellm.failed_query"
)
print_verbose(f"self.failure_callback: {litellm.failure_callback}") print_verbose(f"self.failure_callback: {litellm.failure_callback}")
for callback in litellm.failure_callback: for callback in litellm.failure_callback:
try: try:
@ -835,8 +830,9 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args,
for detail in additional_details: for detail in additional_details:
slack_msg += f"{detail}: {additional_details[detail]}\n" slack_msg += f"{detail}: {additional_details[detail]}\n"
slack_msg += f"Traceback: {traceback_exception}" slack_msg += f"Traceback: {traceback_exception}"
slack_app.client.chat_postMessage(channel=alerts_channel, slack_app.client.chat_postMessage(
text=slack_msg) channel=alerts_channel, text=slack_msg
)
elif callback == "sentry": elif callback == "sentry":
capture_exception(exception) capture_exception(exception)
elif callback == "posthog": elif callback == "posthog":
@ -855,8 +851,9 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args,
print_verbose(f"ph_obj: {ph_obj}") print_verbose(f"ph_obj: {ph_obj}")
print_verbose(f"PostHog Event Name: {event_name}") print_verbose(f"PostHog Event Name: {event_name}")
if "user_id" in additional_details: if "user_id" in additional_details:
posthog.capture(additional_details["user_id"], posthog.capture(
event_name, ph_obj) additional_details["user_id"], event_name, ph_obj
)
else: # PostHog calls require a unique id to identify a user - https://posthog.com/docs/libraries/python else: # PostHog calls require a unique id to identify a user - https://posthog.com/docs/libraries/python
unique_id = str(uuid.uuid4()) unique_id = str(uuid.uuid4())
posthog.capture(unique_id, event_name) posthog.capture(unique_id, event_name)
@ -870,10 +867,10 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args,
"created": time.time(), "created": time.time(),
"error": traceback_exception, "error": traceback_exception,
"usage": { "usage": {
"prompt_tokens": "prompt_tokens": prompt_token_calculator(
prompt_token_calculator(model, messages=messages), model, messages=messages
"completion_tokens": ),
0, "completion_tokens": 0,
}, },
} }
berrispendLogger.log_event( berrispendLogger.log_event(
@ -892,10 +889,10 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args,
"model": model, "model": model,
"created": time.time(), "created": time.time(),
"usage": { "usage": {
"prompt_tokens": "prompt_tokens": prompt_token_calculator(
prompt_token_calculator(model, messages=messages), model, messages=messages
"completion_tokens": ),
0, "completion_tokens": 0,
}, },
} }
aispendLogger.log_event( aispendLogger.log_event(
@ -910,10 +907,13 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args,
model = args[0] if len(args) > 0 else kwargs["model"] model = args[0] if len(args) > 0 else kwargs["model"]
input = args[1] if len(args) > 1 else kwargs.get( input = (
"messages", kwargs.get("input", None)) args[1]
if len(args) > 1
else kwargs.get("messages", kwargs.get("input", None))
)
type = 'embed' if 'input' in kwargs else 'llm' type = "embed" if "input" in kwargs else "llm"
llmonitorLogger.log_event( llmonitorLogger.log_event(
type=type, type=type,
@ -937,10 +937,10 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args,
"created": time.time(), "created": time.time(),
"error": traceback_exception, "error": traceback_exception,
"usage": { "usage": {
"prompt_tokens": "prompt_tokens": prompt_token_calculator(
prompt_token_calculator(model, messages=messages), model, messages=messages
"completion_tokens": ),
0, "completion_tokens": 0,
}, },
} }
supabaseClient.log_event( supabaseClient.log_event(
@ -957,16 +957,28 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args,
print_verbose("reaches lite_debugger for logging!") print_verbose("reaches lite_debugger for logging!")
print_verbose(f"liteDebuggerClient: {liteDebuggerClient}") print_verbose(f"liteDebuggerClient: {liteDebuggerClient}")
model = args[0] if len(args) > 0 else kwargs["model"] model = args[0] if len(args) > 0 else kwargs["model"]
messages = args[1] if len(args) > 1 else kwargs.get("messages", [{"role": "user", "content": ' '.join(kwargs.get("input", ""))}]) messages = (
args[1]
if len(args) > 1
else kwargs.get(
"messages",
[
{
"role": "user",
"content": " ".join(kwargs.get("input", "")),
}
],
)
)
result = { result = {
"model": model, "model": model,
"created": time.time(), "created": time.time(),
"error": traceback_exception, "error": traceback_exception,
"usage": { "usage": {
"prompt_tokens": "prompt_tokens": prompt_token_calculator(
prompt_token_calculator(model, messages=messages), model, messages=messages
"completion_tokens": ),
0, "completion_tokens": 0,
}, },
} }
liteDebuggerClient.log_event( liteDebuggerClient.log_event(
@ -1002,11 +1014,16 @@ def handle_success(args, kwargs, result, start_time, end_time):
global heliconeLogger, aispendLogger, supabaseClient, liteDebuggerClient, llmonitorLogger global heliconeLogger, aispendLogger, supabaseClient, liteDebuggerClient, llmonitorLogger
try: try:
model = args[0] if len(args) > 0 else kwargs["model"] model = args[0] if len(args) > 0 else kwargs["model"]
input = args[1] if len(args) > 1 else kwargs.get("messages", kwargs.get("input", None)) input = (
args[1]
if len(args) > 1
else kwargs.get("messages", kwargs.get("input", None))
)
success_handler = additional_details.pop("success_handler", None) success_handler = additional_details.pop("success_handler", None)
failure_handler = additional_details.pop("failure_handler", None) failure_handler = additional_details.pop("failure_handler", None)
additional_details["Event_Name"] = additional_details.pop( additional_details["Event_Name"] = additional_details.pop(
"successful_event_name", "litellm.succes_query") "successful_event_name", "litellm.succes_query"
)
for callback in litellm.success_callback: for callback in litellm.success_callback:
try: try:
if callback == "posthog": if callback == "posthog":
@ -1015,8 +1032,9 @@ def handle_success(args, kwargs, result, start_time, end_time):
ph_obj[detail] = additional_details[detail] ph_obj[detail] = additional_details[detail]
event_name = additional_details["Event_Name"] event_name = additional_details["Event_Name"]
if "user_id" in additional_details: if "user_id" in additional_details:
posthog.capture(additional_details["user_id"], posthog.capture(
event_name, ph_obj) additional_details["user_id"], event_name, ph_obj
)
else: # PostHog calls require a unique id to identify a user - https://posthog.com/docs/libraries/python else: # PostHog calls require a unique id to identify a user - https://posthog.com/docs/libraries/python
unique_id = str(uuid.uuid4()) unique_id = str(uuid.uuid4())
posthog.capture(unique_id, event_name, ph_obj) posthog.capture(unique_id, event_name, ph_obj)
@ -1025,8 +1043,9 @@ def handle_success(args, kwargs, result, start_time, end_time):
slack_msg = "" slack_msg = ""
for detail in additional_details: for detail in additional_details:
slack_msg += f"{detail}: {additional_details[detail]}\n" slack_msg += f"{detail}: {additional_details[detail]}\n"
slack_app.client.chat_postMessage(channel=alerts_channel, slack_app.client.chat_postMessage(
text=slack_msg) channel=alerts_channel, text=slack_msg
)
elif callback == "helicone": elif callback == "helicone":
print_verbose("reaches helicone for logging!") print_verbose("reaches helicone for logging!")
model = args[0] if len(args) > 0 else kwargs["model"] model = args[0] if len(args) > 0 else kwargs["model"]
@ -1043,11 +1062,14 @@ def handle_success(args, kwargs, result, start_time, end_time):
print_verbose("reaches llmonitor for logging!") print_verbose("reaches llmonitor for logging!")
model = args[0] if len(args) > 0 else kwargs["model"] model = args[0] if len(args) > 0 else kwargs["model"]
input = args[1] if len(args) > 1 else kwargs.get( input = (
"messages", kwargs.get("input", None)) args[1]
if len(args) > 1
else kwargs.get("messages", kwargs.get("input", None))
)
#if contains input, it's 'embedding', otherwise 'llm' # if contains input, it's 'embedding', otherwise 'llm'
type = 'embed' if 'input' in kwargs else 'llm' type = "embed" if "input" in kwargs else "llm"
llmonitorLogger.log_event( llmonitorLogger.log_event(
type=type, type=type,
@ -1069,7 +1091,6 @@ def handle_success(args, kwargs, result, start_time, end_time):
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
print_verbose=print_verbose, print_verbose=print_verbose,
) )
elif callback == "aispend": elif callback == "aispend":
print_verbose("reaches aispend for logging!") print_verbose("reaches aispend for logging!")
@ -1084,7 +1105,11 @@ def handle_success(args, kwargs, result, start_time, end_time):
elif callback == "supabase": elif callback == "supabase":
print_verbose("reaches supabase for logging!") print_verbose("reaches supabase for logging!")
model = args[0] if len(args) > 0 else kwargs["model"] model = args[0] if len(args) > 0 else kwargs["model"]
messages = args[1] if len(args) > 1 else kwargs.get("messages", {"role": "user", "content": ""}) messages = (
args[1]
if len(args) > 1
else kwargs.get("messages", {"role": "user", "content": ""})
)
print(f"supabaseClient: {supabaseClient}") print(f"supabaseClient: {supabaseClient}")
supabaseClient.log_event( supabaseClient.log_event(
model=model, model=model,
@ -1099,7 +1124,19 @@ def handle_success(args, kwargs, result, start_time, end_time):
elif callback == "lite_debugger": elif callback == "lite_debugger":
print_verbose("reaches lite_debugger for logging!") print_verbose("reaches lite_debugger for logging!")
print_verbose(f"liteDebuggerClient: {liteDebuggerClient}") print_verbose(f"liteDebuggerClient: {liteDebuggerClient}")
messages = args[1] if len(args) > 1 else kwargs.get("messages", [{"role": "user", "content": ' '.join(kwargs.get("input", ""))}]) messages = (
args[1]
if len(args) > 1
else kwargs.get(
"messages",
[
{
"role": "user",
"content": " ".join(kwargs.get("input", "")),
}
],
)
)
liteDebuggerClient.log_event( liteDebuggerClient.log_event(
model=model, model=model,
messages=messages, messages=messages,
@ -1129,6 +1166,7 @@ def handle_success(args, kwargs, result, start_time, end_time):
) )
pass pass
def acreate(*args, **kwargs): ## Thin client to handle the acreate langchain call def acreate(*args, **kwargs): ## Thin client to handle the acreate langchain call
return litellm.acompletion(*args, **kwargs) return litellm.acompletion(*args, **kwargs)
@ -1170,28 +1208,43 @@ def modify_integration(integration_name, integration_params):
if "table_name" in integration_params: if "table_name" in integration_params:
Supabase.supabase_table_name = integration_params["table_name"] Supabase.supabase_table_name = integration_params["table_name"]
####### [BETA] HOSTED PRODUCT ################ - https://docs.litellm.ai/docs/debugging/hosted_debugging ####### [BETA] HOSTED PRODUCT ################ - https://docs.litellm.ai/docs/debugging/hosted_debugging
def get_all_keys(llm_provider=None): def get_all_keys(llm_provider=None):
try: try:
global last_fetched_at_keys global last_fetched_at_keys
# if user is using hosted product -> instantiate their env with their hosted api keys - refresh every 5 minutes # if user is using hosted product -> instantiate their env with their hosted api keys - refresh every 5 minutes
print_verbose(f"Reaches get all keys, llm_provider: {llm_provider}") print_verbose(f"Reaches get all keys, llm_provider: {llm_provider}")
user_email = os.getenv("LITELLM_EMAIL") or litellm.email or litellm.token or os.getenv("LITELLM_TOKEN") user_email = (
os.getenv("LITELLM_EMAIL")
or litellm.email
or litellm.token
or os.getenv("LITELLM_TOKEN")
)
if user_email: if user_email:
time_delta = 0 time_delta = 0
if last_fetched_at_keys != None: if last_fetched_at_keys != None:
current_time = time.time() current_time = time.time()
time_delta = current_time - last_fetched_at_keys time_delta = current_time - last_fetched_at_keys
if time_delta > 300 or last_fetched_at_keys == None or llm_provider: # if the llm provider is passed in , assume this happening due to an AuthError for that provider if (
time_delta > 300 or last_fetched_at_keys == None or llm_provider
): # if the llm provider is passed in , assume this happening due to an AuthError for that provider
# make the api call # make the api call
last_fetched_at = time.time() last_fetched_at = time.time()
print_verbose(f"last_fetched_at: {last_fetched_at}") print_verbose(f"last_fetched_at: {last_fetched_at}")
response = requests.post(url="http://api.litellm.ai/get_all_keys", headers={"content-type": "application/json"}, data=json.dumps({"user_email": user_email})) response = requests.post(
url="http://api.litellm.ai/get_all_keys",
headers={"content-type": "application/json"},
data=json.dumps({"user_email": user_email}),
)
print_verbose(f"get model key response: {response.text}") print_verbose(f"get model key response: {response.text}")
data = response.json() data = response.json()
# update model list # update model list
for key, value in data["model_keys"].items(): # follows the LITELLM API KEY format - <UPPERCASE_PROVIDER_NAME>_API_KEY - e.g. HUGGINGFACE_API_KEY for key, value in data[
"model_keys"
].items(): # follows the LITELLM API KEY format - <UPPERCASE_PROVIDER_NAME>_API_KEY - e.g. HUGGINGFACE_API_KEY
os.environ[key] = value os.environ[key] = value
# set model alias map # set model alias map
for model_alias, value in data["model_alias_map"].items(): for model_alias, value in data["model_alias_map"].items():
@ -1200,19 +1253,31 @@ def get_all_keys(llm_provider=None):
return None return None
return None return None
except: except:
print_verbose(f"[Non-Blocking Error] get_all_keys error - {traceback.format_exc()}") print_verbose(
f"[Non-Blocking Error] get_all_keys error - {traceback.format_exc()}"
)
pass pass
def get_model_list(): def get_model_list():
global last_fetched_at global last_fetched_at
try: try:
# if user is using hosted product -> get their updated model list # if user is using hosted product -> get their updated model list
user_email = os.getenv("LITELLM_EMAIL") or litellm.email or litellm.token or os.getenv("LITELLM_TOKEN") user_email = (
os.getenv("LITELLM_EMAIL")
or litellm.email
or litellm.token
or os.getenv("LITELLM_TOKEN")
)
if user_email: if user_email:
# make the api call # make the api call
last_fetched_at = time.time() last_fetched_at = time.time()
print(f"last_fetched_at: {last_fetched_at}") print(f"last_fetched_at: {last_fetched_at}")
response = requests.post(url="http://api.litellm.ai/get_model_list", headers={"content-type": "application/json"}, data=json.dumps({"user_email": user_email})) response = requests.post(
url="http://api.litellm.ai/get_model_list",
headers={"content-type": "application/json"},
data=json.dumps({"user_email": user_email}),
)
print_verbose(f"get_model_list response: {response.text}") print_verbose(f"get_model_list response: {response.text}")
data = response.json() data = response.json()
# update model list # update model list
@ -1224,12 +1289,14 @@ def get_model_list():
if f"{item.upper()}_API_KEY" not in os.environ: if f"{item.upper()}_API_KEY" not in os.environ:
missing_llm_provider = item missing_llm_provider = item
break break
# update environment - if required # update environment - if required
threading.Thread(target=get_all_keys, args=(missing_llm_provider)).start() threading.Thread(target=get_all_keys, args=(missing_llm_provider)).start()
return model_list return model_list
return [] # return empty list by default return [] # return empty list by default
except: except:
print_verbose(f"[Non-Blocking Error] get_all_keys error - {traceback.format_exc()}") print_verbose(
f"[Non-Blocking Error] get_all_keys error - {traceback.format_exc()}"
)
####### EXCEPTION MAPPING ################ ####### EXCEPTION MAPPING ################
@ -1253,36 +1320,33 @@ def exception_type(model, original_exception, custom_llm_provider):
exception_type = "" exception_type = ""
if "claude" in model: # one of the anthropics if "claude" in model: # one of the anthropics
if hasattr(original_exception, "status_code"): if hasattr(original_exception, "status_code"):
print_verbose( print_verbose(f"status_code: {original_exception.status_code}")
f"status_code: {original_exception.status_code}")
if original_exception.status_code == 401: if original_exception.status_code == 401:
exception_mapping_worked = True exception_mapping_worked = True
raise AuthenticationError( raise AuthenticationError(
message= message=f"AnthropicException - {original_exception.message}",
f"AnthropicException - {original_exception.message}",
llm_provider="anthropic", llm_provider="anthropic",
) )
elif original_exception.status_code == 400: elif original_exception.status_code == 400:
exception_mapping_worked = True exception_mapping_worked = True
raise InvalidRequestError( raise InvalidRequestError(
message= message=f"AnthropicException - {original_exception.message}",
f"AnthropicException - {original_exception.message}",
model=model, model=model,
llm_provider="anthropic", llm_provider="anthropic",
) )
elif original_exception.status_code == 429: elif original_exception.status_code == 429:
exception_mapping_worked = True exception_mapping_worked = True
raise RateLimitError( raise RateLimitError(
message= message=f"AnthropicException - {original_exception.message}",
f"AnthropicException - {original_exception.message}",
llm_provider="anthropic", llm_provider="anthropic",
) )
elif ("Could not resolve authentication method. Expected either api_key or auth_token to be set." elif (
in error_str): "Could not resolve authentication method. Expected either api_key or auth_token to be set."
in error_str
):
exception_mapping_worked = True exception_mapping_worked = True
raise AuthenticationError( raise AuthenticationError(
message= message=f"AnthropicException - {original_exception.message}",
f"AnthropicException - {original_exception.message}",
llm_provider="anthropic", llm_provider="anthropic",
) )
elif "replicate" in model: elif "replicate" in model:
@ -1306,36 +1370,35 @@ def exception_type(model, original_exception, custom_llm_provider):
llm_provider="replicate", llm_provider="replicate",
) )
elif ( elif (
exception_type == "ReplicateError" exception_type == "ReplicateError"
): # ReplicateError implies an error on Replicate server side, not user side ): # ReplicateError implies an error on Replicate server side, not user side
raise ServiceUnavailableError( raise ServiceUnavailableError(
message=f"ReplicateException - {error_str}", message=f"ReplicateException - {error_str}",
llm_provider="replicate", llm_provider="replicate",
) )
elif model == "command-nightly": # Cohere elif model == "command-nightly": # Cohere
if ("invalid api token" in error_str if (
or "No API key provided." in error_str): "invalid api token" in error_str
or "No API key provided." in error_str
):
exception_mapping_worked = True exception_mapping_worked = True
raise AuthenticationError( raise AuthenticationError(
message= message=f"CohereException - {original_exception.message}",
f"CohereException - {original_exception.message}",
llm_provider="cohere", llm_provider="cohere",
) )
elif "too many tokens" in error_str: elif "too many tokens" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
raise InvalidRequestError( raise InvalidRequestError(
message= message=f"CohereException - {original_exception.message}",
f"CohereException - {original_exception.message}",
model=model, model=model,
llm_provider="cohere", llm_provider="cohere",
) )
elif ( elif (
"CohereConnectionError" in exception_type "CohereConnectionError" in exception_type
): # cohere seems to fire these errors when we load test it (1k+ messages / min) ): # cohere seems to fire these errors when we load test it (1k+ messages / min)
exception_mapping_worked = True exception_mapping_worked = True
raise RateLimitError( raise RateLimitError(
message= message=f"CohereException - {original_exception.message}",
f"CohereException - {original_exception.message}",
llm_provider="cohere", llm_provider="cohere",
) )
elif custom_llm_provider == "huggingface": elif custom_llm_provider == "huggingface":
@ -1343,23 +1406,20 @@ def exception_type(model, original_exception, custom_llm_provider):
if original_exception.status_code == 401: if original_exception.status_code == 401:
exception_mapping_worked = True exception_mapping_worked = True
raise AuthenticationError( raise AuthenticationError(
message= message=f"HuggingfaceException - {original_exception.message}",
f"HuggingfaceException - {original_exception.message}",
llm_provider="huggingface", llm_provider="huggingface",
) )
elif original_exception.status_code == 400: elif original_exception.status_code == 400:
exception_mapping_worked = True exception_mapping_worked = True
raise InvalidRequestError( raise InvalidRequestError(
message= message=f"HuggingfaceException - {original_exception.message}",
f"HuggingfaceException - {original_exception.message}",
model=model, model=model,
llm_provider="huggingface", llm_provider="huggingface",
) )
elif original_exception.status_code == 429: elif original_exception.status_code == 429:
exception_mapping_worked = True exception_mapping_worked = True
raise RateLimitError( raise RateLimitError(
message= message=f"HuggingfaceException - {original_exception.message}",
f"HuggingfaceException - {original_exception.message}",
llm_provider="huggingface", llm_provider="huggingface",
) )
raise original_exception # base case - return the original exception raise original_exception # base case - return the original exception
@ -1375,8 +1435,10 @@ def exception_type(model, original_exception, custom_llm_provider):
}, },
exception=e, exception=e,
) )
## AUTH ERROR ## AUTH ERROR
if isinstance(e, AuthenticationError) and (litellm.email or "LITELLM_EMAIL" in os.environ): if isinstance(e, AuthenticationError) and (
litellm.email or "LITELLM_EMAIL" in os.environ
):
threading.Thread(target=get_all_keys, args=(e.llm_provider,)).start() threading.Thread(target=get_all_keys, args=(e.llm_provider,)).start()
if exception_mapping_worked: if exception_mapping_worked:
raise e raise e
@ -1391,7 +1453,8 @@ def safe_crash_reporting(model=None, exception=None, custom_llm_provider=None):
"exception": str(exception), "exception": str(exception),
"custom_llm_provider": custom_llm_provider, "custom_llm_provider": custom_llm_provider,
} }
threading.Thread(target=litellm_telemetry, args=(data, )).start() threading.Thread(target=litellm_telemetry, args=(data,)).start()
def get_or_generate_uuid(): def get_or_generate_uuid():
uuid_file = "litellm_uuid.txt" uuid_file = "litellm_uuid.txt"
@ -1445,8 +1508,7 @@ def get_secret(secret_name):
# TODO: check which secret manager is being used # TODO: check which secret manager is being used
# currently only supports Infisical # currently only supports Infisical
try: try:
secret = litellm.secret_manager_client.get_secret( secret = litellm.secret_manager_client.get_secret(secret_name).secret_value
secret_name).secret_value
except: except:
secret = None secret = None
return secret return secret
@ -1460,7 +1522,6 @@ def get_secret(secret_name):
# wraps the completion stream to return the correct format for the model # wraps the completion stream to return the correct format for the model
# replicate/anthropic/cohere # replicate/anthropic/cohere
class CustomStreamWrapper: class CustomStreamWrapper:
def __init__(self, completion_stream, model, custom_llm_provider=None): def __init__(self, completion_stream, model, custom_llm_provider=None):
self.model = model self.model = model
self.custom_llm_provider = custom_llm_provider self.custom_llm_provider = custom_llm_provider
@ -1509,8 +1570,9 @@ class CustomStreamWrapper:
elif self.model == "replicate": elif self.model == "replicate":
chunk = next(self.completion_stream) chunk = next(self.completion_stream)
completion_obj["content"] = chunk completion_obj["content"] = chunk
elif (self.custom_llm_provider and self.custom_llm_provider == "together_ai") or ("togethercomputer" elif (
in self.model): self.custom_llm_provider and self.custom_llm_provider == "together_ai"
) or ("togethercomputer" in self.model):
chunk = next(self.completion_stream) chunk = next(self.completion_stream)
text_data = self.handle_together_ai_chunk(chunk) text_data = self.handle_together_ai_chunk(chunk)
if text_data == "": if text_data == "":
@ -1545,9 +1607,9 @@ def read_config_args(config_path):
########## ollama implementation ############################ ########## ollama implementation ############################
async def get_ollama_response_stream(api_base="http://localhost:11434", async def get_ollama_response_stream(
model="llama2", api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?"
prompt="Why is the sky blue?"): ):
session = aiohttp.ClientSession() session = aiohttp.ClientSession()
url = f"{api_base}/api/generate" url = f"{api_base}/api/generate"
data = { data = {
@ -1570,11 +1632,7 @@ async def get_ollama_response_stream(api_base="http://localhost:11434",
"content": "", "content": "",
} }
completion_obj["content"] = j["response"] completion_obj["content"] = j["response"]
yield { yield {"choices": [{"delta": completion_obj}]}
"choices": [{
"delta": completion_obj
}]
}
# self.responses.append(j["response"]) # self.responses.append(j["response"])
# yield "blank" # yield "blank"
except Exception as e: except Exception as e: