From a69b7ffcfa2dbe2405f8eac8a5e4784baf8e18c0 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 28 Aug 2023 09:20:50 -0700 Subject: [PATCH] formatting improvements --- litellm/__init__.py | 8 +- litellm/_version.py | 1 + litellm/cache.py | 13 +- litellm/integrations/litedebugger.py | 39 +- litellm/integrations/llmonitor.py | 104 ++--- litellm/integrations/prompt_layer.py | 10 +- litellm/llms/anthropic.py | 6 +- litellm/llms/baseten.py | 51 ++- litellm/main.py | 41 +- litellm/tests/test_caching.py | 10 +- litellm/tests/test_completion.py | 6 +- litellm/tests/test_completion_with_retries.py | 8 +- litellm/tests/test_embedding.py | 4 +- litellm/tests/test_llmonitor_integration.py | 15 +- litellm/tests/test_model_alias_map.py | 18 +- litellm/tests/test_streaming.py | 55 ++- litellm/utils.py | 398 ++++++++++-------- 17 files changed, 464 insertions(+), 323 deletions(-) diff --git a/litellm/__init__.py b/litellm/__init__.py index e70f1b544d..baa3aa4af8 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -5,8 +5,12 @@ input_callback: List[str] = [] success_callback: List[str] = [] failure_callback: List[str] = [] set_verbose = False -email: Optional[str] = None # for hosted dashboard. Learn more - https://docs.litellm.ai/docs/debugging/hosted_debugging -token: Optional[str] = None # for hosted dashboard. Learn more - https://docs.litellm.ai/docs/debugging/hosted_debugging +email: Optional[ + str +] = None # for hosted dashboard. Learn more - https://docs.litellm.ai/docs/debugging/hosted_debugging +token: Optional[ + str +] = None # for hosted dashboard. Learn more - https://docs.litellm.ai/docs/debugging/hosted_debugging telemetry = True max_tokens = 256 # OpenAI Defaults retry = True diff --git a/litellm/_version.py b/litellm/_version.py index 2614375693..7b38bae2c2 100644 --- a/litellm/_version.py +++ b/litellm/_version.py @@ -1,4 +1,5 @@ import importlib_metadata + try: version = importlib_metadata.version("litellm") except: diff --git a/litellm/cache.py b/litellm/cache.py index 815c1e628b..e2c9d33fc1 100644 --- a/litellm/cache.py +++ b/litellm/cache.py @@ -1,20 +1,21 @@ - ###### LiteLLM Integration with GPT Cache ######### import gptcache + # openai.ChatCompletion._llm_handler = litellm.completion from gptcache.adapter import openai import litellm + class LiteLLMChatCompletion(gptcache.adapter.openai.ChatCompletion): @classmethod def _llm_handler(cls, *llm_args, **llm_kwargs): return litellm.completion(*llm_args, **llm_kwargs) - + + completion = LiteLLMChatCompletion.create ###### End of LiteLLM Integration with GPT Cache ######### - # ####### Example usage ############### # from gptcache import cache # completion = LiteLLMChatCompletion.create @@ -23,9 +24,3 @@ completion = LiteLLMChatCompletion.create # cache.set_openai_key() # result = completion(model="claude-2", messages=[{"role": "user", "content": "cto of litellm"}]) # print(result) - - - - - - diff --git a/litellm/integrations/litedebugger.py b/litellm/integrations/litedebugger.py index bea48061d5..5187d555f8 100644 --- a/litellm/integrations/litedebugger.py +++ b/litellm/integrations/litedebugger.py @@ -1,5 +1,6 @@ import requests, traceback, json, os + class LiteDebugger: user_email = None dashboard_url = None @@ -15,7 +16,9 @@ class LiteDebugger: self.user_email = os.getenv("LITELLM_EMAIL") or email self.dashboard_url = "https://admin.litellm.ai/" + self.user_email try: - print(f"\033[92mHere's your LiteLLM Dashboard 👉 \033[94m\033[4m{self.dashboard_url}\033[0m") + print( + f"\033[92mHere's your LiteLLM Dashboard 👉 \033[94m\033[4m{self.dashboard_url}\033[0m" + ) except: print(f"Here's your LiteLLM Dashboard 👉 {self.dashboard_url}") if self.user_email == None: @@ -28,17 +31,25 @@ class LiteDebugger: ) def input_log_event( - self, model, messages, end_user, litellm_call_id, print_verbose, litellm_params, optional_params + self, + model, + messages, + end_user, + litellm_call_id, + print_verbose, + litellm_params, + optional_params, ): try: print_verbose( f"LiteLLMDebugger: Logging - Enters input logging function for model {model}" ) + def remove_key_value(dictionary, key): new_dict = dictionary.copy() # Create a copy of the original dictionary new_dict.pop(key) # Remove the specified key-value pair from the copy return new_dict - + updated_litellm_params = remove_key_value(litellm_params, "logger_fn") litellm_data_obj = { @@ -49,7 +60,7 @@ class LiteDebugger: "litellm_call_id": litellm_call_id, "user_email": self.user_email, "litellm_params": updated_litellm_params, - "optional_params": optional_params + "optional_params": optional_params, } print_verbose( f"LiteLLMDebugger: Logging - logged data obj {litellm_data_obj}" @@ -65,10 +76,8 @@ class LiteDebugger: f"[Non-Blocking Error] LiteDebugger: Logging Error - {traceback.format_exc()}" ) pass - - def post_call_log_event( - self, original_response, litellm_call_id, print_verbose - ): + + def post_call_log_event(self, original_response, litellm_call_id, print_verbose): try: litellm_data_obj = { "status": "received", @@ -110,7 +119,7 @@ class LiteDebugger: "model": response_obj["model"], "total_cost": total_cost, "messages": messages, - "response": response['choices'][0]['message']['content'], + "response": response["choices"][0]["message"]["content"], "end_user": end_user, "litellm_call_id": litellm_call_id, "status": "success", @@ -124,7 +133,12 @@ class LiteDebugger: headers={"content-type": "application/json"}, data=json.dumps(litellm_data_obj), ) - elif "data" in response_obj and isinstance(response_obj["data"], list) and len(response_obj["data"]) > 0 and "embedding" in response_obj["data"][0]: + elif ( + "data" in response_obj + and isinstance(response_obj["data"], list) + and len(response_obj["data"]) > 0 + and "embedding" in response_obj["data"][0] + ): print(f"messages: {messages}") litellm_data_obj = { "response_time": response_time, @@ -145,7 +159,10 @@ class LiteDebugger: headers={"content-type": "application/json"}, data=json.dumps(litellm_data_obj), ) - elif isinstance(response_obj, object) and response_obj.__class__.__name__ == "CustomStreamWrapper": + elif ( + isinstance(response_obj, object) + and response_obj.__class__.__name__ == "CustomStreamWrapper" + ): litellm_data_obj = { "response_time": response_time, "total_cost": total_cost, diff --git a/litellm/integrations/llmonitor.py b/litellm/integrations/llmonitor.py index d166e18880..22acf874c0 100644 --- a/litellm/integrations/llmonitor.py +++ b/litellm/integrations/llmonitor.py @@ -12,20 +12,17 @@ dotenv.load_dotenv() # Loading env variables using dotenv # convert to {completion: xx, tokens: xx} def parse_usage(usage): return { - "completion": - usage["completion_tokens"] if "completion_tokens" in usage else 0, - "prompt": - usage["prompt_tokens"] if "prompt_tokens" in usage else 0, + "completion": usage["completion_tokens"] if "completion_tokens" in usage else 0, + "prompt": usage["prompt_tokens"] if "prompt_tokens" in usage else 0, } def parse_messages(input): - if input is None: return None def clean_message(message): - #if is strin, return as is + # if is strin, return as is if isinstance(message, str): return message @@ -50,75 +47,78 @@ class LLMonitorLogger: # Class variables or attributes def __init__(self): # Instance variables - self.api_url = os.getenv( - "LLMONITOR_API_URL") or "https://app.llmonitor.com" + self.api_url = os.getenv("LLMONITOR_API_URL") or "https://app.llmonitor.com" self.app_id = os.getenv("LLMONITOR_APP_ID") def log_event( - self, - type, - event, - run_id, - model, - print_verbose, - input=None, - user_id=None, - response_obj=None, - start_time=datetime.datetime.now(), - end_time=datetime.datetime.now(), - error=None, + self, + type, + event, + run_id, + model, + print_verbose, + input=None, + user_id=None, + response_obj=None, + start_time=datetime.datetime.now(), + end_time=datetime.datetime.now(), + error=None, ): # Method definition try: - print_verbose( - f"LLMonitor Logging - Logging request for model {model}") + print_verbose(f"LLMonitor Logging - Logging request for model {model}") if response_obj: - usage = parse_usage( - response_obj['usage']) if 'usage' in response_obj else None - output = response_obj[ - 'choices'] if 'choices' in response_obj else None + usage = ( + parse_usage(response_obj["usage"]) + if "usage" in response_obj + else None + ) + output = response_obj["choices"] if "choices" in response_obj else None else: usage = None output = None if error: - error_obj = {'stack': error} + error_obj = {"stack": error} else: error_obj = None - data = [{ - "type": type, - "name": model, - "runId": run_id, - "app": self.app_id, - 'event': 'start', - "timestamp": start_time.isoformat(), - "userId": user_id, - "input": parse_messages(input), - }, { - "type": type, - "runId": run_id, - "app": self.app_id, - "event": event, - "error": error_obj, - "timestamp": end_time.isoformat(), - "userId": user_id, - "output": parse_messages(output), - "tokensUsage": usage, - }] + data = [ + { + "type": type, + "name": model, + "runId": run_id, + "app": self.app_id, + "event": "start", + "timestamp": start_time.isoformat(), + "userId": user_id, + "input": parse_messages(input), + }, + { + "type": type, + "runId": run_id, + "app": self.app_id, + "event": event, + "error": error_obj, + "timestamp": end_time.isoformat(), + "userId": user_id, + "output": parse_messages(output), + "tokensUsage": usage, + }, + ] # print_verbose(f"LLMonitor Logging - final data object: {data}") response = requests.post( - self.api_url + '/api/report', - headers={'Content-Type': 'application/json'}, - json={'events': data}) + self.api_url + "/api/report", + headers={"Content-Type": "application/json"}, + json={"events": data}, + ) print_verbose(f"LLMonitor Logging - response: {response}") except: # traceback.print_exc() - print_verbose( - f"LLMonitor Logging Error - {traceback.format_exc()}") + print_verbose(f"LLMonitor Logging Error - {traceback.format_exc()}") pass diff --git a/litellm/integrations/prompt_layer.py b/litellm/integrations/prompt_layer.py index e1cdb666ee..e725378384 100644 --- a/litellm/integrations/prompt_layer.py +++ b/litellm/integrations/prompt_layer.py @@ -7,6 +7,7 @@ import requests dotenv.load_dotenv() # Loading env variables using dotenv import traceback + class PromptLayerLogger: # Class variables or attributes def __init__(self): @@ -26,7 +27,9 @@ class PromptLayerLogger: "function_name": "openai.ChatCompletion.create", "kwargs": kwargs, "tags": ["hello", "world"], - "request_response": dict(response_obj), # TODO: Check if we need a dict + "request_response": dict( + response_obj + ), # TODO: Check if we need a dict "request_start_time": int(start_time.timestamp()), "request_end_time": int(end_time.timestamp()), "api_key": self.key, @@ -34,11 +37,12 @@ class PromptLayerLogger: # "prompt_id": "", # "prompt_input_variables": "", # "prompt_version":1, - }, ) - print_verbose(f"Prompt Layer Logging - final response object: {request_response}") + print_verbose( + f"Prompt Layer Logging - final response object: {request_response}" + ) except: # traceback.print_exc() print_verbose(f"Prompt Layer Error - {traceback.format_exc()}") diff --git a/litellm/llms/anthropic.py b/litellm/llms/anthropic.py index fc695374b7..c4e2ece7c0 100644 --- a/litellm/llms/anthropic.py +++ b/litellm/llms/anthropic.py @@ -94,7 +94,10 @@ class AnthropicLLM: ## COMPLETION CALL if "stream" in optional_params and optional_params["stream"] == True: response = requests.post( - self.completion_url, headers=self.headers, data=json.dumps(data), stream=optional_params["stream"] + self.completion_url, + headers=self.headers, + data=json.dumps(data), + stream=optional_params["stream"], ) return response.iter_lines() else: @@ -142,4 +145,3 @@ class AnthropicLLM: self, ): # logic for parsing in - calling - parsing out model embedding calls pass - diff --git a/litellm/llms/baseten.py b/litellm/llms/baseten.py index 218efa6832..cc0fcec8d1 100644 --- a/litellm/llms/baseten.py +++ b/litellm/llms/baseten.py @@ -5,6 +5,7 @@ import time from typing import Callable from litellm.utils import ModelResponse + class BasetenError(Exception): def __init__(self, status_code, message): self.status_code = status_code @@ -15,9 +16,7 @@ class BasetenError(Exception): class BasetenLLM: - def __init__( - self, encoding, logging_obj, api_key=None - ): + def __init__(self, encoding, logging_obj, api_key=None): self.encoding = encoding self.completion_url_fragment_1 = "https://app.baseten.co/models/" self.completion_url_fragment_2 = "/predict" @@ -55,13 +54,9 @@ class BasetenLLM: for message in messages: if "role" in message: if message["role"] == "user": - prompt += ( - f"{message['content']}" - ) + prompt += f"{message['content']}" else: - prompt += ( - f"{message['content']}" - ) + prompt += f"{message['content']}" else: prompt += f"{message['content']}" data = { @@ -78,7 +73,9 @@ class BasetenLLM: ) ## COMPLETION CALL response = requests.post( - self.completion_url_fragment_1 + model + self.completion_url_fragment_2, headers=self.headers, data=json.dumps(data) + self.completion_url_fragment_1 + model + self.completion_url_fragment_2, + headers=self.headers, + data=json.dumps(data), ) if "stream" in optional_params and optional_params["stream"] == True: return response.iter_lines() @@ -100,19 +97,33 @@ class BasetenLLM: ) else: if "model_output" in completion_response: - if isinstance(completion_response["model_output"], dict) and "data" in completion_response["model_output"] and isinstance(completion_response["model_output"]["data"], list): - model_response["choices"][0]["message"]["content"] = completion_response["model_output"]["data"][0] + if ( + isinstance(completion_response["model_output"], dict) + and "data" in completion_response["model_output"] + and isinstance( + completion_response["model_output"]["data"], list + ) + ): + model_response["choices"][0]["message"][ + "content" + ] = completion_response["model_output"]["data"][0] elif isinstance(completion_response["model_output"], str): - model_response["choices"][0]["message"]["content"] = completion_response["model_output"] - elif "completion" in completion_response and isinstance(completion_response["completion"], str): - model_response["choices"][0]["message"]["content"] = completion_response["completion"] + model_response["choices"][0]["message"][ + "content" + ] = completion_response["model_output"] + elif "completion" in completion_response and isinstance( + completion_response["completion"], str + ): + model_response["choices"][0]["message"][ + "content" + ] = completion_response["completion"] else: - raise ValueError(f"Unable to parse response. Original response: {response.text}") + raise ValueError( + f"Unable to parse response. Original response: {response.text}" + ) - ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. - prompt_tokens = len( - self.encoding.encode(prompt) - ) + ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. + prompt_tokens = len(self.encoding.encode(prompt)) completion_tokens = len( self.encoding.encode(model_response["choices"][0]["message"]["content"]) ) diff --git a/litellm/main.py b/litellm/main.py index f12725e61a..fa8cc847b5 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -103,7 +103,9 @@ def completion( return completion_with_fallbacks(**args) if litellm.model_alias_map and model in litellm.model_alias_map: args["model_alias_map"] = litellm.model_alias_map - model = litellm.model_alias_map[model] # update the model to the actual value if an alias has been passed in + model = litellm.model_alias_map[ + model + ] # update the model to the actual value if an alias has been passed in model_response = ModelResponse() if azure: # this flag is deprecated, remove once notebooks are also updated. custom_llm_provider = "azure" @@ -146,7 +148,7 @@ def completion( custom_llm_provider=custom_llm_provider, custom_api_base=custom_api_base, litellm_call_id=litellm_call_id, - model_alias_map=litellm.model_alias_map + model_alias_map=litellm.model_alias_map, ) logging = Logging( model=model, @@ -216,7 +218,10 @@ def completion( # note: if a user sets a custom base - we should ensure this works # allow for the setting of dynamic and stateful api-bases api_base = ( - custom_api_base or litellm.api_base or get_secret("OPENAI_API_BASE") or "https://api.openai.com/v1" + custom_api_base + or litellm.api_base + or get_secret("OPENAI_API_BASE") + or "https://api.openai.com/v1" ) openai.api_base = api_base openai.api_version = None @@ -255,9 +260,11 @@ def completion( original_response=response, additional_args={"headers": litellm.headers}, ) - elif (model in litellm.open_ai_text_completion_models or - "ft:babbage-002" in model or # support for finetuned completion models - "ft:davinci-002" in model): + elif ( + model in litellm.open_ai_text_completion_models + or "ft:babbage-002" in model + or "ft:davinci-002" in model # support for finetuned completion models + ): openai.api_type = "openai" openai.api_base = ( litellm.api_base @@ -544,7 +551,10 @@ def completion( logging.pre_call(input=prompt, api_key=TOGETHER_AI_TOKEN) print(f"TOGETHER_AI_TOKEN: {TOGETHER_AI_TOKEN}") - if "stream_tokens" in optional_params and optional_params["stream_tokens"] == True: + if ( + "stream_tokens" in optional_params + and optional_params["stream_tokens"] == True + ): res = requests.post( endpoint, json={ @@ -698,9 +708,7 @@ def completion( ): custom_llm_provider = "baseten" baseten_key = ( - api_key - or litellm.baseten_key - or os.environ.get("BASETEN_API_KEY") + api_key or litellm.baseten_key or os.environ.get("BASETEN_API_KEY") ) baseten_client = BasetenLLM( encoding=encoding, api_key=baseten_key, logging_obj=logging @@ -767,11 +775,14 @@ def completion( model=model, custom_llm_provider=custom_llm_provider, original_exception=e ) + def completion_with_retries(*args, **kwargs): import tenacity + retryer = tenacity.Retrying(stop=tenacity.stop_after_attempt(3), reraise=True) return retryer(completion, *args, **kwargs) + def batch_completion(*args, **kwargs): batch_messages = args[1] if len(args) > 1 else kwargs.get("messages") completions = [] @@ -865,14 +876,16 @@ def embedding( custom_llm_provider="azure" if azure == True else None, ) + ###### Text Completion ################ def text_completion(*args, **kwargs): - if 'prompt' in kwargs: - messages = [{'role': 'system', 'content': kwargs['prompt']}] - kwargs['messages'] = messages - kwargs.pop('prompt') + if "prompt" in kwargs: + messages = [{"role": "system", "content": kwargs["prompt"]}] + kwargs["messages"] = messages + kwargs.pop("prompt") return completion(*args, **kwargs) + ####### HELPER FUNCTIONS ################ ## Set verbose to true -> ```litellm.set_verbose = True``` def print_verbose(print_statement): diff --git a/litellm/tests/test_caching.py b/litellm/tests/test_caching.py index 8c4a428dfd..ef9b24a436 100644 --- a/litellm/tests/test_caching.py +++ b/litellm/tests/test_caching.py @@ -14,6 +14,7 @@ from litellm import embedding, completion messages = [{"role": "user", "content": "who is ishaan Github? "}] + # test if response cached def test_caching(): try: @@ -50,14 +51,16 @@ def test_caching_with_models(): print(f"response1: {response1}") print(f"response2: {response2}") pytest.fail(f"Error occurred:") -# test_caching_with_models() +# test_caching_with_models() + def test_gpt_cache(): # INIT GPT Cache # from gptcache import cache from litellm.cache import completion + cache.init() cache.set_openai_key() @@ -67,10 +70,11 @@ def test_gpt_cache(): print(f"response2: {response2}") print(f"response3: {response3}") - if response3['choices'] != response2['choices']: + if response3["choices"] != response2["choices"]: # if models are different, it should not return cached response print(f"response2: {response2}") print(f"response3: {response3}") pytest.fail(f"Error occurred:") -# test_gpt_cache() + +# test_gpt_cache() diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index e59448450d..ca066db245 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -142,9 +142,12 @@ def test_completion_openai(): except Exception as e: pytest.fail(f"Error occurred: {e}") + def test_completion_openai_prompt(): try: - response = text_completion(model="gpt-3.5-turbo", prompt="What's the weather in SF?") + response = text_completion( + model="gpt-3.5-turbo", prompt="What's the weather in SF?" + ) response_str = response["choices"][0]["message"]["content"] response_str_2 = response.choices[0].message.content print(response) @@ -154,6 +157,7 @@ def test_completion_openai_prompt(): except Exception as e: pytest.fail(f"Error occurred: {e}") + def test_completion_text_openai(): try: response = completion(model="text-davinci-003", messages=messages) diff --git a/litellm/tests/test_completion_with_retries.py b/litellm/tests/test_completion_with_retries.py index 4d3d553990..bfc077b1d2 100644 --- a/litellm/tests/test_completion_with_retries.py +++ b/litellm/tests/test_completion_with_retries.py @@ -27,7 +27,7 @@ # # print(f"user_model_dict: {user_model_dict}") # pass -# # normal call +# # normal call # def test_completion_custom_provider_model_name(): # try: # response = completion_with_retries( @@ -41,7 +41,7 @@ # pytest.fail(f"Error occurred: {e}") -# # bad call +# # bad call # # def test_completion_custom_provider_model_name(): # # try: # # response = completion_with_retries( @@ -54,7 +54,7 @@ # # except Exception as e: # # pytest.fail(f"Error occurred: {e}") -# # impact on exception mapping +# # impact on exception mapping # def test_context_window(): # sample_text = "how does a court case get to the Supreme Court?" * 5000 # messages = [{"content": sample_text, "role": "user"}] @@ -83,4 +83,4 @@ # test_context_window() -# test_completion_custom_provider_model_name() \ No newline at end of file +# test_completion_custom_provider_model_name() diff --git a/litellm/tests/test_embedding.py b/litellm/tests/test_embedding.py index faa5760b28..705828a73e 100644 --- a/litellm/tests/test_embedding.py +++ b/litellm/tests/test_embedding.py @@ -22,4 +22,6 @@ def test_openai_embedding(): # print(f"response: {str(response)}") except Exception as e: pytest.fail(f"Error occurred: {e}") -test_openai_embedding() \ No newline at end of file + + +test_openai_embedding() diff --git a/litellm/tests/test_llmonitor_integration.py b/litellm/tests/test_llmonitor_integration.py index 3597019585..a3183bf8f3 100644 --- a/litellm/tests/test_llmonitor_integration.py +++ b/litellm/tests/test_llmonitor_integration.py @@ -4,7 +4,7 @@ import sys import os -sys.path.insert(0, os.path.abspath('../..')) +sys.path.insert(0, os.path.abspath("../..")) from litellm import completion, embedding import litellm @@ -17,11 +17,10 @@ litellm.set_verbose = True def test_chat_openai(): try: - response = completion(model="gpt-3.5-turbo", - messages=[{ - "role": "user", - "content": "Hi 👋 - i'm openai" - }]) + response = completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}], + ) print(response) @@ -31,7 +30,7 @@ def test_chat_openai(): def test_embedding_openai(): try: - response = embedding(model="text-embedding-ada-002", input=['test']) + response = embedding(model="text-embedding-ada-002", input=["test"]) # Add any assertions here to check the response print(f"response: {str(response)[:50]}") except Exception as e: @@ -39,4 +38,4 @@ def test_embedding_openai(): test_chat_openai() -test_embedding_openai() \ No newline at end of file +test_embedding_openai() diff --git a/litellm/tests/test_model_alias_map.py b/litellm/tests/test_model_alias_map.py index b49a9a9d80..a7254df6d5 100644 --- a/litellm/tests/test_model_alias_map.py +++ b/litellm/tests/test_model_alias_map.py @@ -13,5 +13,19 @@ from litellm import embedding, completion litellm.set_verbose = True # Test: Check if the alias created via LiteDebugger is mapped correctly -{"top_p": 0.75, "prompt": "What's the meaning of life?", "num_beams": 4, "temperature": 0.1} -print(completion("llama2", messages=[{"role": "user", "content": "Hey, how's it going?"}], top_p=0.1, temperature=0, num_beams=4, max_tokens=60)) \ No newline at end of file +{ + "top_p": 0.75, + "prompt": "What's the meaning of life?", + "num_beams": 4, + "temperature": 0.1, +} +print( + completion( + "llama2", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + top_p=0.1, + temperature=0, + num_beams=4, + max_tokens=60, + ) +) diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index b6e37a7e86..f6a676177a 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -3,12 +3,14 @@ import sys, os import traceback -import time +import time + sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path import litellm from litellm import completion + litellm.logging = False litellm.set_verbose = False @@ -31,11 +33,11 @@ messages = [{"content": user_message, "role": "user"}] # complete_response = "" # start_time = time.time() # for chunk in response: -# chunk_time = time.time() +# chunk_time = time.time() # print(f"time since initial request: {chunk_time - start_time:.5f}") # print(chunk["choices"][0]["delta"]) # complete_response += chunk["choices"][0]["delta"]["content"] -# if complete_response == "": +# if complete_response == "": # raise Exception("Empty response received") # except: # print(f"error occurred: {traceback.format_exc()}") @@ -50,11 +52,11 @@ messages = [{"content": user_message, "role": "user"}] # response = "" # start_time = time.time() # for chunk in response: -# chunk_time = time.time() +# chunk_time = time.time() # print(f"time since initial request: {chunk_time - start_time:.2f}") # print(chunk["choices"][0]["delta"]) # response += chunk["choices"][0]["delta"] -# if response == "": +# if response == "": # raise Exception("Empty response received") # except: # print(f"error occurred: {traceback.format_exc()}") @@ -73,7 +75,7 @@ try: print(f"time since initial request: {chunk_time - start_time:.5f}") print(chunk["choices"][0]["delta"]) complete_response += chunk["choices"][0]["delta"]["content"] - if complete_response == "": + if complete_response == "": raise Exception("Empty response received") except: print(f"error occurred: {traceback.format_exc()}") @@ -88,11 +90,11 @@ except: # ) # complete_response = "" # for chunk in response: -# chunk_time = time.time() +# chunk_time = time.time() # print(f"time since initial request: {chunk_time - start_time:.2f}") # print(chunk["choices"][0]["delta"]) -# complete_response += chunk["choices"][0]["delta"]["content"] if len(chunk["choices"][0]["delta"].keys()) > 0 else "" -# if complete_response == "": +# complete_response += chunk["choices"][0]["delta"]["content"] if len(chunk["choices"][0]["delta"].keys()) > 0 else "" +# if complete_response == "": # raise Exception("Empty response received") # except: # print(f"error occurred: {traceback.format_exc()}") @@ -102,16 +104,20 @@ except: try: start_time = time.time() response = completion( - model="Replit-Code-3B", messages=messages, logger_fn=logger_fn, stream= True + model="Replit-Code-3B", messages=messages, logger_fn=logger_fn, stream=True ) complete_response = "" print(f"returned response object: {response}") for chunk in response: - chunk_time = time.time() + chunk_time = time.time() print(f"time since initial request: {chunk_time - start_time:.2f}") print(chunk["choices"][0]["delta"]) - complete_response += chunk["choices"][0]["delta"]["content"] if len(chunk["choices"][0]["delta"].keys()) > 0 else "" - if complete_response == "": + complete_response += ( + chunk["choices"][0]["delta"]["content"] + if len(chunk["choices"][0]["delta"].keys()) > 0 + else "" + ) + if complete_response == "": raise Exception("Empty response received") except: print(f"error occurred: {traceback.format_exc()}") @@ -121,16 +127,23 @@ except: try: start_time = time.time() response = completion( - model="together_ai/bigcode/starcoder", messages=messages, logger_fn=logger_fn, stream= True + model="together_ai/bigcode/starcoder", + messages=messages, + logger_fn=logger_fn, + stream=True, ) complete_response = "" print(f"returned response object: {response}") for chunk in response: - chunk_time = time.time() - complete_response += chunk["choices"][0]["delta"]["content"] if len(chunk["choices"][0]["delta"].keys()) > 0 else "" + chunk_time = time.time() + complete_response += ( + chunk["choices"][0]["delta"]["content"] + if len(chunk["choices"][0]["delta"].keys()) > 0 + else "" + ) if len(complete_response) > 0: print(complete_response) - if complete_response == "": + if complete_response == "": raise Exception("Empty response received") except: print(f"error occurred: {traceback.format_exc()}") @@ -144,11 +157,11 @@ except: # ) # response = "" # for chunk in response: -# chunk_time = time.time() +# chunk_time = time.time() # print(f"time since initial request: {chunk_time - start_time:.2f}") # print(chunk["choices"][0]["delta"]) # response += chunk["choices"][0]["delta"] -# if response == "": +# if response == "": # raise Exception("Empty response received") # except: # print(f"error occurred: {traceback.format_exc()}") @@ -162,11 +175,11 @@ except: # ) # response = "" # for chunk in response: -# chunk_time = time.time() +# chunk_time = time.time() # print(f"time since initial request: {chunk_time - start_time:.2f}") # print(chunk["choices"][0]["delta"]) # response += chunk["choices"][0]["delta"] -# if response == "": +# if response == "": # raise Exception("Empty response received") # except: # print(f"error occurred: {traceback.format_exc()}") diff --git a/litellm/utils.py b/litellm/utils.py index eac0079cf8..7e85decf5f 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -69,7 +69,6 @@ last_fetched_at_keys = None class Message(OpenAIObject): - def __init__(self, content="default", role="assistant", **params): super(Message, self).__init__(**params) self.content = content @@ -77,12 +76,7 @@ class Message(OpenAIObject): class Choices(OpenAIObject): - - def __init__(self, - finish_reason="stop", - index=0, - message=Message(), - **params): + def __init__(self, finish_reason="stop", index=0, message=Message(), **params): super(Choices, self).__init__(**params) self.finish_reason = finish_reason self.index = index @@ -90,22 +84,20 @@ class Choices(OpenAIObject): class ModelResponse(OpenAIObject): - - def __init__(self, - choices=None, - created=None, - model=None, - usage=None, - **params): + def __init__(self, choices=None, created=None, model=None, usage=None, **params): super(ModelResponse, self).__init__(**params) self.choices = choices if choices else [Choices()] self.created = created self.model = model - self.usage = (usage if usage else { - "prompt_tokens": None, - "completion_tokens": None, - "total_tokens": None, - }) + self.usage = ( + usage + if usage + else { + "prompt_tokens": None, + "completion_tokens": None, + "total_tokens": None, + } + ) def to_dict_recursive(self): d = super().to_dict_recursive() @@ -173,7 +165,9 @@ class Logging: self.model_call_details["api_key"] = api_key self.model_call_details["additional_args"] = additional_args - if model: # if model name was changes pre-call, overwrite the initial model call name with the new one + if ( + model + ): # if model name was changes pre-call, overwrite the initial model call name with the new one self.model_call_details["model"] = model # User Logging -> if you pass in a custom logging function @@ -203,8 +197,7 @@ class Logging: model=model, messages=messages, end_user=litellm._thread_context.user, - litellm_call_id=self. - litellm_params["litellm_call_id"], + litellm_call_id=self.litellm_params["litellm_call_id"], print_verbose=print_verbose, ) @@ -217,8 +210,7 @@ class Logging: model=model, messages=messages, end_user=litellm._thread_context.user, - litellm_call_id=self. - litellm_params["litellm_call_id"], + litellm_call_id=self.litellm_params["litellm_call_id"], litellm_params=self.model_call_details["litellm_params"], optional_params=self.model_call_details["optional_params"], print_verbose=print_verbose, @@ -263,7 +255,7 @@ class Logging: print_verbose( f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {traceback.format_exc()}" ) - + # Input Integration Logging -> If you want to log the fact that an attempt to call the model was made for callback in litellm.input_callback: try: @@ -274,8 +266,7 @@ class Logging: print_verbose(f"liteDebuggerClient: {liteDebuggerClient}") liteDebuggerClient.post_call_log_event( original_response=original_response, - litellm_call_id=self. - litellm_params["litellm_call_id"], + litellm_call_id=self.litellm_params["litellm_call_id"], print_verbose=print_verbose, ) except: @@ -295,6 +286,7 @@ class Logging: # Add more methods as needed + def exception_logging( additional_args={}, logger_fn=None, @@ -329,13 +321,18 @@ def exception_logging( # make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking def client(original_function): global liteDebuggerClient, get_all_keys - + def function_setup( *args, **kwargs ): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc. try: global callback_list, add_breadcrumb, user_logger_fn - if litellm.email is not None or os.getenv("LITELLM_EMAIL", None) is not None or litellm.token is not None or os.getenv("LITELLM_TOKEN", None): # add to input, success and failure callbacks if user is using hosted product + if ( + litellm.email is not None + or os.getenv("LITELLM_EMAIL", None) is not None + or litellm.token is not None + or os.getenv("LITELLM_TOKEN", None) + ): # add to input, success and failure callbacks if user is using hosted product get_all_keys() if "lite_debugger" not in callback_list and litellm.logging: litellm.input_callback.append("lite_debugger") @@ -381,11 +378,12 @@ def client(original_function): if litellm.telemetry: try: model = args[0] if len(args) > 0 else kwargs["model"] - exception = kwargs[ - "exception"] if "exception" in kwargs else None - custom_llm_provider = (kwargs["custom_llm_provider"] - if "custom_llm_provider" in kwargs else - None) + exception = kwargs["exception"] if "exception" in kwargs else None + custom_llm_provider = ( + kwargs["custom_llm_provider"] + if "custom_llm_provider" in kwargs + else None + ) safe_crash_reporting( model=model, exception=exception, @@ -410,10 +408,10 @@ def client(original_function): def check_cache(*args, **kwargs): try: # never block execution prompt = get_prompt(*args, **kwargs) - if (prompt != None): # check if messages / prompt exists + if prompt != None: # check if messages / prompt exists if litellm.caching_with_models: # if caching with model names is enabled, key is prompt + model name - if ("model" in kwargs): + if "model" in kwargs: cache_key = prompt + kwargs["model"] if cache_key in local_cache: return local_cache[cache_key] @@ -423,7 +421,7 @@ def client(original_function): return result else: return None - return None # default to return None + return None # default to return None except: return None @@ -431,7 +429,7 @@ def client(original_function): try: # never block execution prompt = get_prompt(*args, **kwargs) if litellm.caching_with_models: # caching with model + prompt - if ("model" in kwargs): + if "model" in kwargs: cache_key = prompt + kwargs["model"] local_cache[cache_key] = result else: # caching based only on prompts @@ -449,7 +447,8 @@ def client(original_function): start_time = datetime.datetime.now() # [OPTIONAL] CHECK CACHE if (litellm.caching or litellm.caching_with_models) and ( - cached_result := check_cache(*args, **kwargs)) is not None: + cached_result := check_cache(*args, **kwargs) + ) is not None: result = cached_result return result # MODEL CALL @@ -458,25 +457,22 @@ def client(original_function): return result end_time = datetime.datetime.now() # [OPTIONAL] ADD TO CACHE - if (litellm.caching or litellm.caching_with_models): + if litellm.caching or litellm.caching_with_models: add_cache(result, *args, **kwargs) # LOG SUCCESS my_thread = threading.Thread( - target=handle_success, - args=(args, kwargs, result, start_time, - end_time)) # don't interrupt execution of main thread + target=handle_success, args=(args, kwargs, result, start_time, end_time) + ) # don't interrupt execution of main thread my_thread.start() # RETURN RESULT return result except Exception as e: - traceback_exception = traceback.format_exc() crash_reporting(*args, **kwargs, exception=traceback_exception) end_time = datetime.datetime.now() my_thread = threading.Thread( target=handle_failure, - args=(e, traceback_exception, start_time, end_time, args, - kwargs), + args=(e, traceback_exception, start_time, end_time, args, kwargs), ) # don't interrupt execution of main thread my_thread.start() if hasattr(e, "message"): @@ -506,18 +502,18 @@ def token_counter(model, text): return num_tokens -def cost_per_token(model="gpt-3.5-turbo", - prompt_tokens=0, - completion_tokens=0): +def cost_per_token(model="gpt-3.5-turbo", prompt_tokens=0, completion_tokens=0): # given prompt_tokens_cost_usd_dollar = 0 completion_tokens_cost_usd_dollar = 0 model_cost_ref = litellm.model_cost if model in model_cost_ref: prompt_tokens_cost_usd_dollar = ( - model_cost_ref[model]["input_cost_per_token"] * prompt_tokens) + model_cost_ref[model]["input_cost_per_token"] * prompt_tokens + ) completion_tokens_cost_usd_dollar = ( - model_cost_ref[model]["output_cost_per_token"] * completion_tokens) + model_cost_ref[model]["output_cost_per_token"] * completion_tokens + ) return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar else: # calculate average input cost @@ -538,9 +534,8 @@ def completion_cost(model="gpt-3.5-turbo", prompt="", completion=""): prompt_tokens = token_counter(model=model, text=prompt) completion_tokens = token_counter(model=model, text=completion) prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar = cost_per_token( - model=model, - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens) + model=model, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens + ) return prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar @@ -558,7 +553,7 @@ def get_litellm_params( custom_llm_provider=None, custom_api_base=None, litellm_call_id=None, - model_alias_map=None + model_alias_map=None, ): litellm_params = { "return_async": return_async, @@ -569,13 +564,13 @@ def get_litellm_params( "custom_llm_provider": custom_llm_provider, "custom_api_base": custom_api_base, "litellm_call_id": litellm_call_id, - "model_alias_map": model_alias_map + "model_alias_map": model_alias_map, } return litellm_params -def get_optional_params( # use the openai defaults +def get_optional_params( # use the openai defaults # 12 optional params functions=[], function_call="", @@ -588,7 +583,7 @@ def get_optional_params( # use the openai defaults presence_penalty=0, frequency_penalty=0, logit_bias={}, - num_beams=1, + num_beams=1, user="", deployment_id=None, model=None, @@ -635,8 +630,9 @@ def get_optional_params( # use the openai defaults optional_params["max_tokens"] = max_tokens if frequency_penalty != 0: optional_params["frequency_penalty"] = frequency_penalty - elif (model == "chat-bison" - ): # chat-bison has diff args from chat-bison@001 ty Google + elif ( + model == "chat-bison" + ): # chat-bison has diff args from chat-bison@001 ty Google if temperature != 1: optional_params["temperature"] = temperature if top_p != 1: @@ -702,10 +698,7 @@ def load_test_model( test_prompt = prompt if num_calls: test_calls = num_calls - messages = [[{ - "role": "user", - "content": test_prompt - }] for _ in range(test_calls)] + messages = [[{"role": "user", "content": test_prompt}] for _ in range(test_calls)] start_time = time.time() try: litellm.batch_completion( @@ -743,15 +736,17 @@ def set_callbacks(callback_list): try: import sentry_sdk except ImportError: - print_verbose( - "Package 'sentry_sdk' is missing. Installing it...") + print_verbose("Package 'sentry_sdk' is missing. Installing it...") subprocess.check_call( - [sys.executable, "-m", "pip", "install", "sentry_sdk"]) + [sys.executable, "-m", "pip", "install", "sentry_sdk"] + ) import sentry_sdk sentry_sdk_instance = sentry_sdk - sentry_trace_rate = (os.environ.get("SENTRY_API_TRACE_RATE") - if "SENTRY_API_TRACE_RATE" in os.environ - else "1.0") + sentry_trace_rate = ( + os.environ.get("SENTRY_API_TRACE_RATE") + if "SENTRY_API_TRACE_RATE" in os.environ + else "1.0" + ) sentry_sdk_instance.init( dsn=os.environ.get("SENTRY_API_URL"), traces_sample_rate=float(sentry_trace_rate), @@ -762,10 +757,10 @@ def set_callbacks(callback_list): try: from posthog import Posthog except ImportError: - print_verbose( - "Package 'posthog' is missing. Installing it...") + print_verbose("Package 'posthog' is missing. Installing it...") subprocess.check_call( - [sys.executable, "-m", "pip", "install", "posthog"]) + [sys.executable, "-m", "pip", "install", "posthog"] + ) from posthog import Posthog posthog = Posthog( project_api_key=os.environ.get("POSTHOG_API_KEY"), @@ -775,10 +770,10 @@ def set_callbacks(callback_list): try: from slack_bolt import App except ImportError: - print_verbose( - "Package 'slack_bolt' is missing. Installing it...") + print_verbose("Package 'slack_bolt' is missing. Installing it...") subprocess.check_call( - [sys.executable, "-m", "pip", "install", "slack_bolt"]) + [sys.executable, "-m", "pip", "install", "slack_bolt"] + ) from slack_bolt import App slack_app = App( token=os.environ.get("SLACK_API_TOKEN"), @@ -809,8 +804,7 @@ def set_callbacks(callback_list): raise e -def handle_failure(exception, traceback_exception, start_time, end_time, args, - kwargs): +def handle_failure(exception, traceback_exception, start_time, end_time, args, kwargs): global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient try: # print_verbose(f"handle_failure args: {args}") @@ -820,7 +814,8 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args, failure_handler = additional_details.pop("failure_handler", None) additional_details["Event_Name"] = additional_details.pop( - "failed_event_name", "litellm.failed_query") + "failed_event_name", "litellm.failed_query" + ) print_verbose(f"self.failure_callback: {litellm.failure_callback}") for callback in litellm.failure_callback: try: @@ -835,8 +830,9 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args, for detail in additional_details: slack_msg += f"{detail}: {additional_details[detail]}\n" slack_msg += f"Traceback: {traceback_exception}" - slack_app.client.chat_postMessage(channel=alerts_channel, - text=slack_msg) + slack_app.client.chat_postMessage( + channel=alerts_channel, text=slack_msg + ) elif callback == "sentry": capture_exception(exception) elif callback == "posthog": @@ -855,8 +851,9 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args, print_verbose(f"ph_obj: {ph_obj}") print_verbose(f"PostHog Event Name: {event_name}") if "user_id" in additional_details: - posthog.capture(additional_details["user_id"], - event_name, ph_obj) + posthog.capture( + additional_details["user_id"], event_name, ph_obj + ) else: # PostHog calls require a unique id to identify a user - https://posthog.com/docs/libraries/python unique_id = str(uuid.uuid4()) posthog.capture(unique_id, event_name) @@ -870,10 +867,10 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args, "created": time.time(), "error": traceback_exception, "usage": { - "prompt_tokens": - prompt_token_calculator(model, messages=messages), - "completion_tokens": - 0, + "prompt_tokens": prompt_token_calculator( + model, messages=messages + ), + "completion_tokens": 0, }, } berrispendLogger.log_event( @@ -892,10 +889,10 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args, "model": model, "created": time.time(), "usage": { - "prompt_tokens": - prompt_token_calculator(model, messages=messages), - "completion_tokens": - 0, + "prompt_tokens": prompt_token_calculator( + model, messages=messages + ), + "completion_tokens": 0, }, } aispendLogger.log_event( @@ -910,10 +907,13 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args, model = args[0] if len(args) > 0 else kwargs["model"] - input = args[1] if len(args) > 1 else kwargs.get( - "messages", kwargs.get("input", None)) + input = ( + args[1] + if len(args) > 1 + else kwargs.get("messages", kwargs.get("input", None)) + ) - type = 'embed' if 'input' in kwargs else 'llm' + type = "embed" if "input" in kwargs else "llm" llmonitorLogger.log_event( type=type, @@ -937,10 +937,10 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args, "created": time.time(), "error": traceback_exception, "usage": { - "prompt_tokens": - prompt_token_calculator(model, messages=messages), - "completion_tokens": - 0, + "prompt_tokens": prompt_token_calculator( + model, messages=messages + ), + "completion_tokens": 0, }, } supabaseClient.log_event( @@ -957,16 +957,28 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args, print_verbose("reaches lite_debugger for logging!") print_verbose(f"liteDebuggerClient: {liteDebuggerClient}") model = args[0] if len(args) > 0 else kwargs["model"] - messages = args[1] if len(args) > 1 else kwargs.get("messages", [{"role": "user", "content": ' '.join(kwargs.get("input", ""))}]) + messages = ( + args[1] + if len(args) > 1 + else kwargs.get( + "messages", + [ + { + "role": "user", + "content": " ".join(kwargs.get("input", "")), + } + ], + ) + ) result = { "model": model, "created": time.time(), "error": traceback_exception, "usage": { - "prompt_tokens": - prompt_token_calculator(model, messages=messages), - "completion_tokens": - 0, + "prompt_tokens": prompt_token_calculator( + model, messages=messages + ), + "completion_tokens": 0, }, } liteDebuggerClient.log_event( @@ -1002,11 +1014,16 @@ def handle_success(args, kwargs, result, start_time, end_time): global heliconeLogger, aispendLogger, supabaseClient, liteDebuggerClient, llmonitorLogger try: model = args[0] if len(args) > 0 else kwargs["model"] - input = args[1] if len(args) > 1 else kwargs.get("messages", kwargs.get("input", None)) + input = ( + args[1] + if len(args) > 1 + else kwargs.get("messages", kwargs.get("input", None)) + ) success_handler = additional_details.pop("success_handler", None) failure_handler = additional_details.pop("failure_handler", None) additional_details["Event_Name"] = additional_details.pop( - "successful_event_name", "litellm.succes_query") + "successful_event_name", "litellm.succes_query" + ) for callback in litellm.success_callback: try: if callback == "posthog": @@ -1015,8 +1032,9 @@ def handle_success(args, kwargs, result, start_time, end_time): ph_obj[detail] = additional_details[detail] event_name = additional_details["Event_Name"] if "user_id" in additional_details: - posthog.capture(additional_details["user_id"], - event_name, ph_obj) + posthog.capture( + additional_details["user_id"], event_name, ph_obj + ) else: # PostHog calls require a unique id to identify a user - https://posthog.com/docs/libraries/python unique_id = str(uuid.uuid4()) posthog.capture(unique_id, event_name, ph_obj) @@ -1025,8 +1043,9 @@ def handle_success(args, kwargs, result, start_time, end_time): slack_msg = "" for detail in additional_details: slack_msg += f"{detail}: {additional_details[detail]}\n" - slack_app.client.chat_postMessage(channel=alerts_channel, - text=slack_msg) + slack_app.client.chat_postMessage( + channel=alerts_channel, text=slack_msg + ) elif callback == "helicone": print_verbose("reaches helicone for logging!") model = args[0] if len(args) > 0 else kwargs["model"] @@ -1043,11 +1062,14 @@ def handle_success(args, kwargs, result, start_time, end_time): print_verbose("reaches llmonitor for logging!") model = args[0] if len(args) > 0 else kwargs["model"] - input = args[1] if len(args) > 1 else kwargs.get( - "messages", kwargs.get("input", None)) + input = ( + args[1] + if len(args) > 1 + else kwargs.get("messages", kwargs.get("input", None)) + ) - #if contains input, it's 'embedding', otherwise 'llm' - type = 'embed' if 'input' in kwargs else 'llm' + # if contains input, it's 'embedding', otherwise 'llm' + type = "embed" if "input" in kwargs else "llm" llmonitorLogger.log_event( type=type, @@ -1069,7 +1091,6 @@ def handle_success(args, kwargs, result, start_time, end_time): start_time=start_time, end_time=end_time, print_verbose=print_verbose, - ) elif callback == "aispend": print_verbose("reaches aispend for logging!") @@ -1084,7 +1105,11 @@ def handle_success(args, kwargs, result, start_time, end_time): elif callback == "supabase": print_verbose("reaches supabase for logging!") model = args[0] if len(args) > 0 else kwargs["model"] - messages = args[1] if len(args) > 1 else kwargs.get("messages", {"role": "user", "content": ""}) + messages = ( + args[1] + if len(args) > 1 + else kwargs.get("messages", {"role": "user", "content": ""}) + ) print(f"supabaseClient: {supabaseClient}") supabaseClient.log_event( model=model, @@ -1099,7 +1124,19 @@ def handle_success(args, kwargs, result, start_time, end_time): elif callback == "lite_debugger": print_verbose("reaches lite_debugger for logging!") print_verbose(f"liteDebuggerClient: {liteDebuggerClient}") - messages = args[1] if len(args) > 1 else kwargs.get("messages", [{"role": "user", "content": ' '.join(kwargs.get("input", ""))}]) + messages = ( + args[1] + if len(args) > 1 + else kwargs.get( + "messages", + [ + { + "role": "user", + "content": " ".join(kwargs.get("input", "")), + } + ], + ) + ) liteDebuggerClient.log_event( model=model, messages=messages, @@ -1129,6 +1166,7 @@ def handle_success(args, kwargs, result, start_time, end_time): ) pass + def acreate(*args, **kwargs): ## Thin client to handle the acreate langchain call return litellm.acompletion(*args, **kwargs) @@ -1170,28 +1208,43 @@ def modify_integration(integration_name, integration_params): if "table_name" in integration_params: Supabase.supabase_table_name = integration_params["table_name"] + ####### [BETA] HOSTED PRODUCT ################ - https://docs.litellm.ai/docs/debugging/hosted_debugging + def get_all_keys(llm_provider=None): try: global last_fetched_at_keys # if user is using hosted product -> instantiate their env with their hosted api keys - refresh every 5 minutes print_verbose(f"Reaches get all keys, llm_provider: {llm_provider}") - user_email = os.getenv("LITELLM_EMAIL") or litellm.email or litellm.token or os.getenv("LITELLM_TOKEN") + user_email = ( + os.getenv("LITELLM_EMAIL") + or litellm.email + or litellm.token + or os.getenv("LITELLM_TOKEN") + ) if user_email: time_delta = 0 if last_fetched_at_keys != None: current_time = time.time() time_delta = current_time - last_fetched_at_keys - if time_delta > 300 or last_fetched_at_keys == None or llm_provider: # if the llm provider is passed in , assume this happening due to an AuthError for that provider + if ( + time_delta > 300 or last_fetched_at_keys == None or llm_provider + ): # if the llm provider is passed in , assume this happening due to an AuthError for that provider # make the api call last_fetched_at = time.time() print_verbose(f"last_fetched_at: {last_fetched_at}") - response = requests.post(url="http://api.litellm.ai/get_all_keys", headers={"content-type": "application/json"}, data=json.dumps({"user_email": user_email})) + response = requests.post( + url="http://api.litellm.ai/get_all_keys", + headers={"content-type": "application/json"}, + data=json.dumps({"user_email": user_email}), + ) print_verbose(f"get model key response: {response.text}") data = response.json() # update model list - for key, value in data["model_keys"].items(): # follows the LITELLM API KEY format - _API_KEY - e.g. HUGGINGFACE_API_KEY + for key, value in data[ + "model_keys" + ].items(): # follows the LITELLM API KEY format - _API_KEY - e.g. HUGGINGFACE_API_KEY os.environ[key] = value # set model alias map for model_alias, value in data["model_alias_map"].items(): @@ -1200,19 +1253,31 @@ def get_all_keys(llm_provider=None): return None return None except: - print_verbose(f"[Non-Blocking Error] get_all_keys error - {traceback.format_exc()}") + print_verbose( + f"[Non-Blocking Error] get_all_keys error - {traceback.format_exc()}" + ) pass + def get_model_list(): global last_fetched_at try: # if user is using hosted product -> get their updated model list - user_email = os.getenv("LITELLM_EMAIL") or litellm.email or litellm.token or os.getenv("LITELLM_TOKEN") + user_email = ( + os.getenv("LITELLM_EMAIL") + or litellm.email + or litellm.token + or os.getenv("LITELLM_TOKEN") + ) if user_email: # make the api call last_fetched_at = time.time() print(f"last_fetched_at: {last_fetched_at}") - response = requests.post(url="http://api.litellm.ai/get_model_list", headers={"content-type": "application/json"}, data=json.dumps({"user_email": user_email})) + response = requests.post( + url="http://api.litellm.ai/get_model_list", + headers={"content-type": "application/json"}, + data=json.dumps({"user_email": user_email}), + ) print_verbose(f"get_model_list response: {response.text}") data = response.json() # update model list @@ -1224,12 +1289,14 @@ def get_model_list(): if f"{item.upper()}_API_KEY" not in os.environ: missing_llm_provider = item break - # update environment - if required + # update environment - if required threading.Thread(target=get_all_keys, args=(missing_llm_provider)).start() return model_list - return [] # return empty list by default + return [] # return empty list by default except: - print_verbose(f"[Non-Blocking Error] get_all_keys error - {traceback.format_exc()}") + print_verbose( + f"[Non-Blocking Error] get_all_keys error - {traceback.format_exc()}" + ) ####### EXCEPTION MAPPING ################ @@ -1253,36 +1320,33 @@ def exception_type(model, original_exception, custom_llm_provider): exception_type = "" if "claude" in model: # one of the anthropics if hasattr(original_exception, "status_code"): - print_verbose( - f"status_code: {original_exception.status_code}") + print_verbose(f"status_code: {original_exception.status_code}") if original_exception.status_code == 401: exception_mapping_worked = True raise AuthenticationError( - message= - f"AnthropicException - {original_exception.message}", + message=f"AnthropicException - {original_exception.message}", llm_provider="anthropic", ) elif original_exception.status_code == 400: exception_mapping_worked = True raise InvalidRequestError( - message= - f"AnthropicException - {original_exception.message}", + message=f"AnthropicException - {original_exception.message}", model=model, llm_provider="anthropic", ) elif original_exception.status_code == 429: exception_mapping_worked = True raise RateLimitError( - message= - f"AnthropicException - {original_exception.message}", + message=f"AnthropicException - {original_exception.message}", llm_provider="anthropic", ) - elif ("Could not resolve authentication method. Expected either api_key or auth_token to be set." - in error_str): + elif ( + "Could not resolve authentication method. Expected either api_key or auth_token to be set." + in error_str + ): exception_mapping_worked = True raise AuthenticationError( - message= - f"AnthropicException - {original_exception.message}", + message=f"AnthropicException - {original_exception.message}", llm_provider="anthropic", ) elif "replicate" in model: @@ -1306,36 +1370,35 @@ def exception_type(model, original_exception, custom_llm_provider): llm_provider="replicate", ) elif ( - exception_type == "ReplicateError" + exception_type == "ReplicateError" ): # ReplicateError implies an error on Replicate server side, not user side raise ServiceUnavailableError( message=f"ReplicateException - {error_str}", llm_provider="replicate", ) elif model == "command-nightly": # Cohere - if ("invalid api token" in error_str - or "No API key provided." in error_str): + if ( + "invalid api token" in error_str + or "No API key provided." in error_str + ): exception_mapping_worked = True raise AuthenticationError( - message= - f"CohereException - {original_exception.message}", + message=f"CohereException - {original_exception.message}", llm_provider="cohere", ) elif "too many tokens" in error_str: exception_mapping_worked = True raise InvalidRequestError( - message= - f"CohereException - {original_exception.message}", + message=f"CohereException - {original_exception.message}", model=model, llm_provider="cohere", ) elif ( - "CohereConnectionError" in exception_type + "CohereConnectionError" in exception_type ): # cohere seems to fire these errors when we load test it (1k+ messages / min) exception_mapping_worked = True raise RateLimitError( - message= - f"CohereException - {original_exception.message}", + message=f"CohereException - {original_exception.message}", llm_provider="cohere", ) elif custom_llm_provider == "huggingface": @@ -1343,23 +1406,20 @@ def exception_type(model, original_exception, custom_llm_provider): if original_exception.status_code == 401: exception_mapping_worked = True raise AuthenticationError( - message= - f"HuggingfaceException - {original_exception.message}", + message=f"HuggingfaceException - {original_exception.message}", llm_provider="huggingface", ) elif original_exception.status_code == 400: exception_mapping_worked = True raise InvalidRequestError( - message= - f"HuggingfaceException - {original_exception.message}", + message=f"HuggingfaceException - {original_exception.message}", model=model, llm_provider="huggingface", ) elif original_exception.status_code == 429: exception_mapping_worked = True raise RateLimitError( - message= - f"HuggingfaceException - {original_exception.message}", + message=f"HuggingfaceException - {original_exception.message}", llm_provider="huggingface", ) raise original_exception # base case - return the original exception @@ -1375,8 +1435,10 @@ def exception_type(model, original_exception, custom_llm_provider): }, exception=e, ) - ## AUTH ERROR - if isinstance(e, AuthenticationError) and (litellm.email or "LITELLM_EMAIL" in os.environ): + ## AUTH ERROR + if isinstance(e, AuthenticationError) and ( + litellm.email or "LITELLM_EMAIL" in os.environ + ): threading.Thread(target=get_all_keys, args=(e.llm_provider,)).start() if exception_mapping_worked: raise e @@ -1391,7 +1453,8 @@ def safe_crash_reporting(model=None, exception=None, custom_llm_provider=None): "exception": str(exception), "custom_llm_provider": custom_llm_provider, } - threading.Thread(target=litellm_telemetry, args=(data, )).start() + threading.Thread(target=litellm_telemetry, args=(data,)).start() + def get_or_generate_uuid(): uuid_file = "litellm_uuid.txt" @@ -1445,8 +1508,7 @@ def get_secret(secret_name): # TODO: check which secret manager is being used # currently only supports Infisical try: - secret = litellm.secret_manager_client.get_secret( - secret_name).secret_value + secret = litellm.secret_manager_client.get_secret(secret_name).secret_value except: secret = None return secret @@ -1460,7 +1522,6 @@ def get_secret(secret_name): # wraps the completion stream to return the correct format for the model # replicate/anthropic/cohere class CustomStreamWrapper: - def __init__(self, completion_stream, model, custom_llm_provider=None): self.model = model self.custom_llm_provider = custom_llm_provider @@ -1509,8 +1570,9 @@ class CustomStreamWrapper: elif self.model == "replicate": chunk = next(self.completion_stream) completion_obj["content"] = chunk - elif (self.custom_llm_provider and self.custom_llm_provider == "together_ai") or ("togethercomputer" - in self.model): + elif ( + self.custom_llm_provider and self.custom_llm_provider == "together_ai" + ) or ("togethercomputer" in self.model): chunk = next(self.completion_stream) text_data = self.handle_together_ai_chunk(chunk) if text_data == "": @@ -1545,9 +1607,9 @@ def read_config_args(config_path): ########## ollama implementation ############################ -async def get_ollama_response_stream(api_base="http://localhost:11434", - model="llama2", - prompt="Why is the sky blue?"): +async def get_ollama_response_stream( + api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?" +): session = aiohttp.ClientSession() url = f"{api_base}/api/generate" data = { @@ -1570,11 +1632,7 @@ async def get_ollama_response_stream(api_base="http://localhost:11434", "content": "", } completion_obj["content"] = j["response"] - yield { - "choices": [{ - "delta": completion_obj - }] - } + yield {"choices": [{"delta": completion_obj}]} # self.responses.append(j["response"]) # yield "blank" except Exception as e: