diff --git a/litellm/__init__.py b/litellm/__init__.py index a9fd326347..3073f27196 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -1,5 +1,6 @@ import threading from typing import Callable, List, Optional + input_callback: List[str] = [] success_callback: List[str] = [] failure_callback: List[str] = [] @@ -20,7 +21,8 @@ vertex_project: Optional[str] = None vertex_location: Optional[str] = None togetherai_api_key: Optional[str] = None caching = False -caching_with_models = False # if you want the caching key to be model + prompt +caching_with_models = False # if you want the caching key to be model + prompt +debugger = False model_cost = { "gpt-3.5-turbo": { "max_tokens": 4000, @@ -156,7 +158,7 @@ replicate_models = [ "a16z-infra/llama-2-7b-chat:7b0bfc9aff140d5b75bacbed23e91fd3c34b01a1e958d32132de6e0a19796e2c", "replicate/vicuna-13b:6282abe6a492de4145d7bb601023762212f9ddbbe78278bd6771c8b3b2f2a13b", "daanelson/flan-t5-large:ce962b3f6792a57074a601d3979db5839697add2e4e02696b3ced4c022d4767f", - "replit/replit-code-v1-3b:b84f4c074b807211cd75e3e8b1589b6399052125b4c27106e43d47189e8415ad" + "replit/replit-code-v1-3b:b84f4c074b807211cd75e3e8b1589b6399052125b4c27106e43d47189e8415ad", ] # placeholder, to make sure we accept any replicate model in our model_list openrouter_models = [ @@ -196,14 +198,10 @@ ai21_models = ["j2-ultra", "j2-mid", "j2-light"] together_ai_models = [ "togethercomputer/llama-2-70b-chat", "togethercomputer/Llama-2-7B-32K-Instruct", - "togethercomputer/llama-2-7b" + "togethercomputer/llama-2-7b", ] -baseten_models = [ - "qvv0xeq", # FALCON 7B - "q841o8w", # WizardLM - "31dxrj3" # Mosaic ML -] +baseten_models = ["qvv0xeq", "q841o8w", "31dxrj3"] # FALCON 7B # WizardLM # Mosaic ML model_list = ( open_ai_chat_completion_models @@ -231,12 +229,11 @@ provider_list = [ "openrouter", "vertex_ai", "ai21", - "baseten" + "baseten", ] models_by_provider = { - "openai": open_ai_chat_completion_models - + open_ai_text_completion_models, + "openai": open_ai_chat_completion_models + open_ai_text_completion_models, "cohere": cohere_models, "anthropic": anthropic_models, "replicate": replicate_models, @@ -263,8 +260,11 @@ from .utils import ( completion_cost, get_litellm_params, Logging, +<<<<<<< HEAD acreate, get_model_list +======= +>>>>>>> 878f1a6 (formatting fixes) ) from .main import * # type: ignore from .integrations import * diff --git a/litellm/integrations/litedebugger.py b/litellm/integrations/litedebugger.py index 1880c4694f..74b6ec4ded 100644 --- a/litellm/integrations/litedebugger.py +++ b/litellm/integrations/litedebugger.py @@ -1,5 +1,6 @@ import requests, traceback, json, os + class LiteDebugger: user_email = None dashboard_url = None @@ -14,43 +15,57 @@ class LiteDebugger: self.dashboard_url = 'https://admin.litellm.ai/' + self.user_email print(f"Here's your free Dashboard 👉 {self.dashboard_url}") if self.user_email == None: - raise Exception("[Non-Blocking Error] LiteLLMDebugger: Missing LITELLM_EMAIL. Set it in your environment. Eg.: os.environ['LITELLM_EMAIL']= ") + raise Exception( + "[Non-Blocking Error] LiteLLMDebugger: Missing LITELLM_EMAIL. Set it in your environment. Eg.: os.environ['LITELLM_EMAIL']= " + ) except Exception as e: - raise ValueError("[Non-Blocking Error] LiteLLMDebugger: Missing LITELLM_EMAIL. Set it in your environment. Eg.: os.environ['LITELLM_EMAIL']= ") + raise ValueError( + "[Non-Blocking Error] LiteLLMDebugger: Missing LITELLM_EMAIL. Set it in your environment. Eg.: os.environ['LITELLM_EMAIL']= " + ) - - def input_log_event(self, model, messages, end_user, litellm_call_id, print_verbose): + def input_log_event( + self, model, messages, end_user, litellm_call_id, print_verbose + ): try: print_verbose( f"LiteLLMDebugger: Logging - Enters input logging function for model {model}" ) litellm_data_obj = { - "model": model, - "messages": messages, - "end_user": end_user, + "model": model, + "messages": messages, + "end_user": end_user, "status": "initiated", "litellm_call_id": litellm_call_id, - "user_email": self.user_email + "user_email": self.user_email, } - response = requests.post(url=self.api_url, headers={"content-type": "application/json"}, data=json.dumps(litellm_data_obj)) + response = requests.post( + url=self.api_url, + headers={"content-type": "application/json"}, + data=json.dumps(litellm_data_obj), + ) print_verbose(f"LiteDebugger: api response - {response.text}") except: - print_verbose(f"[Non-Blocking Error] LiteDebugger: Logging Error - {traceback.format_exc()}") + print_verbose( + f"[Non-Blocking Error] LiteDebugger: Logging Error - {traceback.format_exc()}" + ) pass - - def log_event(self, model, + + def log_event( + self, + model, messages, end_user, response_obj, start_time, end_time, litellm_call_id, - print_verbose,): + print_verbose, + ): try: print_verbose( f"LiteLLMDebugger: Logging - Enters input logging function for model {model}" ) - total_cost = 0 # [TODO] implement cost tracking + total_cost = 0 # [TODO] implement cost tracking response_time = (end_time - start_time).total_seconds() if "choices" in response_obj: litellm_data_obj = { @@ -62,12 +77,16 @@ class LiteDebugger: "end_user": end_user, "litellm_call_id": litellm_call_id, "status": "success", - "user_email": self.user_email + "user_email": self.user_email, } print_verbose( f"LiteDebugger: Logging - final data object: {litellm_data_obj}" ) - response = requests.post(url=self.api_url, headers={"content-type": "application/json"}, data=json.dumps(litellm_data_obj)) + response = requests.post( + url=self.api_url, + headers={"content-type": "application/json"}, + data=json.dumps(litellm_data_obj), + ) elif "error" in response_obj: if "Unable to map your input to a model." in response_obj["error"]: total_cost = 0 @@ -80,13 +99,19 @@ class LiteDebugger: "end_user": end_user, "litellm_call_id": litellm_call_id, "status": "failure", - "user_email": self.user_email + "user_email": self.user_email, } print_verbose( f"LiteDebugger: Logging - final data object: {litellm_data_obj}" ) - response = requests.post(url=self.api_url, headers={"content-type": "application/json"}, data=json.dumps(litellm_data_obj)) + response = requests.post( + url=self.api_url, + headers={"content-type": "application/json"}, + data=json.dumps(litellm_data_obj), + ) print_verbose(f"LiteDebugger: api response - {response.text}") except: - print_verbose(f"[Non-Blocking Error] LiteDebugger: Logging Error - {traceback.format_exc()}") - pass \ No newline at end of file + print_verbose( + f"[Non-Blocking Error] LiteDebugger: Logging Error - {traceback.format_exc()}" + ) + pass diff --git a/litellm/integrations/supabase.py b/litellm/integrations/supabase.py index 3ea63f5c67..1091fd15ab 100644 --- a/litellm/integrations/supabase.py +++ b/litellm/integrations/supabase.py @@ -144,23 +144,25 @@ class Supabase: ) return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar - def input_log_event(self, model, messages, end_user, litellm_call_id, print_verbose): + def input_log_event( + self, model, messages, end_user, litellm_call_id, print_verbose + ): try: print_verbose( f"Supabase Logging - Enters input logging function for model {model}" ) supabase_data_obj = { - "model": model, - "messages": messages, - "end_user": end_user, + "model": model, + "messages": messages, + "end_user": end_user, "status": "initiated", - "litellm_call_id": litellm_call_id + "litellm_call_id": litellm_call_id, } data, count = ( - self.supabase_client.table(self.supabase_table_name) - .insert(supabase_data_obj) - .execute() - ) + self.supabase_client.table(self.supabase_table_name) + .insert(supabase_data_obj) + .execute() + ) print(f"data: {data}") except: print_verbose(f"Supabase Logging Error - {traceback.format_exc()}") @@ -200,7 +202,7 @@ class Supabase: "response": response_obj["choices"][0]["message"]["content"], "end_user": end_user, "litellm_call_id": litellm_call_id, - "status": "success" + "status": "success", } print_verbose( f"Supabase Logging - final data object: {supabase_data_obj}" @@ -221,7 +223,7 @@ class Supabase: "error": response_obj["error"], "end_user": end_user, "litellm_call_id": litellm_call_id, - "status": "failure" + "status": "failure", } print_verbose( f"Supabase Logging - final data object: {supabase_data_obj}" diff --git a/litellm/llms/anthropic.py b/litellm/llms/anthropic.py index fecc655f2f..983d4ca0e5 100644 --- a/litellm/llms/anthropic.py +++ b/litellm/llms/anthropic.py @@ -21,7 +21,9 @@ class AnthropicError(Exception): class AnthropicLLM: - def __init__(self, encoding, default_max_tokens_to_sample, logging_obj, api_key=None): + def __init__( + self, encoding, default_max_tokens_to_sample, logging_obj, api_key=None + ): self.encoding = encoding self.default_max_tokens_to_sample = default_max_tokens_to_sample self.completion_url = "https://api.anthropic.com/v1/complete" @@ -84,7 +86,11 @@ class AnthropicLLM: } ## LOGGING - self.logging_obj.pre_call(input=prompt, api_key=self.api_key, additional_args={"complete_input_dict": data}) + self.logging_obj.pre_call( + input=prompt, + api_key=self.api_key, + additional_args={"complete_input_dict": data}, + ) ## COMPLETION CALL response = requests.post( self.completion_url, headers=self.headers, data=json.dumps(data) @@ -93,7 +99,12 @@ class AnthropicLLM: return response.iter_lines() else: ## LOGGING - self.logging_obj.post_call(input=prompt, api_key=self.api_key, original_response=response.text, additional_args={"complete_input_dict": data}) + self.logging_obj.post_call( + input=prompt, + api_key=self.api_key, + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) print_verbose(f"raw model_response: {response.text}") ## RESPONSE OBJECT completion_response = response.json() diff --git a/litellm/llms/huggingface_restapi.py b/litellm/llms/huggingface_restapi.py index 624fb4f055..fbde78334e 100644 --- a/litellm/llms/huggingface_restapi.py +++ b/litellm/llms/huggingface_restapi.py @@ -72,12 +72,13 @@ class HuggingfaceRestAPILLM: if "max_tokens" in optional_params: value = optional_params.pop("max_tokens") optional_params["max_new_tokens"] = value - data = { - "inputs": prompt, - "parameters": optional_params - } + data = {"inputs": prompt, "parameters": optional_params} ## LOGGING - self.logging_obj.pre_call(input=prompt, api_key=self.api_key, additional_args={"complete_input_dict": data}) + self.logging_obj.pre_call( + input=prompt, + api_key=self.api_key, + additional_args={"complete_input_dict": data}, + ) ## COMPLETION CALL response = requests.post( completion_url, headers=self.headers, data=json.dumps(data) @@ -86,7 +87,12 @@ class HuggingfaceRestAPILLM: return response.iter_lines() else: ## LOGGING - self.logging_obj.post_call(input=prompt, api_key=self.api_key, original_response=response.text, additional_args={"complete_input_dict": data}) + self.logging_obj.post_call( + input=prompt, + api_key=self.api_key, + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) ## RESPONSE OBJECT completion_response = response.json() print_verbose(f"response: {completion_response}") diff --git a/litellm/main.py b/litellm/main.py index 3fb8da19f2..298d2ce956 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -10,7 +10,7 @@ from litellm import ( # type: ignore timeout, get_optional_params, get_litellm_params, - Logging + Logging, ) from litellm.utils import ( get_secret, @@ -96,10 +96,14 @@ def completion( model_response = ModelResponse() if azure: # this flag is deprecated, remove once notebooks are also updated. custom_llm_provider = "azure" - elif model.split("/", 1)[0] in litellm.provider_list: # allow custom provider to be passed in via the model name "azure/chatgpt-test" + elif ( + model.split("/", 1)[0] in litellm.provider_list + ): # allow custom provider to be passed in via the model name "azure/chatgpt-test" custom_llm_provider = model.split("/", 1)[0] model = model.split("/", 1)[1] - if "replicate" == custom_llm_provider and "/" not in model: # handle the "replicate/llama2..." edge-case + if ( + "replicate" == custom_llm_provider and "/" not in model + ): # handle the "replicate/llama2..." edge-case model = custom_llm_provider + "/" + model # check if user passed in any of the OpenAI optional params optional_params = get_optional_params( @@ -130,9 +134,14 @@ def completion( verbose=verbose, custom_llm_provider=custom_llm_provider, custom_api_base=custom_api_base, - litellm_call_id=litellm_call_id + litellm_call_id=litellm_call_id, + ) + logging = Logging( + model=model, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, ) - logging = Logging(model=model, messages=messages, optional_params=optional_params, litellm_params=litellm_params) if custom_llm_provider == "azure": # azure configs openai.api_type = "azure" @@ -153,7 +162,15 @@ def completion( # set key openai.api_key = api_key ## LOGGING - logging.pre_call(input=messages, api_key=openai.api_key, additional_args={"litellm.headers": litellm.headers, "api_version": openai.api_version, "api_base": openai.api_base}) + logging.pre_call( + input=messages, + api_key=openai.api_key, + additional_args={ + "headers": litellm.headers, + "api_version": openai.api_version, + "api_base": openai.api_base, + }, + ) ## COMPLETION CALL if litellm.headers: response = openai.ChatCompletion.create( @@ -168,7 +185,16 @@ def completion( ) ## LOGGING - logging.post_call(input=messages, api_key=openai.api_key, original_response=response, additional_args={"headers": litellm.headers, "api_version": openai.api_version, "api_base": openai.api_base}) + logging.post_call( + input=messages, + api_key=openai.api_key, + original_response=response, + additional_args={ + "headers": litellm.headers, + "api_version": openai.api_version, + "api_base": openai.api_base, + }, + ) elif ( model in litellm.open_ai_chat_completion_models or custom_llm_provider == "custom_openai" @@ -193,7 +219,11 @@ def completion( openai.api_key = api_key ## LOGGING - logging.pre_call(input=messages, api_key=api_key, additional_args={"headers": litellm.headers, "api_base": api_base}) + logging.pre_call( + input=messages, + api_key=api_key, + additional_args={"headers": litellm.headers, "api_base": api_base}, + ) ## COMPLETION CALL if litellm.headers: response = openai.ChatCompletion.create( @@ -207,7 +237,12 @@ def completion( model=model, messages=messages, **optional_params ) ## LOGGING - logging.post_call(input=messages, api_key=api_key, original_response=response, additional_args={"headers": litellm.headers}) + logging.post_call( + input=messages, + api_key=api_key, + original_response=response, + additional_args={"headers": litellm.headers}, + ) elif model in litellm.open_ai_text_completion_models: openai.api_type = "openai" openai.api_base = ( @@ -228,7 +263,16 @@ def completion( openai.organization = litellm.organization prompt = " ".join([message["content"] for message in messages]) ## LOGGING - logging.pre_call(input=prompt, api_key=api_key, additional_args={"openai_organization": litellm.organization, "headers": litellm.headers, "api_base": openai.api_base, "api_type": openai.api_type}) + logging.pre_call( + input=prompt, + api_key=api_key, + additional_args={ + "openai_organization": litellm.organization, + "headers": litellm.headers, + "api_base": openai.api_base, + "api_type": openai.api_type, + }, + ) ## COMPLETION CALL if litellm.headers: response = openai.Completion.create( @@ -239,7 +283,17 @@ def completion( else: response = openai.Completion.create(model=model, prompt=prompt) ## LOGGING - logging.post_call(input=prompt, api_key=api_key, original_response=response, additional_args={"openai_organization": litellm.organization, "headers": litellm.headers, "api_base": openai.api_base, "api_type": openai.api_type}) + logging.post_call( + input=prompt, + api_key=api_key, + original_response=response, + additional_args={ + "openai_organization": litellm.organization, + "headers": litellm.headers, + "api_base": openai.api_base, + "api_type": openai.api_type, + }, + ) ## RESPONSE OBJECT completion_response = response["choices"][0]["text"] model_response["choices"][0]["message"]["content"] = completion_response @@ -270,7 +324,14 @@ def completion( input["max_length"] = max_tokens # for t5 models input["max_new_tokens"] = max_tokens # for llama2 models ## LOGGING - logging.pre_call(input=prompt, api_key=replicate_key, additional_args={"complete_input_dict": input, "max_tokens": max_tokens}) + logging.pre_call( + input=prompt, + api_key=replicate_key, + additional_args={ + "complete_input_dict": input, + "max_tokens": max_tokens, + }, + ) ## COMPLETION CALL output = replicate.run(model, input=input) if "stream" in optional_params and optional_params["stream"] == True: @@ -283,7 +344,15 @@ def completion( response += item completion_response = response ## LOGGING - logging.post_call(input=prompt, api_key=replicate_key, original_response=completion_response, additional_args={"complete_input_dict": input, "max_tokens": max_tokens}) + logging.post_call( + input=prompt, + api_key=replicate_key, + original_response=completion_response, + additional_args={ + "complete_input_dict": input, + "max_tokens": max_tokens, + }, + ) ## USAGE prompt_tokens = len(encoding.encode(prompt)) completion_tokens = len(encoding.encode(completion_response)) @@ -305,7 +374,7 @@ def completion( encoding=encoding, default_max_tokens_to_sample=litellm.max_tokens, api_key=anthropic_key, - logging_obj = logging # model call logging done inside the class as we make need to modify I/O to fit anthropic's requirements + logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit anthropic's requirements ) model_response = anthropic_client.completion( model=model, @@ -369,7 +438,9 @@ def completion( **optional_params, ) ## LOGGING - logging.post_call(input=messages, api_key=openai.api_key, original_response=response) + logging.post_call( + input=messages, api_key=openai.api_key, original_response=response + ) elif model in litellm.cohere_models: # import cohere/if it fails then pip install cohere install_and_import("cohere") @@ -392,7 +463,9 @@ def completion( response = CustomStreamWrapper(response, model) return response ## LOGGING - logging.post_call(input=prompt, api_key=cohere_key, original_response=response) + logging.post_call( + input=prompt, api_key=cohere_key, original_response=response + ) ## USAGE completion_response = response[0].text prompt_tokens = len(encoding.encode(prompt)) @@ -475,7 +548,9 @@ def completion( headers=headers, ) ## LOGGING - logging.post_call(input=prompt, api_key=TOGETHER_AI_TOKEN, original_response=res.text) + logging.post_call( + input=prompt, api_key=TOGETHER_AI_TOKEN, original_response=res.text + ) # make this safe for reading, if output does not exist raise an error json_response = res.json() if "output" not in json_response: @@ -516,7 +591,9 @@ def completion( completion_response = chat.send_message(prompt, **optional_params) ## LOGGING - logging.post_call(input=prompt, api_key=None, original_response=completion_response) + logging.post_call( + input=prompt, api_key=None, original_response=completion_response + ) ## RESPONSE OBJECT model_response["choices"][0]["message"]["content"] = completion_response @@ -541,7 +618,9 @@ def completion( completion_response = vertex_model.predict(prompt, **optional_params) ## LOGGING - logging.post_call(input=prompt, api_key=None, original_response=completion_response) + logging.post_call( + input=prompt, api_key=None, original_response=completion_response + ) ## RESPONSE OBJECT model_response["choices"][0]["message"]["content"] = completion_response model_response["created"] = time.time() @@ -564,7 +643,11 @@ def completion( completion_response = ai21_response["completions"][0]["data"]["text"] ## LOGGING - logging.post_call(input=prompt, api_key=ai21.api_key, original_response=completion_response) + logging.post_call( + input=prompt, + api_key=ai21.api_key, + original_response=completion_response, + ) ## RESPONSE OBJECT model_response["choices"][0]["message"]["content"] = completion_response @@ -578,7 +661,9 @@ def completion( prompt = " ".join([message["content"] for message in messages]) ## LOGGING - logging.pre_call(input=prompt, api_key=None, additional_args={"endpoint": endpoint}) + logging.pre_call( + input=prompt, api_key=None, additional_args={"endpoint": endpoint} + ) generator = get_ollama_response_stream(endpoint, model, prompt) # assume all responses are streamed @@ -605,7 +690,11 @@ def completion( completion_response = completion_response["generated_text"] ## LOGGING - logging.post_call(input=prompt, api_key=base_ten_key, original_response=completion_response) + logging.post_call( + input=prompt, + api_key=base_ten_key, + original_response=completion_response, + ) ## RESPONSE OBJECT model_response["choices"][0]["message"]["content"] = completion_response @@ -622,13 +711,22 @@ def completion( prompt = " ".join([message["content"] for message in messages]) ## LOGGING - logging.pre_call(input=prompt, api_key=None, additional_args={"url": url, "max_new_tokens": 100}) + logging.pre_call( + input=prompt, + api_key=None, + additional_args={"url": url, "max_new_tokens": 100}, + ) response = requests.post( url, data={"inputs": prompt, "max_new_tokens": 100, "model": model} ) ## LOGGING - logging.post_call(input=prompt, api_key=None, original_response=response.text, additional_args={"url": url, "max_new_tokens": 100}) + logging.post_call( + input=prompt, + api_key=None, + original_response=response.text, + additional_args={"url": url, "max_new_tokens": 100}, + ) completion_response = response.json()["outputs"] @@ -676,10 +774,22 @@ def batch_completion(*args, **kwargs): @timeout( # type: ignore 60 ) ## set timeouts, in case calls hang (e.g. Azure) - default is 60s, override with `force_timeout` -def embedding(model, input=[], azure=False, force_timeout=60, litellm_call_id=None, logger_fn=None): +def embedding( + model, input=[], azure=False, force_timeout=60, litellm_call_id=None, logger_fn=None +): try: response = None - logging = Logging(model=model, messages=input, optional_params={}, litellm_params={"azure": azure, "force_timeout": force_timeout, "logger_fn": logger_fn, "litellm_call_id": litellm_call_id}) + logging = Logging( + model=model, + messages=input, + optional_params={}, + litellm_params={ + "azure": azure, + "force_timeout": force_timeout, + "logger_fn": logger_fn, + "litellm_call_id": litellm_call_id, + }, + ) if azure == True: # azure configs openai.api_type = "azure" @@ -687,7 +797,15 @@ def embedding(model, input=[], azure=False, force_timeout=60, litellm_call_id=No openai.api_version = get_secret("AZURE_API_VERSION") openai.api_key = get_secret("AZURE_API_KEY") ## LOGGING - logging.pre_call(input=input, api_key=openai.api_key, additional_args={"api_type": openai.api_type, "api_base": openai.api_base, "api_version": openai.api_version}) + logging.pre_call( + input=input, + api_key=openai.api_key, + additional_args={ + "api_type": openai.api_type, + "api_base": openai.api_base, + "api_version": openai.api_version, + }, + ) ## EMBEDDING CALL response = openai.Embedding.create(input=input, engine=model) print_verbose(f"response_value: {str(response)[:50]}") @@ -697,7 +815,15 @@ def embedding(model, input=[], azure=False, force_timeout=60, litellm_call_id=No openai.api_version = None openai.api_key = get_secret("OPENAI_API_KEY") ## LOGGING - logging.pre_call(input=input, api_key=openai.api_key, additional_args={"api_type": openai.api_type, "api_base": openai.api_base, "api_version": openai.api_version}) + logging.pre_call( + input=input, + api_key=openai.api_key, + additional_args={ + "api_type": openai.api_type, + "api_base": openai.api_base, + "api_version": openai.api_version, + }, + ) ## EMBEDDING CALL response = openai.Embedding.create(input=input, model=model) print_verbose(f"response_value: {str(response)[:50]}") @@ -710,7 +836,11 @@ def embedding(model, input=[], azure=False, force_timeout=60, litellm_call_id=No ## LOGGING logging.post_call(input=input, api_key=openai.api_key, original_response=e) ## Map to OpenAI Exception - raise exception_type(model=model, original_exception=e, custom_llm_provider="azure" if azure==True else None) + raise exception_type( + model=model, + original_exception=e, + custom_llm_provider="azure" if azure == True else None, + ) ####### HELPER FUNCTIONS ################ diff --git a/litellm/tests/test_caching.py b/litellm/tests/test_caching.py index 87ae02b58f..8365937f6f 100644 --- a/litellm/tests/test_caching.py +++ b/litellm/tests/test_caching.py @@ -34,7 +34,6 @@ def test_caching(): pytest.fail(f"Error occurred: {e}") - def test_caching_with_models(): litellm.caching_with_models = True response2 = completion(model="gpt-3.5-turbo", messages=messages) @@ -47,6 +46,3 @@ def test_caching_with_models(): print(f"response2: {response2}") print(f"response3: {response3}") pytest.fail(f"Error occurred: {e}") - - - diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 9717269820..6fd1964b8f 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -28,15 +28,19 @@ def logger_fn(user_model_dict): def test_completion_custom_provider_model_name(): try: response = completion( - model="together_ai/togethercomputer/llama-2-70b-chat", messages=messages, logger_fn=logger_fn + model="together_ai/togethercomputer/llama-2-70b-chat", + messages=messages, + logger_fn=logger_fn, ) # Add any assertions here to check the response print(response) except Exception as e: pytest.fail(f"Error occurred: {e}") + test_completion_custom_provider_model_name() + def test_completion_claude(): try: response = completion( @@ -89,7 +93,10 @@ def test_completion_claude_stream(): def test_completion_cohere(): try: response = completion( - model="command-nightly", messages=messages, max_tokens=100, logit_bias={40: 10} + model="command-nightly", + messages=messages, + max_tokens=100, + logit_bias={40: 10}, ) # Add any assertions here to check the response print(response) @@ -103,6 +110,7 @@ def test_completion_cohere(): except Exception as e: pytest.fail(f"Error occurred: {e}") + def test_completion_cohere_stream(): try: messages = [ diff --git a/litellm/utils.py b/litellm/utils.py index 4fbbb73eaf..1684bb7579 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -139,6 +139,7 @@ def install_and_import(package: str): # Logging function -> log the exact model details + what's being sent | Non-Blocking class Logging: global supabaseClient, liteDebuggerClient + def __init__(self, model, messages, optional_params, litellm_params): self.model = model self.messages = messages @@ -146,19 +147,19 @@ class Logging: self.litellm_params = litellm_params self.logger_fn = litellm_params["logger_fn"] self.model_call_details = { - "model": model, - "messages": messages, + "model": model, + "messages": messages, "optional_params": self.optional_params, "litellm_params": self.litellm_params, } - + def pre_call(self, input, api_key, additional_args={}): try: print_verbose(f"logging pre call for model: {self.model}") self.model_call_details["input"] = input self.model_call_details["api_key"] = api_key self.model_call_details["additional_args"] = additional_args - + ## User Logging -> if you pass in a custom logging function print_verbose( f"Logging Details: logger_fn - {self.logger_fn} | callable(logger_fn) - {callable(self.logger_fn)}" @@ -173,7 +174,7 @@ class Logging: f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {traceback.format_exc()}" ) - ## Input Integration Logging -> If you want to log the fact that an attempt to call the model was made + ## Input Integration Logging -> If you want to log the fact that an attempt to call the model was made for callback in litellm.input_callback: try: if callback == "supabase": @@ -201,11 +202,13 @@ class Logging: print_verbose=print_verbose, ) except Exception as e: - print_verbose(f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while input logging with integrations {traceback.format_exc()}") + print_verbose( + f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while input logging with integrations {traceback.format_exc()}" + ) print_verbose( f"LiteLLM.Logging: is sentry capture exception initialized {capture_exception}" ) - if capture_exception: # log this error to sentry for debugging + if capture_exception: # log this error to sentry for debugging capture_exception(e) except: print_verbose( @@ -214,9 +217,9 @@ class Logging: print_verbose( f"LiteLLM.Logging: is sentry capture exception initialized {capture_exception}" ) - if capture_exception: # log this error to sentry for debugging + if capture_exception: # log this error to sentry for debugging capture_exception(e) - + def post_call(self, input, api_key, original_response, additional_args={}): # Do something here try: @@ -224,7 +227,7 @@ class Logging: self.model_call_details["api_key"] = api_key self.model_call_details["original_response"] = original_response self.model_call_details["additional_args"] = additional_args - + ## User Logging -> if you pass in a custom logging function print_verbose( f"Logging Details: logger_fn - {self.logger_fn} | callable(logger_fn) - {callable(self.logger_fn)}" @@ -244,6 +247,7 @@ class Logging: ) pass + def exception_logging( additional_args={}, logger_fn=None, @@ -278,6 +282,7 @@ def exception_logging( # make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking def client(original_function): global liteDebuggerClient + def function_setup( *args, **kwargs ): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc. @@ -288,10 +293,16 @@ def client(original_function): litellm.success_callback.append("lite_debugger") litellm.failure_callback.append("lite_debugger") if ( - len(litellm.input_callback) > 0 or len(litellm.success_callback) > 0 or len(litellm.failure_callback) > 0 + len(litellm.input_callback) > 0 + or len(litellm.success_callback) > 0 + or len(litellm.failure_callback) > 0 ) and len(callback_list) == 0: callback_list = list( - set(litellm.input_callback + litellm.success_callback + litellm.failure_callback) + set( + litellm.input_callback + + litellm.success_callback + + litellm.failure_callback + ) ) set_callbacks( callback_list=callback_list, @@ -413,7 +424,9 @@ def client(original_function): ) # don't interrupt execution of main thread my_thread.start() if hasattr(e, "message"): - if liteDebuggerClient and liteDebuggerClient.dashboard_url != None: # make it easy to get to the debugger logs if you've initialized it + if ( + liteDebuggerClient and liteDebuggerClient.dashboard_url != None + ): # make it easy to get to the debugger logs if you've initialized it e.message += f"\n Check the log in your dashboard - {liteDebuggerClient.dashboard_url}" raise e @@ -497,7 +510,7 @@ def get_litellm_params( "verbose": verbose, "custom_llm_provider": custom_llm_provider, "custom_api_base": custom_api_base, - "litellm_call_id": litellm_call_id + "litellm_call_id": litellm_call_id, } return litellm_params @@ -1052,14 +1065,18 @@ def prompt_token_calculator(model, messages): def valid_model(model): try: - # for a given model name, check if the user has the right permissions to access the model - if model in litellm.open_ai_chat_completion_models or model in litellm.open_ai_text_completion_models: + # for a given model name, check if the user has the right permissions to access the model + if ( + model in litellm.open_ai_chat_completion_models + or model in litellm.open_ai_text_completion_models + ): openai.Model.retrieve(model) else: messages = [{"role": "user", "content": "Hello World"}] litellm.completion(model=model, messages=messages) except: - raise InvalidRequestError(message="", model=model, llm_provider="") + raise InvalidRequestError(message="", model=model, llm_provider="") + # integration helper function def modify_integration(integration_name, integration_params): @@ -1410,7 +1427,7 @@ async def stream_to_string(generator): return response -########## Together AI streaming ############################# [TODO] move together ai to it's own llm class +########## Together AI streaming ############################# [TODO] move together ai to it's own llm class async def together_ai_completion_streaming(json_data, headers): session = aiohttp.ClientSession() url = "https://api.together.xyz/inference"