formatting improvements

This commit is contained in:
ishaan-jaff 2023-08-28 09:20:50 -07:00
parent 70b323e0f5
commit b713acb0a4
17 changed files with 464 additions and 323 deletions

View file

@ -5,8 +5,12 @@ input_callback: List[str] = []
success_callback: List[str] = []
failure_callback: List[str] = []
set_verbose = False
email: Optional[str] = None # for hosted dashboard. Learn more - https://docs.litellm.ai/docs/debugging/hosted_debugging
token: Optional[str] = None # for hosted dashboard. Learn more - https://docs.litellm.ai/docs/debugging/hosted_debugging
email: Optional[
str
] = None # for hosted dashboard. Learn more - https://docs.litellm.ai/docs/debugging/hosted_debugging
token: Optional[
str
] = None # for hosted dashboard. Learn more - https://docs.litellm.ai/docs/debugging/hosted_debugging
telemetry = True
max_tokens = 256 # OpenAI Defaults
retry = True

View file

@ -1,4 +1,5 @@
import importlib_metadata
try:
version = importlib_metadata.version("litellm")
except:

View file

@ -1,20 +1,21 @@
###### LiteLLM Integration with GPT Cache #########
import gptcache
# openai.ChatCompletion._llm_handler = litellm.completion
from gptcache.adapter import openai
import litellm
class LiteLLMChatCompletion(gptcache.adapter.openai.ChatCompletion):
@classmethod
def _llm_handler(cls, *llm_args, **llm_kwargs):
return litellm.completion(*llm_args, **llm_kwargs)
completion = LiteLLMChatCompletion.create
###### End of LiteLLM Integration with GPT Cache #########
# ####### Example usage ###############
# from gptcache import cache
# completion = LiteLLMChatCompletion.create
@ -23,9 +24,3 @@ completion = LiteLLMChatCompletion.create
# cache.set_openai_key()
# result = completion(model="claude-2", messages=[{"role": "user", "content": "cto of litellm"}])
# print(result)

View file

@ -1,5 +1,6 @@
import requests, traceback, json, os
class LiteDebugger:
user_email = None
dashboard_url = None
@ -15,7 +16,9 @@ class LiteDebugger:
self.user_email = os.getenv("LITELLM_EMAIL") or email
self.dashboard_url = "https://admin.litellm.ai/" + self.user_email
try:
print(f"\033[92mHere's your LiteLLM Dashboard 👉 \033[94m\033[4m{self.dashboard_url}\033[0m")
print(
f"\033[92mHere's your LiteLLM Dashboard 👉 \033[94m\033[4m{self.dashboard_url}\033[0m"
)
except:
print(f"Here's your LiteLLM Dashboard 👉 {self.dashboard_url}")
if self.user_email == None:
@ -28,17 +31,25 @@ class LiteDebugger:
)
def input_log_event(
self, model, messages, end_user, litellm_call_id, print_verbose, litellm_params, optional_params
self,
model,
messages,
end_user,
litellm_call_id,
print_verbose,
litellm_params,
optional_params,
):
try:
print_verbose(
f"LiteLLMDebugger: Logging - Enters input logging function for model {model}"
)
def remove_key_value(dictionary, key):
new_dict = dictionary.copy() # Create a copy of the original dictionary
new_dict.pop(key) # Remove the specified key-value pair from the copy
return new_dict
updated_litellm_params = remove_key_value(litellm_params, "logger_fn")
litellm_data_obj = {
@ -49,7 +60,7 @@ class LiteDebugger:
"litellm_call_id": litellm_call_id,
"user_email": self.user_email,
"litellm_params": updated_litellm_params,
"optional_params": optional_params
"optional_params": optional_params,
}
print_verbose(
f"LiteLLMDebugger: Logging - logged data obj {litellm_data_obj}"
@ -65,10 +76,8 @@ class LiteDebugger:
f"[Non-Blocking Error] LiteDebugger: Logging Error - {traceback.format_exc()}"
)
pass
def post_call_log_event(
self, original_response, litellm_call_id, print_verbose
):
def post_call_log_event(self, original_response, litellm_call_id, print_verbose):
try:
litellm_data_obj = {
"status": "received",
@ -110,7 +119,7 @@ class LiteDebugger:
"model": response_obj["model"],
"total_cost": total_cost,
"messages": messages,
"response": response['choices'][0]['message']['content'],
"response": response["choices"][0]["message"]["content"],
"end_user": end_user,
"litellm_call_id": litellm_call_id,
"status": "success",
@ -124,7 +133,12 @@ class LiteDebugger:
headers={"content-type": "application/json"},
data=json.dumps(litellm_data_obj),
)
elif "data" in response_obj and isinstance(response_obj["data"], list) and len(response_obj["data"]) > 0 and "embedding" in response_obj["data"][0]:
elif (
"data" in response_obj
and isinstance(response_obj["data"], list)
and len(response_obj["data"]) > 0
and "embedding" in response_obj["data"][0]
):
print(f"messages: {messages}")
litellm_data_obj = {
"response_time": response_time,
@ -145,7 +159,10 @@ class LiteDebugger:
headers={"content-type": "application/json"},
data=json.dumps(litellm_data_obj),
)
elif isinstance(response_obj, object) and response_obj.__class__.__name__ == "CustomStreamWrapper":
elif (
isinstance(response_obj, object)
and response_obj.__class__.__name__ == "CustomStreamWrapper"
):
litellm_data_obj = {
"response_time": response_time,
"total_cost": total_cost,

View file

@ -12,20 +12,17 @@ dotenv.load_dotenv() # Loading env variables using dotenv
# convert to {completion: xx, tokens: xx}
def parse_usage(usage):
return {
"completion":
usage["completion_tokens"] if "completion_tokens" in usage else 0,
"prompt":
usage["prompt_tokens"] if "prompt_tokens" in usage else 0,
"completion": usage["completion_tokens"] if "completion_tokens" in usage else 0,
"prompt": usage["prompt_tokens"] if "prompt_tokens" in usage else 0,
}
def parse_messages(input):
if input is None:
return None
def clean_message(message):
#if is strin, return as is
# if is strin, return as is
if isinstance(message, str):
return message
@ -50,75 +47,78 @@ class LLMonitorLogger:
# Class variables or attributes
def __init__(self):
# Instance variables
self.api_url = os.getenv(
"LLMONITOR_API_URL") or "https://app.llmonitor.com"
self.api_url = os.getenv("LLMONITOR_API_URL") or "https://app.llmonitor.com"
self.app_id = os.getenv("LLMONITOR_APP_ID")
def log_event(
self,
type,
event,
run_id,
model,
print_verbose,
input=None,
user_id=None,
response_obj=None,
start_time=datetime.datetime.now(),
end_time=datetime.datetime.now(),
error=None,
self,
type,
event,
run_id,
model,
print_verbose,
input=None,
user_id=None,
response_obj=None,
start_time=datetime.datetime.now(),
end_time=datetime.datetime.now(),
error=None,
):
# Method definition
try:
print_verbose(
f"LLMonitor Logging - Logging request for model {model}")
print_verbose(f"LLMonitor Logging - Logging request for model {model}")
if response_obj:
usage = parse_usage(
response_obj['usage']) if 'usage' in response_obj else None
output = response_obj[
'choices'] if 'choices' in response_obj else None
usage = (
parse_usage(response_obj["usage"])
if "usage" in response_obj
else None
)
output = response_obj["choices"] if "choices" in response_obj else None
else:
usage = None
output = None
if error:
error_obj = {'stack': error}
error_obj = {"stack": error}
else:
error_obj = None
data = [{
"type": type,
"name": model,
"runId": run_id,
"app": self.app_id,
'event': 'start',
"timestamp": start_time.isoformat(),
"userId": user_id,
"input": parse_messages(input),
}, {
"type": type,
"runId": run_id,
"app": self.app_id,
"event": event,
"error": error_obj,
"timestamp": end_time.isoformat(),
"userId": user_id,
"output": parse_messages(output),
"tokensUsage": usage,
}]
data = [
{
"type": type,
"name": model,
"runId": run_id,
"app": self.app_id,
"event": "start",
"timestamp": start_time.isoformat(),
"userId": user_id,
"input": parse_messages(input),
},
{
"type": type,
"runId": run_id,
"app": self.app_id,
"event": event,
"error": error_obj,
"timestamp": end_time.isoformat(),
"userId": user_id,
"output": parse_messages(output),
"tokensUsage": usage,
},
]
# print_verbose(f"LLMonitor Logging - final data object: {data}")
response = requests.post(
self.api_url + '/api/report',
headers={'Content-Type': 'application/json'},
json={'events': data})
self.api_url + "/api/report",
headers={"Content-Type": "application/json"},
json={"events": data},
)
print_verbose(f"LLMonitor Logging - response: {response}")
except:
# traceback.print_exc()
print_verbose(
f"LLMonitor Logging Error - {traceback.format_exc()}")
print_verbose(f"LLMonitor Logging Error - {traceback.format_exc()}")
pass

View file

@ -7,6 +7,7 @@ import requests
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
class PromptLayerLogger:
# Class variables or attributes
def __init__(self):
@ -26,7 +27,9 @@ class PromptLayerLogger:
"function_name": "openai.ChatCompletion.create",
"kwargs": kwargs,
"tags": ["hello", "world"],
"request_response": dict(response_obj), # TODO: Check if we need a dict
"request_response": dict(
response_obj
), # TODO: Check if we need a dict
"request_start_time": int(start_time.timestamp()),
"request_end_time": int(end_time.timestamp()),
"api_key": self.key,
@ -34,11 +37,12 @@ class PromptLayerLogger:
# "prompt_id": "<PROMPT ID>",
# "prompt_input_variables": "<Dictionary of variables for prompt>",
# "prompt_version":1,
},
)
print_verbose(f"Prompt Layer Logging - final response object: {request_response}")
print_verbose(
f"Prompt Layer Logging - final response object: {request_response}"
)
except:
# traceback.print_exc()
print_verbose(f"Prompt Layer Error - {traceback.format_exc()}")

View file

@ -94,7 +94,10 @@ class AnthropicLLM:
## COMPLETION CALL
if "stream" in optional_params and optional_params["stream"] == True:
response = requests.post(
self.completion_url, headers=self.headers, data=json.dumps(data), stream=optional_params["stream"]
self.completion_url,
headers=self.headers,
data=json.dumps(data),
stream=optional_params["stream"],
)
return response.iter_lines()
else:
@ -142,4 +145,3 @@ class AnthropicLLM:
self,
): # logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -5,6 +5,7 @@ import time
from typing import Callable
from litellm.utils import ModelResponse
class BasetenError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
@ -15,9 +16,7 @@ class BasetenError(Exception):
class BasetenLLM:
def __init__(
self, encoding, logging_obj, api_key=None
):
def __init__(self, encoding, logging_obj, api_key=None):
self.encoding = encoding
self.completion_url_fragment_1 = "https://app.baseten.co/models/"
self.completion_url_fragment_2 = "/predict"
@ -55,13 +54,9 @@ class BasetenLLM:
for message in messages:
if "role" in message:
if message["role"] == "user":
prompt += (
f"{message['content']}"
)
prompt += f"{message['content']}"
else:
prompt += (
f"{message['content']}"
)
prompt += f"{message['content']}"
else:
prompt += f"{message['content']}"
data = {
@ -78,7 +73,9 @@ class BasetenLLM:
)
## COMPLETION CALL
response = requests.post(
self.completion_url_fragment_1 + model + self.completion_url_fragment_2, headers=self.headers, data=json.dumps(data)
self.completion_url_fragment_1 + model + self.completion_url_fragment_2,
headers=self.headers,
data=json.dumps(data),
)
if "stream" in optional_params and optional_params["stream"] == True:
return response.iter_lines()
@ -100,19 +97,33 @@ class BasetenLLM:
)
else:
if "model_output" in completion_response:
if isinstance(completion_response["model_output"], dict) and "data" in completion_response["model_output"] and isinstance(completion_response["model_output"]["data"], list):
model_response["choices"][0]["message"]["content"] = completion_response["model_output"]["data"][0]
if (
isinstance(completion_response["model_output"], dict)
and "data" in completion_response["model_output"]
and isinstance(
completion_response["model_output"]["data"], list
)
):
model_response["choices"][0]["message"][
"content"
] = completion_response["model_output"]["data"][0]
elif isinstance(completion_response["model_output"], str):
model_response["choices"][0]["message"]["content"] = completion_response["model_output"]
elif "completion" in completion_response and isinstance(completion_response["completion"], str):
model_response["choices"][0]["message"]["content"] = completion_response["completion"]
model_response["choices"][0]["message"][
"content"
] = completion_response["model_output"]
elif "completion" in completion_response and isinstance(
completion_response["completion"], str
):
model_response["choices"][0]["message"][
"content"
] = completion_response["completion"]
else:
raise ValueError(f"Unable to parse response. Original response: {response.text}")
raise ValueError(
f"Unable to parse response. Original response: {response.text}"
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len(
self.encoding.encode(prompt)
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len(self.encoding.encode(prompt))
completion_tokens = len(
self.encoding.encode(model_response["choices"][0]["message"]["content"])
)

View file

@ -103,7 +103,9 @@ def completion(
return completion_with_fallbacks(**args)
if litellm.model_alias_map and model in litellm.model_alias_map:
args["model_alias_map"] = litellm.model_alias_map
model = litellm.model_alias_map[model] # update the model to the actual value if an alias has been passed in
model = litellm.model_alias_map[
model
] # update the model to the actual value if an alias has been passed in
model_response = ModelResponse()
if azure: # this flag is deprecated, remove once notebooks are also updated.
custom_llm_provider = "azure"
@ -146,7 +148,7 @@ def completion(
custom_llm_provider=custom_llm_provider,
custom_api_base=custom_api_base,
litellm_call_id=litellm_call_id,
model_alias_map=litellm.model_alias_map
model_alias_map=litellm.model_alias_map,
)
logging = Logging(
model=model,
@ -216,7 +218,10 @@ def completion(
# note: if a user sets a custom base - we should ensure this works
# allow for the setting of dynamic and stateful api-bases
api_base = (
custom_api_base or litellm.api_base or get_secret("OPENAI_API_BASE") or "https://api.openai.com/v1"
custom_api_base
or litellm.api_base
or get_secret("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
openai.api_base = api_base
openai.api_version = None
@ -255,9 +260,11 @@ def completion(
original_response=response,
additional_args={"headers": litellm.headers},
)
elif (model in litellm.open_ai_text_completion_models or
"ft:babbage-002" in model or # support for finetuned completion models
"ft:davinci-002" in model):
elif (
model in litellm.open_ai_text_completion_models
or "ft:babbage-002" in model
or "ft:davinci-002" in model # support for finetuned completion models
):
openai.api_type = "openai"
openai.api_base = (
litellm.api_base
@ -544,7 +551,10 @@ def completion(
logging.pre_call(input=prompt, api_key=TOGETHER_AI_TOKEN)
print(f"TOGETHER_AI_TOKEN: {TOGETHER_AI_TOKEN}")
if "stream_tokens" in optional_params and optional_params["stream_tokens"] == True:
if (
"stream_tokens" in optional_params
and optional_params["stream_tokens"] == True
):
res = requests.post(
endpoint,
json={
@ -698,9 +708,7 @@ def completion(
):
custom_llm_provider = "baseten"
baseten_key = (
api_key
or litellm.baseten_key
or os.environ.get("BASETEN_API_KEY")
api_key or litellm.baseten_key or os.environ.get("BASETEN_API_KEY")
)
baseten_client = BasetenLLM(
encoding=encoding, api_key=baseten_key, logging_obj=logging
@ -767,11 +775,14 @@ def completion(
model=model, custom_llm_provider=custom_llm_provider, original_exception=e
)
def completion_with_retries(*args, **kwargs):
import tenacity
retryer = tenacity.Retrying(stop=tenacity.stop_after_attempt(3), reraise=True)
return retryer(completion, *args, **kwargs)
def batch_completion(*args, **kwargs):
batch_messages = args[1] if len(args) > 1 else kwargs.get("messages")
completions = []
@ -865,14 +876,16 @@ def embedding(
custom_llm_provider="azure" if azure == True else None,
)
###### Text Completion ################
def text_completion(*args, **kwargs):
if 'prompt' in kwargs:
messages = [{'role': 'system', 'content': kwargs['prompt']}]
kwargs['messages'] = messages
kwargs.pop('prompt')
if "prompt" in kwargs:
messages = [{"role": "system", "content": kwargs["prompt"]}]
kwargs["messages"] = messages
kwargs.pop("prompt")
return completion(*args, **kwargs)
####### HELPER FUNCTIONS ################
## Set verbose to true -> ```litellm.set_verbose = True```
def print_verbose(print_statement):

View file

@ -14,6 +14,7 @@ from litellm import embedding, completion
messages = [{"role": "user", "content": "who is ishaan Github? "}]
# test if response cached
def test_caching():
try:
@ -50,14 +51,16 @@ def test_caching_with_models():
print(f"response1: {response1}")
print(f"response2: {response2}")
pytest.fail(f"Error occurred:")
# test_caching_with_models()
# test_caching_with_models()
def test_gpt_cache():
# INIT GPT Cache #
from gptcache import cache
from litellm.cache import completion
cache.init()
cache.set_openai_key()
@ -67,10 +70,11 @@ def test_gpt_cache():
print(f"response2: {response2}")
print(f"response3: {response3}")
if response3['choices'] != response2['choices']:
if response3["choices"] != response2["choices"]:
# if models are different, it should not return cached response
print(f"response2: {response2}")
print(f"response3: {response3}")
pytest.fail(f"Error occurred:")
# test_gpt_cache()
# test_gpt_cache()

View file

@ -142,9 +142,12 @@ def test_completion_openai():
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_completion_openai_prompt():
try:
response = text_completion(model="gpt-3.5-turbo", prompt="What's the weather in SF?")
response = text_completion(
model="gpt-3.5-turbo", prompt="What's the weather in SF?"
)
response_str = response["choices"][0]["message"]["content"]
response_str_2 = response.choices[0].message.content
print(response)
@ -154,6 +157,7 @@ def test_completion_openai_prompt():
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_completion_text_openai():
try:
response = completion(model="text-davinci-003", messages=messages)

View file

@ -27,7 +27,7 @@
# # print(f"user_model_dict: {user_model_dict}")
# pass
# # normal call
# # normal call
# def test_completion_custom_provider_model_name():
# try:
# response = completion_with_retries(
@ -41,7 +41,7 @@
# pytest.fail(f"Error occurred: {e}")
# # bad call
# # bad call
# # def test_completion_custom_provider_model_name():
# # try:
# # response = completion_with_retries(
@ -54,7 +54,7 @@
# # except Exception as e:
# # pytest.fail(f"Error occurred: {e}")
# # impact on exception mapping
# # impact on exception mapping
# def test_context_window():
# sample_text = "how does a court case get to the Supreme Court?" * 5000
# messages = [{"content": sample_text, "role": "user"}]
@ -83,4 +83,4 @@
# test_context_window()
# test_completion_custom_provider_model_name()
# test_completion_custom_provider_model_name()

View file

@ -22,4 +22,6 @@ def test_openai_embedding():
# print(f"response: {str(response)}")
except Exception as e:
pytest.fail(f"Error occurred: {e}")
test_openai_embedding()
test_openai_embedding()

View file

@ -4,7 +4,7 @@
import sys
import os
sys.path.insert(0, os.path.abspath('../..'))
sys.path.insert(0, os.path.abspath("../.."))
from litellm import completion, embedding
import litellm
@ -17,11 +17,10 @@ litellm.set_verbose = True
def test_chat_openai():
try:
response = completion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}])
response = completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
)
print(response)
@ -31,7 +30,7 @@ def test_chat_openai():
def test_embedding_openai():
try:
response = embedding(model="text-embedding-ada-002", input=['test'])
response = embedding(model="text-embedding-ada-002", input=["test"])
# Add any assertions here to check the response
print(f"response: {str(response)[:50]}")
except Exception as e:
@ -39,4 +38,4 @@ def test_embedding_openai():
test_chat_openai()
test_embedding_openai()
test_embedding_openai()

View file

@ -13,5 +13,19 @@ from litellm import embedding, completion
litellm.set_verbose = True
# Test: Check if the alias created via LiteDebugger is mapped correctly
{"top_p": 0.75, "prompt": "What's the meaning of life?", "num_beams": 4, "temperature": 0.1}
print(completion("llama2", messages=[{"role": "user", "content": "Hey, how's it going?"}], top_p=0.1, temperature=0, num_beams=4, max_tokens=60))
{
"top_p": 0.75,
"prompt": "What's the meaning of life?",
"num_beams": 4,
"temperature": 0.1,
}
print(
completion(
"llama2",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
top_p=0.1,
temperature=0,
num_beams=4,
max_tokens=60,
)
)

View file

@ -3,12 +3,14 @@
import sys, os
import traceback
import time
import time
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from litellm import completion
litellm.logging = False
litellm.set_verbose = False
@ -31,11 +33,11 @@ messages = [{"content": user_message, "role": "user"}]
# complete_response = ""
# start_time = time.time()
# for chunk in response:
# chunk_time = time.time()
# chunk_time = time.time()
# print(f"time since initial request: {chunk_time - start_time:.5f}")
# print(chunk["choices"][0]["delta"])
# complete_response += chunk["choices"][0]["delta"]["content"]
# if complete_response == "":
# if complete_response == "":
# raise Exception("Empty response received")
# except:
# print(f"error occurred: {traceback.format_exc()}")
@ -50,11 +52,11 @@ messages = [{"content": user_message, "role": "user"}]
# response = ""
# start_time = time.time()
# for chunk in response:
# chunk_time = time.time()
# chunk_time = time.time()
# print(f"time since initial request: {chunk_time - start_time:.2f}")
# print(chunk["choices"][0]["delta"])
# response += chunk["choices"][0]["delta"]
# if response == "":
# if response == "":
# raise Exception("Empty response received")
# except:
# print(f"error occurred: {traceback.format_exc()}")
@ -73,7 +75,7 @@ try:
print(f"time since initial request: {chunk_time - start_time:.5f}")
print(chunk["choices"][0]["delta"])
complete_response += chunk["choices"][0]["delta"]["content"]
if complete_response == "":
if complete_response == "":
raise Exception("Empty response received")
except:
print(f"error occurred: {traceback.format_exc()}")
@ -88,11 +90,11 @@ except:
# )
# complete_response = ""
# for chunk in response:
# chunk_time = time.time()
# chunk_time = time.time()
# print(f"time since initial request: {chunk_time - start_time:.2f}")
# print(chunk["choices"][0]["delta"])
# complete_response += chunk["choices"][0]["delta"]["content"] if len(chunk["choices"][0]["delta"].keys()) > 0 else ""
# if complete_response == "":
# complete_response += chunk["choices"][0]["delta"]["content"] if len(chunk["choices"][0]["delta"].keys()) > 0 else ""
# if complete_response == "":
# raise Exception("Empty response received")
# except:
# print(f"error occurred: {traceback.format_exc()}")
@ -102,16 +104,20 @@ except:
try:
start_time = time.time()
response = completion(
model="Replit-Code-3B", messages=messages, logger_fn=logger_fn, stream= True
model="Replit-Code-3B", messages=messages, logger_fn=logger_fn, stream=True
)
complete_response = ""
print(f"returned response object: {response}")
for chunk in response:
chunk_time = time.time()
chunk_time = time.time()
print(f"time since initial request: {chunk_time - start_time:.2f}")
print(chunk["choices"][0]["delta"])
complete_response += chunk["choices"][0]["delta"]["content"] if len(chunk["choices"][0]["delta"].keys()) > 0 else ""
if complete_response == "":
complete_response += (
chunk["choices"][0]["delta"]["content"]
if len(chunk["choices"][0]["delta"].keys()) > 0
else ""
)
if complete_response == "":
raise Exception("Empty response received")
except:
print(f"error occurred: {traceback.format_exc()}")
@ -121,16 +127,23 @@ except:
try:
start_time = time.time()
response = completion(
model="together_ai/bigcode/starcoder", messages=messages, logger_fn=logger_fn, stream= True
model="together_ai/bigcode/starcoder",
messages=messages,
logger_fn=logger_fn,
stream=True,
)
complete_response = ""
print(f"returned response object: {response}")
for chunk in response:
chunk_time = time.time()
complete_response += chunk["choices"][0]["delta"]["content"] if len(chunk["choices"][0]["delta"].keys()) > 0 else ""
chunk_time = time.time()
complete_response += (
chunk["choices"][0]["delta"]["content"]
if len(chunk["choices"][0]["delta"].keys()) > 0
else ""
)
if len(complete_response) > 0:
print(complete_response)
if complete_response == "":
if complete_response == "":
raise Exception("Empty response received")
except:
print(f"error occurred: {traceback.format_exc()}")
@ -144,11 +157,11 @@ except:
# )
# response = ""
# for chunk in response:
# chunk_time = time.time()
# chunk_time = time.time()
# print(f"time since initial request: {chunk_time - start_time:.2f}")
# print(chunk["choices"][0]["delta"])
# response += chunk["choices"][0]["delta"]
# if response == "":
# if response == "":
# raise Exception("Empty response received")
# except:
# print(f"error occurred: {traceback.format_exc()}")
@ -162,11 +175,11 @@ except:
# )
# response = ""
# for chunk in response:
# chunk_time = time.time()
# chunk_time = time.time()
# print(f"time since initial request: {chunk_time - start_time:.2f}")
# print(chunk["choices"][0]["delta"])
# response += chunk["choices"][0]["delta"]
# if response == "":
# if response == "":
# raise Exception("Empty response received")
# except:
# print(f"error occurred: {traceback.format_exc()}")

View file

@ -69,7 +69,6 @@ last_fetched_at_keys = None
class Message(OpenAIObject):
def __init__(self, content="default", role="assistant", **params):
super(Message, self).__init__(**params)
self.content = content
@ -77,12 +76,7 @@ class Message(OpenAIObject):
class Choices(OpenAIObject):
def __init__(self,
finish_reason="stop",
index=0,
message=Message(),
**params):
def __init__(self, finish_reason="stop", index=0, message=Message(), **params):
super(Choices, self).__init__(**params)
self.finish_reason = finish_reason
self.index = index
@ -90,22 +84,20 @@ class Choices(OpenAIObject):
class ModelResponse(OpenAIObject):
def __init__(self,
choices=None,
created=None,
model=None,
usage=None,
**params):
def __init__(self, choices=None, created=None, model=None, usage=None, **params):
super(ModelResponse, self).__init__(**params)
self.choices = choices if choices else [Choices()]
self.created = created
self.model = model
self.usage = (usage if usage else {
"prompt_tokens": None,
"completion_tokens": None,
"total_tokens": None,
})
self.usage = (
usage
if usage
else {
"prompt_tokens": None,
"completion_tokens": None,
"total_tokens": None,
}
)
def to_dict_recursive(self):
d = super().to_dict_recursive()
@ -173,7 +165,9 @@ class Logging:
self.model_call_details["api_key"] = api_key
self.model_call_details["additional_args"] = additional_args
if model: # if model name was changes pre-call, overwrite the initial model call name with the new one
if (
model
): # if model name was changes pre-call, overwrite the initial model call name with the new one
self.model_call_details["model"] = model
# User Logging -> if you pass in a custom logging function
@ -203,8 +197,7 @@ class Logging:
model=model,
messages=messages,
end_user=litellm._thread_context.user,
litellm_call_id=self.
litellm_params["litellm_call_id"],
litellm_call_id=self.litellm_params["litellm_call_id"],
print_verbose=print_verbose,
)
@ -217,8 +210,7 @@ class Logging:
model=model,
messages=messages,
end_user=litellm._thread_context.user,
litellm_call_id=self.
litellm_params["litellm_call_id"],
litellm_call_id=self.litellm_params["litellm_call_id"],
litellm_params=self.model_call_details["litellm_params"],
optional_params=self.model_call_details["optional_params"],
print_verbose=print_verbose,
@ -263,7 +255,7 @@ class Logging:
print_verbose(
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {traceback.format_exc()}"
)
# Input Integration Logging -> If you want to log the fact that an attempt to call the model was made
for callback in litellm.input_callback:
try:
@ -274,8 +266,7 @@ class Logging:
print_verbose(f"liteDebuggerClient: {liteDebuggerClient}")
liteDebuggerClient.post_call_log_event(
original_response=original_response,
litellm_call_id=self.
litellm_params["litellm_call_id"],
litellm_call_id=self.litellm_params["litellm_call_id"],
print_verbose=print_verbose,
)
except:
@ -295,6 +286,7 @@ class Logging:
# Add more methods as needed
def exception_logging(
additional_args={},
logger_fn=None,
@ -329,13 +321,18 @@ def exception_logging(
# make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking
def client(original_function):
global liteDebuggerClient, get_all_keys
def function_setup(
*args, **kwargs
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
try:
global callback_list, add_breadcrumb, user_logger_fn
if litellm.email is not None or os.getenv("LITELLM_EMAIL", None) is not None or litellm.token is not None or os.getenv("LITELLM_TOKEN", None): # add to input, success and failure callbacks if user is using hosted product
if (
litellm.email is not None
or os.getenv("LITELLM_EMAIL", None) is not None
or litellm.token is not None
or os.getenv("LITELLM_TOKEN", None)
): # add to input, success and failure callbacks if user is using hosted product
get_all_keys()
if "lite_debugger" not in callback_list and litellm.logging:
litellm.input_callback.append("lite_debugger")
@ -381,11 +378,12 @@ def client(original_function):
if litellm.telemetry:
try:
model = args[0] if len(args) > 0 else kwargs["model"]
exception = kwargs[
"exception"] if "exception" in kwargs else None
custom_llm_provider = (kwargs["custom_llm_provider"]
if "custom_llm_provider" in kwargs else
None)
exception = kwargs["exception"] if "exception" in kwargs else None
custom_llm_provider = (
kwargs["custom_llm_provider"]
if "custom_llm_provider" in kwargs
else None
)
safe_crash_reporting(
model=model,
exception=exception,
@ -410,10 +408,10 @@ def client(original_function):
def check_cache(*args, **kwargs):
try: # never block execution
prompt = get_prompt(*args, **kwargs)
if (prompt != None): # check if messages / prompt exists
if prompt != None: # check if messages / prompt exists
if litellm.caching_with_models:
# if caching with model names is enabled, key is prompt + model name
if ("model" in kwargs):
if "model" in kwargs:
cache_key = prompt + kwargs["model"]
if cache_key in local_cache:
return local_cache[cache_key]
@ -423,7 +421,7 @@ def client(original_function):
return result
else:
return None
return None # default to return None
return None # default to return None
except:
return None
@ -431,7 +429,7 @@ def client(original_function):
try: # never block execution
prompt = get_prompt(*args, **kwargs)
if litellm.caching_with_models: # caching with model + prompt
if ("model" in kwargs):
if "model" in kwargs:
cache_key = prompt + kwargs["model"]
local_cache[cache_key] = result
else: # caching based only on prompts
@ -449,7 +447,8 @@ def client(original_function):
start_time = datetime.datetime.now()
# [OPTIONAL] CHECK CACHE
if (litellm.caching or litellm.caching_with_models) and (
cached_result := check_cache(*args, **kwargs)) is not None:
cached_result := check_cache(*args, **kwargs)
) is not None:
result = cached_result
return result
# MODEL CALL
@ -458,25 +457,22 @@ def client(original_function):
return result
end_time = datetime.datetime.now()
# [OPTIONAL] ADD TO CACHE
if (litellm.caching or litellm.caching_with_models):
if litellm.caching or litellm.caching_with_models:
add_cache(result, *args, **kwargs)
# LOG SUCCESS
my_thread = threading.Thread(
target=handle_success,
args=(args, kwargs, result, start_time,
end_time)) # don't interrupt execution of main thread
target=handle_success, args=(args, kwargs, result, start_time, end_time)
) # don't interrupt execution of main thread
my_thread.start()
# RETURN RESULT
return result
except Exception as e:
traceback_exception = traceback.format_exc()
crash_reporting(*args, **kwargs, exception=traceback_exception)
end_time = datetime.datetime.now()
my_thread = threading.Thread(
target=handle_failure,
args=(e, traceback_exception, start_time, end_time, args,
kwargs),
args=(e, traceback_exception, start_time, end_time, args, kwargs),
) # don't interrupt execution of main thread
my_thread.start()
if hasattr(e, "message"):
@ -506,18 +502,18 @@ def token_counter(model, text):
return num_tokens
def cost_per_token(model="gpt-3.5-turbo",
prompt_tokens=0,
completion_tokens=0):
def cost_per_token(model="gpt-3.5-turbo", prompt_tokens=0, completion_tokens=0):
# given
prompt_tokens_cost_usd_dollar = 0
completion_tokens_cost_usd_dollar = 0
model_cost_ref = litellm.model_cost
if model in model_cost_ref:
prompt_tokens_cost_usd_dollar = (
model_cost_ref[model]["input_cost_per_token"] * prompt_tokens)
model_cost_ref[model]["input_cost_per_token"] * prompt_tokens
)
completion_tokens_cost_usd_dollar = (
model_cost_ref[model]["output_cost_per_token"] * completion_tokens)
model_cost_ref[model]["output_cost_per_token"] * completion_tokens
)
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
else:
# calculate average input cost
@ -538,9 +534,8 @@ def completion_cost(model="gpt-3.5-turbo", prompt="", completion=""):
prompt_tokens = token_counter(model=model, text=prompt)
completion_tokens = token_counter(model=model, text=completion)
prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar = cost_per_token(
model=model,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens)
model=model, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens
)
return prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
@ -558,7 +553,7 @@ def get_litellm_params(
custom_llm_provider=None,
custom_api_base=None,
litellm_call_id=None,
model_alias_map=None
model_alias_map=None,
):
litellm_params = {
"return_async": return_async,
@ -569,13 +564,13 @@ def get_litellm_params(
"custom_llm_provider": custom_llm_provider,
"custom_api_base": custom_api_base,
"litellm_call_id": litellm_call_id,
"model_alias_map": model_alias_map
"model_alias_map": model_alias_map,
}
return litellm_params
def get_optional_params( # use the openai defaults
def get_optional_params( # use the openai defaults
# 12 optional params
functions=[],
function_call="",
@ -588,7 +583,7 @@ def get_optional_params( # use the openai defaults
presence_penalty=0,
frequency_penalty=0,
logit_bias={},
num_beams=1,
num_beams=1,
user="",
deployment_id=None,
model=None,
@ -635,8 +630,9 @@ def get_optional_params( # use the openai defaults
optional_params["max_tokens"] = max_tokens
if frequency_penalty != 0:
optional_params["frequency_penalty"] = frequency_penalty
elif (model == "chat-bison"
): # chat-bison has diff args from chat-bison@001 ty Google
elif (
model == "chat-bison"
): # chat-bison has diff args from chat-bison@001 ty Google
if temperature != 1:
optional_params["temperature"] = temperature
if top_p != 1:
@ -702,10 +698,7 @@ def load_test_model(
test_prompt = prompt
if num_calls:
test_calls = num_calls
messages = [[{
"role": "user",
"content": test_prompt
}] for _ in range(test_calls)]
messages = [[{"role": "user", "content": test_prompt}] for _ in range(test_calls)]
start_time = time.time()
try:
litellm.batch_completion(
@ -743,15 +736,17 @@ def set_callbacks(callback_list):
try:
import sentry_sdk
except ImportError:
print_verbose(
"Package 'sentry_sdk' is missing. Installing it...")
print_verbose("Package 'sentry_sdk' is missing. Installing it...")
subprocess.check_call(
[sys.executable, "-m", "pip", "install", "sentry_sdk"])
[sys.executable, "-m", "pip", "install", "sentry_sdk"]
)
import sentry_sdk
sentry_sdk_instance = sentry_sdk
sentry_trace_rate = (os.environ.get("SENTRY_API_TRACE_RATE")
if "SENTRY_API_TRACE_RATE" in os.environ
else "1.0")
sentry_trace_rate = (
os.environ.get("SENTRY_API_TRACE_RATE")
if "SENTRY_API_TRACE_RATE" in os.environ
else "1.0"
)
sentry_sdk_instance.init(
dsn=os.environ.get("SENTRY_API_URL"),
traces_sample_rate=float(sentry_trace_rate),
@ -762,10 +757,10 @@ def set_callbacks(callback_list):
try:
from posthog import Posthog
except ImportError:
print_verbose(
"Package 'posthog' is missing. Installing it...")
print_verbose("Package 'posthog' is missing. Installing it...")
subprocess.check_call(
[sys.executable, "-m", "pip", "install", "posthog"])
[sys.executable, "-m", "pip", "install", "posthog"]
)
from posthog import Posthog
posthog = Posthog(
project_api_key=os.environ.get("POSTHOG_API_KEY"),
@ -775,10 +770,10 @@ def set_callbacks(callback_list):
try:
from slack_bolt import App
except ImportError:
print_verbose(
"Package 'slack_bolt' is missing. Installing it...")
print_verbose("Package 'slack_bolt' is missing. Installing it...")
subprocess.check_call(
[sys.executable, "-m", "pip", "install", "slack_bolt"])
[sys.executable, "-m", "pip", "install", "slack_bolt"]
)
from slack_bolt import App
slack_app = App(
token=os.environ.get("SLACK_API_TOKEN"),
@ -809,8 +804,7 @@ def set_callbacks(callback_list):
raise e
def handle_failure(exception, traceback_exception, start_time, end_time, args,
kwargs):
def handle_failure(exception, traceback_exception, start_time, end_time, args, kwargs):
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient
try:
# print_verbose(f"handle_failure args: {args}")
@ -820,7 +814,8 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args,
failure_handler = additional_details.pop("failure_handler", None)
additional_details["Event_Name"] = additional_details.pop(
"failed_event_name", "litellm.failed_query")
"failed_event_name", "litellm.failed_query"
)
print_verbose(f"self.failure_callback: {litellm.failure_callback}")
for callback in litellm.failure_callback:
try:
@ -835,8 +830,9 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args,
for detail in additional_details:
slack_msg += f"{detail}: {additional_details[detail]}\n"
slack_msg += f"Traceback: {traceback_exception}"
slack_app.client.chat_postMessage(channel=alerts_channel,
text=slack_msg)
slack_app.client.chat_postMessage(
channel=alerts_channel, text=slack_msg
)
elif callback == "sentry":
capture_exception(exception)
elif callback == "posthog":
@ -855,8 +851,9 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args,
print_verbose(f"ph_obj: {ph_obj}")
print_verbose(f"PostHog Event Name: {event_name}")
if "user_id" in additional_details:
posthog.capture(additional_details["user_id"],
event_name, ph_obj)
posthog.capture(
additional_details["user_id"], event_name, ph_obj
)
else: # PostHog calls require a unique id to identify a user - https://posthog.com/docs/libraries/python
unique_id = str(uuid.uuid4())
posthog.capture(unique_id, event_name)
@ -870,10 +867,10 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args,
"created": time.time(),
"error": traceback_exception,
"usage": {
"prompt_tokens":
prompt_token_calculator(model, messages=messages),
"completion_tokens":
0,
"prompt_tokens": prompt_token_calculator(
model, messages=messages
),
"completion_tokens": 0,
},
}
berrispendLogger.log_event(
@ -892,10 +889,10 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args,
"model": model,
"created": time.time(),
"usage": {
"prompt_tokens":
prompt_token_calculator(model, messages=messages),
"completion_tokens":
0,
"prompt_tokens": prompt_token_calculator(
model, messages=messages
),
"completion_tokens": 0,
},
}
aispendLogger.log_event(
@ -910,10 +907,13 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args,
model = args[0] if len(args) > 0 else kwargs["model"]
input = args[1] if len(args) > 1 else kwargs.get(
"messages", kwargs.get("input", None))
input = (
args[1]
if len(args) > 1
else kwargs.get("messages", kwargs.get("input", None))
)
type = 'embed' if 'input' in kwargs else 'llm'
type = "embed" if "input" in kwargs else "llm"
llmonitorLogger.log_event(
type=type,
@ -937,10 +937,10 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args,
"created": time.time(),
"error": traceback_exception,
"usage": {
"prompt_tokens":
prompt_token_calculator(model, messages=messages),
"completion_tokens":
0,
"prompt_tokens": prompt_token_calculator(
model, messages=messages
),
"completion_tokens": 0,
},
}
supabaseClient.log_event(
@ -957,16 +957,28 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args,
print_verbose("reaches lite_debugger for logging!")
print_verbose(f"liteDebuggerClient: {liteDebuggerClient}")
model = args[0] if len(args) > 0 else kwargs["model"]
messages = args[1] if len(args) > 1 else kwargs.get("messages", [{"role": "user", "content": ' '.join(kwargs.get("input", ""))}])
messages = (
args[1]
if len(args) > 1
else kwargs.get(
"messages",
[
{
"role": "user",
"content": " ".join(kwargs.get("input", "")),
}
],
)
)
result = {
"model": model,
"created": time.time(),
"error": traceback_exception,
"usage": {
"prompt_tokens":
prompt_token_calculator(model, messages=messages),
"completion_tokens":
0,
"prompt_tokens": prompt_token_calculator(
model, messages=messages
),
"completion_tokens": 0,
},
}
liteDebuggerClient.log_event(
@ -1002,11 +1014,16 @@ def handle_success(args, kwargs, result, start_time, end_time):
global heliconeLogger, aispendLogger, supabaseClient, liteDebuggerClient, llmonitorLogger
try:
model = args[0] if len(args) > 0 else kwargs["model"]
input = args[1] if len(args) > 1 else kwargs.get("messages", kwargs.get("input", None))
input = (
args[1]
if len(args) > 1
else kwargs.get("messages", kwargs.get("input", None))
)
success_handler = additional_details.pop("success_handler", None)
failure_handler = additional_details.pop("failure_handler", None)
additional_details["Event_Name"] = additional_details.pop(
"successful_event_name", "litellm.succes_query")
"successful_event_name", "litellm.succes_query"
)
for callback in litellm.success_callback:
try:
if callback == "posthog":
@ -1015,8 +1032,9 @@ def handle_success(args, kwargs, result, start_time, end_time):
ph_obj[detail] = additional_details[detail]
event_name = additional_details["Event_Name"]
if "user_id" in additional_details:
posthog.capture(additional_details["user_id"],
event_name, ph_obj)
posthog.capture(
additional_details["user_id"], event_name, ph_obj
)
else: # PostHog calls require a unique id to identify a user - https://posthog.com/docs/libraries/python
unique_id = str(uuid.uuid4())
posthog.capture(unique_id, event_name, ph_obj)
@ -1025,8 +1043,9 @@ def handle_success(args, kwargs, result, start_time, end_time):
slack_msg = ""
for detail in additional_details:
slack_msg += f"{detail}: {additional_details[detail]}\n"
slack_app.client.chat_postMessage(channel=alerts_channel,
text=slack_msg)
slack_app.client.chat_postMessage(
channel=alerts_channel, text=slack_msg
)
elif callback == "helicone":
print_verbose("reaches helicone for logging!")
model = args[0] if len(args) > 0 else kwargs["model"]
@ -1043,11 +1062,14 @@ def handle_success(args, kwargs, result, start_time, end_time):
print_verbose("reaches llmonitor for logging!")
model = args[0] if len(args) > 0 else kwargs["model"]
input = args[1] if len(args) > 1 else kwargs.get(
"messages", kwargs.get("input", None))
input = (
args[1]
if len(args) > 1
else kwargs.get("messages", kwargs.get("input", None))
)
#if contains input, it's 'embedding', otherwise 'llm'
type = 'embed' if 'input' in kwargs else 'llm'
# if contains input, it's 'embedding', otherwise 'llm'
type = "embed" if "input" in kwargs else "llm"
llmonitorLogger.log_event(
type=type,
@ -1069,7 +1091,6 @@ def handle_success(args, kwargs, result, start_time, end_time):
start_time=start_time,
end_time=end_time,
print_verbose=print_verbose,
)
elif callback == "aispend":
print_verbose("reaches aispend for logging!")
@ -1084,7 +1105,11 @@ def handle_success(args, kwargs, result, start_time, end_time):
elif callback == "supabase":
print_verbose("reaches supabase for logging!")
model = args[0] if len(args) > 0 else kwargs["model"]
messages = args[1] if len(args) > 1 else kwargs.get("messages", {"role": "user", "content": ""})
messages = (
args[1]
if len(args) > 1
else kwargs.get("messages", {"role": "user", "content": ""})
)
print(f"supabaseClient: {supabaseClient}")
supabaseClient.log_event(
model=model,
@ -1099,7 +1124,19 @@ def handle_success(args, kwargs, result, start_time, end_time):
elif callback == "lite_debugger":
print_verbose("reaches lite_debugger for logging!")
print_verbose(f"liteDebuggerClient: {liteDebuggerClient}")
messages = args[1] if len(args) > 1 else kwargs.get("messages", [{"role": "user", "content": ' '.join(kwargs.get("input", ""))}])
messages = (
args[1]
if len(args) > 1
else kwargs.get(
"messages",
[
{
"role": "user",
"content": " ".join(kwargs.get("input", "")),
}
],
)
)
liteDebuggerClient.log_event(
model=model,
messages=messages,
@ -1129,6 +1166,7 @@ def handle_success(args, kwargs, result, start_time, end_time):
)
pass
def acreate(*args, **kwargs): ## Thin client to handle the acreate langchain call
return litellm.acompletion(*args, **kwargs)
@ -1170,28 +1208,43 @@ def modify_integration(integration_name, integration_params):
if "table_name" in integration_params:
Supabase.supabase_table_name = integration_params["table_name"]
####### [BETA] HOSTED PRODUCT ################ - https://docs.litellm.ai/docs/debugging/hosted_debugging
def get_all_keys(llm_provider=None):
try:
global last_fetched_at_keys
# if user is using hosted product -> instantiate their env with their hosted api keys - refresh every 5 minutes
print_verbose(f"Reaches get all keys, llm_provider: {llm_provider}")
user_email = os.getenv("LITELLM_EMAIL") or litellm.email or litellm.token or os.getenv("LITELLM_TOKEN")
user_email = (
os.getenv("LITELLM_EMAIL")
or litellm.email
or litellm.token
or os.getenv("LITELLM_TOKEN")
)
if user_email:
time_delta = 0
if last_fetched_at_keys != None:
current_time = time.time()
time_delta = current_time - last_fetched_at_keys
if time_delta > 300 or last_fetched_at_keys == None or llm_provider: # if the llm provider is passed in , assume this happening due to an AuthError for that provider
if (
time_delta > 300 or last_fetched_at_keys == None or llm_provider
): # if the llm provider is passed in , assume this happening due to an AuthError for that provider
# make the api call
last_fetched_at = time.time()
print_verbose(f"last_fetched_at: {last_fetched_at}")
response = requests.post(url="http://api.litellm.ai/get_all_keys", headers={"content-type": "application/json"}, data=json.dumps({"user_email": user_email}))
response = requests.post(
url="http://api.litellm.ai/get_all_keys",
headers={"content-type": "application/json"},
data=json.dumps({"user_email": user_email}),
)
print_verbose(f"get model key response: {response.text}")
data = response.json()
# update model list
for key, value in data["model_keys"].items(): # follows the LITELLM API KEY format - <UPPERCASE_PROVIDER_NAME>_API_KEY - e.g. HUGGINGFACE_API_KEY
for key, value in data[
"model_keys"
].items(): # follows the LITELLM API KEY format - <UPPERCASE_PROVIDER_NAME>_API_KEY - e.g. HUGGINGFACE_API_KEY
os.environ[key] = value
# set model alias map
for model_alias, value in data["model_alias_map"].items():
@ -1200,19 +1253,31 @@ def get_all_keys(llm_provider=None):
return None
return None
except:
print_verbose(f"[Non-Blocking Error] get_all_keys error - {traceback.format_exc()}")
print_verbose(
f"[Non-Blocking Error] get_all_keys error - {traceback.format_exc()}"
)
pass
def get_model_list():
global last_fetched_at
try:
# if user is using hosted product -> get their updated model list
user_email = os.getenv("LITELLM_EMAIL") or litellm.email or litellm.token or os.getenv("LITELLM_TOKEN")
user_email = (
os.getenv("LITELLM_EMAIL")
or litellm.email
or litellm.token
or os.getenv("LITELLM_TOKEN")
)
if user_email:
# make the api call
last_fetched_at = time.time()
print(f"last_fetched_at: {last_fetched_at}")
response = requests.post(url="http://api.litellm.ai/get_model_list", headers={"content-type": "application/json"}, data=json.dumps({"user_email": user_email}))
response = requests.post(
url="http://api.litellm.ai/get_model_list",
headers={"content-type": "application/json"},
data=json.dumps({"user_email": user_email}),
)
print_verbose(f"get_model_list response: {response.text}")
data = response.json()
# update model list
@ -1224,12 +1289,14 @@ def get_model_list():
if f"{item.upper()}_API_KEY" not in os.environ:
missing_llm_provider = item
break
# update environment - if required
# update environment - if required
threading.Thread(target=get_all_keys, args=(missing_llm_provider)).start()
return model_list
return [] # return empty list by default
return [] # return empty list by default
except:
print_verbose(f"[Non-Blocking Error] get_all_keys error - {traceback.format_exc()}")
print_verbose(
f"[Non-Blocking Error] get_all_keys error - {traceback.format_exc()}"
)
####### EXCEPTION MAPPING ################
@ -1253,36 +1320,33 @@ def exception_type(model, original_exception, custom_llm_provider):
exception_type = ""
if "claude" in model: # one of the anthropics
if hasattr(original_exception, "status_code"):
print_verbose(
f"status_code: {original_exception.status_code}")
print_verbose(f"status_code: {original_exception.status_code}")
if original_exception.status_code == 401:
exception_mapping_worked = True
raise AuthenticationError(
message=
f"AnthropicException - {original_exception.message}",
message=f"AnthropicException - {original_exception.message}",
llm_provider="anthropic",
)
elif original_exception.status_code == 400:
exception_mapping_worked = True
raise InvalidRequestError(
message=
f"AnthropicException - {original_exception.message}",
message=f"AnthropicException - {original_exception.message}",
model=model,
llm_provider="anthropic",
)
elif original_exception.status_code == 429:
exception_mapping_worked = True
raise RateLimitError(
message=
f"AnthropicException - {original_exception.message}",
message=f"AnthropicException - {original_exception.message}",
llm_provider="anthropic",
)
elif ("Could not resolve authentication method. Expected either api_key or auth_token to be set."
in error_str):
elif (
"Could not resolve authentication method. Expected either api_key or auth_token to be set."
in error_str
):
exception_mapping_worked = True
raise AuthenticationError(
message=
f"AnthropicException - {original_exception.message}",
message=f"AnthropicException - {original_exception.message}",
llm_provider="anthropic",
)
elif "replicate" in model:
@ -1306,36 +1370,35 @@ def exception_type(model, original_exception, custom_llm_provider):
llm_provider="replicate",
)
elif (
exception_type == "ReplicateError"
exception_type == "ReplicateError"
): # ReplicateError implies an error on Replicate server side, not user side
raise ServiceUnavailableError(
message=f"ReplicateException - {error_str}",
llm_provider="replicate",
)
elif model == "command-nightly": # Cohere
if ("invalid api token" in error_str
or "No API key provided." in error_str):
if (
"invalid api token" in error_str
or "No API key provided." in error_str
):
exception_mapping_worked = True
raise AuthenticationError(
message=
f"CohereException - {original_exception.message}",
message=f"CohereException - {original_exception.message}",
llm_provider="cohere",
)
elif "too many tokens" in error_str:
exception_mapping_worked = True
raise InvalidRequestError(
message=
f"CohereException - {original_exception.message}",
message=f"CohereException - {original_exception.message}",
model=model,
llm_provider="cohere",
)
elif (
"CohereConnectionError" in exception_type
"CohereConnectionError" in exception_type
): # cohere seems to fire these errors when we load test it (1k+ messages / min)
exception_mapping_worked = True
raise RateLimitError(
message=
f"CohereException - {original_exception.message}",
message=f"CohereException - {original_exception.message}",
llm_provider="cohere",
)
elif custom_llm_provider == "huggingface":
@ -1343,23 +1406,20 @@ def exception_type(model, original_exception, custom_llm_provider):
if original_exception.status_code == 401:
exception_mapping_worked = True
raise AuthenticationError(
message=
f"HuggingfaceException - {original_exception.message}",
message=f"HuggingfaceException - {original_exception.message}",
llm_provider="huggingface",
)
elif original_exception.status_code == 400:
exception_mapping_worked = True
raise InvalidRequestError(
message=
f"HuggingfaceException - {original_exception.message}",
message=f"HuggingfaceException - {original_exception.message}",
model=model,
llm_provider="huggingface",
)
elif original_exception.status_code == 429:
exception_mapping_worked = True
raise RateLimitError(
message=
f"HuggingfaceException - {original_exception.message}",
message=f"HuggingfaceException - {original_exception.message}",
llm_provider="huggingface",
)
raise original_exception # base case - return the original exception
@ -1375,8 +1435,10 @@ def exception_type(model, original_exception, custom_llm_provider):
},
exception=e,
)
## AUTH ERROR
if isinstance(e, AuthenticationError) and (litellm.email or "LITELLM_EMAIL" in os.environ):
## AUTH ERROR
if isinstance(e, AuthenticationError) and (
litellm.email or "LITELLM_EMAIL" in os.environ
):
threading.Thread(target=get_all_keys, args=(e.llm_provider,)).start()
if exception_mapping_worked:
raise e
@ -1391,7 +1453,8 @@ def safe_crash_reporting(model=None, exception=None, custom_llm_provider=None):
"exception": str(exception),
"custom_llm_provider": custom_llm_provider,
}
threading.Thread(target=litellm_telemetry, args=(data, )).start()
threading.Thread(target=litellm_telemetry, args=(data,)).start()
def get_or_generate_uuid():
uuid_file = "litellm_uuid.txt"
@ -1445,8 +1508,7 @@ def get_secret(secret_name):
# TODO: check which secret manager is being used
# currently only supports Infisical
try:
secret = litellm.secret_manager_client.get_secret(
secret_name).secret_value
secret = litellm.secret_manager_client.get_secret(secret_name).secret_value
except:
secret = None
return secret
@ -1460,7 +1522,6 @@ def get_secret(secret_name):
# wraps the completion stream to return the correct format for the model
# replicate/anthropic/cohere
class CustomStreamWrapper:
def __init__(self, completion_stream, model, custom_llm_provider=None):
self.model = model
self.custom_llm_provider = custom_llm_provider
@ -1509,8 +1570,9 @@ class CustomStreamWrapper:
elif self.model == "replicate":
chunk = next(self.completion_stream)
completion_obj["content"] = chunk
elif (self.custom_llm_provider and self.custom_llm_provider == "together_ai") or ("togethercomputer"
in self.model):
elif (
self.custom_llm_provider and self.custom_llm_provider == "together_ai"
) or ("togethercomputer" in self.model):
chunk = next(self.completion_stream)
text_data = self.handle_together_ai_chunk(chunk)
if text_data == "":
@ -1545,9 +1607,9 @@ def read_config_args(config_path):
########## ollama implementation ############################
async def get_ollama_response_stream(api_base="http://localhost:11434",
model="llama2",
prompt="Why is the sky blue?"):
async def get_ollama_response_stream(
api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?"
):
session = aiohttp.ClientSession()
url = f"{api_base}/api/generate"
data = {
@ -1570,11 +1632,7 @@ async def get_ollama_response_stream(api_base="http://localhost:11434",
"content": "",
}
completion_obj["content"] = j["response"]
yield {
"choices": [{
"delta": completion_obj
}]
}
yield {"choices": [{"delta": completion_obj}]}
# self.responses.append(j["response"])
# yield "blank"
except Exception as e: