forked from phoenix/litellm-mirror
formatting improvements
This commit is contained in:
parent
70b323e0f5
commit
b713acb0a4
17 changed files with 464 additions and 323 deletions
|
@ -5,8 +5,12 @@ input_callback: List[str] = []
|
|||
success_callback: List[str] = []
|
||||
failure_callback: List[str] = []
|
||||
set_verbose = False
|
||||
email: Optional[str] = None # for hosted dashboard. Learn more - https://docs.litellm.ai/docs/debugging/hosted_debugging
|
||||
token: Optional[str] = None # for hosted dashboard. Learn more - https://docs.litellm.ai/docs/debugging/hosted_debugging
|
||||
email: Optional[
|
||||
str
|
||||
] = None # for hosted dashboard. Learn more - https://docs.litellm.ai/docs/debugging/hosted_debugging
|
||||
token: Optional[
|
||||
str
|
||||
] = None # for hosted dashboard. Learn more - https://docs.litellm.ai/docs/debugging/hosted_debugging
|
||||
telemetry = True
|
||||
max_tokens = 256 # OpenAI Defaults
|
||||
retry = True
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import importlib_metadata
|
||||
|
||||
try:
|
||||
version = importlib_metadata.version("litellm")
|
||||
except:
|
||||
|
|
|
@ -1,20 +1,21 @@
|
|||
|
||||
###### LiteLLM Integration with GPT Cache #########
|
||||
import gptcache
|
||||
|
||||
# openai.ChatCompletion._llm_handler = litellm.completion
|
||||
from gptcache.adapter import openai
|
||||
import litellm
|
||||
|
||||
|
||||
class LiteLLMChatCompletion(gptcache.adapter.openai.ChatCompletion):
|
||||
@classmethod
|
||||
def _llm_handler(cls, *llm_args, **llm_kwargs):
|
||||
return litellm.completion(*llm_args, **llm_kwargs)
|
||||
|
||||
|
||||
|
||||
completion = LiteLLMChatCompletion.create
|
||||
###### End of LiteLLM Integration with GPT Cache #########
|
||||
|
||||
|
||||
|
||||
# ####### Example usage ###############
|
||||
# from gptcache import cache
|
||||
# completion = LiteLLMChatCompletion.create
|
||||
|
@ -23,9 +24,3 @@ completion = LiteLLMChatCompletion.create
|
|||
# cache.set_openai_key()
|
||||
# result = completion(model="claude-2", messages=[{"role": "user", "content": "cto of litellm"}])
|
||||
# print(result)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import requests, traceback, json, os
|
||||
|
||||
|
||||
class LiteDebugger:
|
||||
user_email = None
|
||||
dashboard_url = None
|
||||
|
@ -15,7 +16,9 @@ class LiteDebugger:
|
|||
self.user_email = os.getenv("LITELLM_EMAIL") or email
|
||||
self.dashboard_url = "https://admin.litellm.ai/" + self.user_email
|
||||
try:
|
||||
print(f"\033[92mHere's your LiteLLM Dashboard 👉 \033[94m\033[4m{self.dashboard_url}\033[0m")
|
||||
print(
|
||||
f"\033[92mHere's your LiteLLM Dashboard 👉 \033[94m\033[4m{self.dashboard_url}\033[0m"
|
||||
)
|
||||
except:
|
||||
print(f"Here's your LiteLLM Dashboard 👉 {self.dashboard_url}")
|
||||
if self.user_email == None:
|
||||
|
@ -28,17 +31,25 @@ class LiteDebugger:
|
|||
)
|
||||
|
||||
def input_log_event(
|
||||
self, model, messages, end_user, litellm_call_id, print_verbose, litellm_params, optional_params
|
||||
self,
|
||||
model,
|
||||
messages,
|
||||
end_user,
|
||||
litellm_call_id,
|
||||
print_verbose,
|
||||
litellm_params,
|
||||
optional_params,
|
||||
):
|
||||
try:
|
||||
print_verbose(
|
||||
f"LiteLLMDebugger: Logging - Enters input logging function for model {model}"
|
||||
)
|
||||
|
||||
def remove_key_value(dictionary, key):
|
||||
new_dict = dictionary.copy() # Create a copy of the original dictionary
|
||||
new_dict.pop(key) # Remove the specified key-value pair from the copy
|
||||
return new_dict
|
||||
|
||||
|
||||
updated_litellm_params = remove_key_value(litellm_params, "logger_fn")
|
||||
|
||||
litellm_data_obj = {
|
||||
|
@ -49,7 +60,7 @@ class LiteDebugger:
|
|||
"litellm_call_id": litellm_call_id,
|
||||
"user_email": self.user_email,
|
||||
"litellm_params": updated_litellm_params,
|
||||
"optional_params": optional_params
|
||||
"optional_params": optional_params,
|
||||
}
|
||||
print_verbose(
|
||||
f"LiteLLMDebugger: Logging - logged data obj {litellm_data_obj}"
|
||||
|
@ -65,10 +76,8 @@ class LiteDebugger:
|
|||
f"[Non-Blocking Error] LiteDebugger: Logging Error - {traceback.format_exc()}"
|
||||
)
|
||||
pass
|
||||
|
||||
def post_call_log_event(
|
||||
self, original_response, litellm_call_id, print_verbose
|
||||
):
|
||||
|
||||
def post_call_log_event(self, original_response, litellm_call_id, print_verbose):
|
||||
try:
|
||||
litellm_data_obj = {
|
||||
"status": "received",
|
||||
|
@ -110,7 +119,7 @@ class LiteDebugger:
|
|||
"model": response_obj["model"],
|
||||
"total_cost": total_cost,
|
||||
"messages": messages,
|
||||
"response": response['choices'][0]['message']['content'],
|
||||
"response": response["choices"][0]["message"]["content"],
|
||||
"end_user": end_user,
|
||||
"litellm_call_id": litellm_call_id,
|
||||
"status": "success",
|
||||
|
@ -124,7 +133,12 @@ class LiteDebugger:
|
|||
headers={"content-type": "application/json"},
|
||||
data=json.dumps(litellm_data_obj),
|
||||
)
|
||||
elif "data" in response_obj and isinstance(response_obj["data"], list) and len(response_obj["data"]) > 0 and "embedding" in response_obj["data"][0]:
|
||||
elif (
|
||||
"data" in response_obj
|
||||
and isinstance(response_obj["data"], list)
|
||||
and len(response_obj["data"]) > 0
|
||||
and "embedding" in response_obj["data"][0]
|
||||
):
|
||||
print(f"messages: {messages}")
|
||||
litellm_data_obj = {
|
||||
"response_time": response_time,
|
||||
|
@ -145,7 +159,10 @@ class LiteDebugger:
|
|||
headers={"content-type": "application/json"},
|
||||
data=json.dumps(litellm_data_obj),
|
||||
)
|
||||
elif isinstance(response_obj, object) and response_obj.__class__.__name__ == "CustomStreamWrapper":
|
||||
elif (
|
||||
isinstance(response_obj, object)
|
||||
and response_obj.__class__.__name__ == "CustomStreamWrapper"
|
||||
):
|
||||
litellm_data_obj = {
|
||||
"response_time": response_time,
|
||||
"total_cost": total_cost,
|
||||
|
|
|
@ -12,20 +12,17 @@ dotenv.load_dotenv() # Loading env variables using dotenv
|
|||
# convert to {completion: xx, tokens: xx}
|
||||
def parse_usage(usage):
|
||||
return {
|
||||
"completion":
|
||||
usage["completion_tokens"] if "completion_tokens" in usage else 0,
|
||||
"prompt":
|
||||
usage["prompt_tokens"] if "prompt_tokens" in usage else 0,
|
||||
"completion": usage["completion_tokens"] if "completion_tokens" in usage else 0,
|
||||
"prompt": usage["prompt_tokens"] if "prompt_tokens" in usage else 0,
|
||||
}
|
||||
|
||||
|
||||
def parse_messages(input):
|
||||
|
||||
if input is None:
|
||||
return None
|
||||
|
||||
def clean_message(message):
|
||||
#if is strin, return as is
|
||||
# if is strin, return as is
|
||||
if isinstance(message, str):
|
||||
return message
|
||||
|
||||
|
@ -50,75 +47,78 @@ class LLMonitorLogger:
|
|||
# Class variables or attributes
|
||||
def __init__(self):
|
||||
# Instance variables
|
||||
self.api_url = os.getenv(
|
||||
"LLMONITOR_API_URL") or "https://app.llmonitor.com"
|
||||
self.api_url = os.getenv("LLMONITOR_API_URL") or "https://app.llmonitor.com"
|
||||
self.app_id = os.getenv("LLMONITOR_APP_ID")
|
||||
|
||||
def log_event(
|
||||
self,
|
||||
type,
|
||||
event,
|
||||
run_id,
|
||||
model,
|
||||
print_verbose,
|
||||
input=None,
|
||||
user_id=None,
|
||||
response_obj=None,
|
||||
start_time=datetime.datetime.now(),
|
||||
end_time=datetime.datetime.now(),
|
||||
error=None,
|
||||
self,
|
||||
type,
|
||||
event,
|
||||
run_id,
|
||||
model,
|
||||
print_verbose,
|
||||
input=None,
|
||||
user_id=None,
|
||||
response_obj=None,
|
||||
start_time=datetime.datetime.now(),
|
||||
end_time=datetime.datetime.now(),
|
||||
error=None,
|
||||
):
|
||||
# Method definition
|
||||
try:
|
||||
print_verbose(
|
||||
f"LLMonitor Logging - Logging request for model {model}")
|
||||
print_verbose(f"LLMonitor Logging - Logging request for model {model}")
|
||||
|
||||
if response_obj:
|
||||
usage = parse_usage(
|
||||
response_obj['usage']) if 'usage' in response_obj else None
|
||||
output = response_obj[
|
||||
'choices'] if 'choices' in response_obj else None
|
||||
usage = (
|
||||
parse_usage(response_obj["usage"])
|
||||
if "usage" in response_obj
|
||||
else None
|
||||
)
|
||||
output = response_obj["choices"] if "choices" in response_obj else None
|
||||
else:
|
||||
usage = None
|
||||
output = None
|
||||
|
||||
if error:
|
||||
error_obj = {'stack': error}
|
||||
error_obj = {"stack": error}
|
||||
|
||||
else:
|
||||
error_obj = None
|
||||
|
||||
data = [{
|
||||
"type": type,
|
||||
"name": model,
|
||||
"runId": run_id,
|
||||
"app": self.app_id,
|
||||
'event': 'start',
|
||||
"timestamp": start_time.isoformat(),
|
||||
"userId": user_id,
|
||||
"input": parse_messages(input),
|
||||
}, {
|
||||
"type": type,
|
||||
"runId": run_id,
|
||||
"app": self.app_id,
|
||||
"event": event,
|
||||
"error": error_obj,
|
||||
"timestamp": end_time.isoformat(),
|
||||
"userId": user_id,
|
||||
"output": parse_messages(output),
|
||||
"tokensUsage": usage,
|
||||
}]
|
||||
data = [
|
||||
{
|
||||
"type": type,
|
||||
"name": model,
|
||||
"runId": run_id,
|
||||
"app": self.app_id,
|
||||
"event": "start",
|
||||
"timestamp": start_time.isoformat(),
|
||||
"userId": user_id,
|
||||
"input": parse_messages(input),
|
||||
},
|
||||
{
|
||||
"type": type,
|
||||
"runId": run_id,
|
||||
"app": self.app_id,
|
||||
"event": event,
|
||||
"error": error_obj,
|
||||
"timestamp": end_time.isoformat(),
|
||||
"userId": user_id,
|
||||
"output": parse_messages(output),
|
||||
"tokensUsage": usage,
|
||||
},
|
||||
]
|
||||
|
||||
# print_verbose(f"LLMonitor Logging - final data object: {data}")
|
||||
|
||||
response = requests.post(
|
||||
self.api_url + '/api/report',
|
||||
headers={'Content-Type': 'application/json'},
|
||||
json={'events': data})
|
||||
self.api_url + "/api/report",
|
||||
headers={"Content-Type": "application/json"},
|
||||
json={"events": data},
|
||||
)
|
||||
|
||||
print_verbose(f"LLMonitor Logging - response: {response}")
|
||||
except:
|
||||
# traceback.print_exc()
|
||||
print_verbose(
|
||||
f"LLMonitor Logging Error - {traceback.format_exc()}")
|
||||
print_verbose(f"LLMonitor Logging Error - {traceback.format_exc()}")
|
||||
pass
|
||||
|
|
|
@ -7,6 +7,7 @@ import requests
|
|||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
|
||||
|
||||
class PromptLayerLogger:
|
||||
# Class variables or attributes
|
||||
def __init__(self):
|
||||
|
@ -26,7 +27,9 @@ class PromptLayerLogger:
|
|||
"function_name": "openai.ChatCompletion.create",
|
||||
"kwargs": kwargs,
|
||||
"tags": ["hello", "world"],
|
||||
"request_response": dict(response_obj), # TODO: Check if we need a dict
|
||||
"request_response": dict(
|
||||
response_obj
|
||||
), # TODO: Check if we need a dict
|
||||
"request_start_time": int(start_time.timestamp()),
|
||||
"request_end_time": int(end_time.timestamp()),
|
||||
"api_key": self.key,
|
||||
|
@ -34,11 +37,12 @@ class PromptLayerLogger:
|
|||
# "prompt_id": "<PROMPT ID>",
|
||||
# "prompt_input_variables": "<Dictionary of variables for prompt>",
|
||||
# "prompt_version":1,
|
||||
|
||||
},
|
||||
)
|
||||
|
||||
print_verbose(f"Prompt Layer Logging - final response object: {request_response}")
|
||||
print_verbose(
|
||||
f"Prompt Layer Logging - final response object: {request_response}"
|
||||
)
|
||||
except:
|
||||
# traceback.print_exc()
|
||||
print_verbose(f"Prompt Layer Error - {traceback.format_exc()}")
|
||||
|
|
|
@ -94,7 +94,10 @@ class AnthropicLLM:
|
|||
## COMPLETION CALL
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
response = requests.post(
|
||||
self.completion_url, headers=self.headers, data=json.dumps(data), stream=optional_params["stream"]
|
||||
self.completion_url,
|
||||
headers=self.headers,
|
||||
data=json.dumps(data),
|
||||
stream=optional_params["stream"],
|
||||
)
|
||||
return response.iter_lines()
|
||||
else:
|
||||
|
@ -142,4 +145,3 @@ class AnthropicLLM:
|
|||
self,
|
||||
): # logic for parsing in - calling - parsing out model embedding calls
|
||||
pass
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ import time
|
|||
from typing import Callable
|
||||
from litellm.utils import ModelResponse
|
||||
|
||||
|
||||
class BasetenError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
|
@ -15,9 +16,7 @@ class BasetenError(Exception):
|
|||
|
||||
|
||||
class BasetenLLM:
|
||||
def __init__(
|
||||
self, encoding, logging_obj, api_key=None
|
||||
):
|
||||
def __init__(self, encoding, logging_obj, api_key=None):
|
||||
self.encoding = encoding
|
||||
self.completion_url_fragment_1 = "https://app.baseten.co/models/"
|
||||
self.completion_url_fragment_2 = "/predict"
|
||||
|
@ -55,13 +54,9 @@ class BasetenLLM:
|
|||
for message in messages:
|
||||
if "role" in message:
|
||||
if message["role"] == "user":
|
||||
prompt += (
|
||||
f"{message['content']}"
|
||||
)
|
||||
prompt += f"{message['content']}"
|
||||
else:
|
||||
prompt += (
|
||||
f"{message['content']}"
|
||||
)
|
||||
prompt += f"{message['content']}"
|
||||
else:
|
||||
prompt += f"{message['content']}"
|
||||
data = {
|
||||
|
@ -78,7 +73,9 @@ class BasetenLLM:
|
|||
)
|
||||
## COMPLETION CALL
|
||||
response = requests.post(
|
||||
self.completion_url_fragment_1 + model + self.completion_url_fragment_2, headers=self.headers, data=json.dumps(data)
|
||||
self.completion_url_fragment_1 + model + self.completion_url_fragment_2,
|
||||
headers=self.headers,
|
||||
data=json.dumps(data),
|
||||
)
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
return response.iter_lines()
|
||||
|
@ -100,19 +97,33 @@ class BasetenLLM:
|
|||
)
|
||||
else:
|
||||
if "model_output" in completion_response:
|
||||
if isinstance(completion_response["model_output"], dict) and "data" in completion_response["model_output"] and isinstance(completion_response["model_output"]["data"], list):
|
||||
model_response["choices"][0]["message"]["content"] = completion_response["model_output"]["data"][0]
|
||||
if (
|
||||
isinstance(completion_response["model_output"], dict)
|
||||
and "data" in completion_response["model_output"]
|
||||
and isinstance(
|
||||
completion_response["model_output"]["data"], list
|
||||
)
|
||||
):
|
||||
model_response["choices"][0]["message"][
|
||||
"content"
|
||||
] = completion_response["model_output"]["data"][0]
|
||||
elif isinstance(completion_response["model_output"], str):
|
||||
model_response["choices"][0]["message"]["content"] = completion_response["model_output"]
|
||||
elif "completion" in completion_response and isinstance(completion_response["completion"], str):
|
||||
model_response["choices"][0]["message"]["content"] = completion_response["completion"]
|
||||
model_response["choices"][0]["message"][
|
||||
"content"
|
||||
] = completion_response["model_output"]
|
||||
elif "completion" in completion_response and isinstance(
|
||||
completion_response["completion"], str
|
||||
):
|
||||
model_response["choices"][0]["message"][
|
||||
"content"
|
||||
] = completion_response["completion"]
|
||||
else:
|
||||
raise ValueError(f"Unable to parse response. Original response: {response.text}")
|
||||
raise ValueError(
|
||||
f"Unable to parse response. Original response: {response.text}"
|
||||
)
|
||||
|
||||
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
||||
prompt_tokens = len(
|
||||
self.encoding.encode(prompt)
|
||||
)
|
||||
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
||||
prompt_tokens = len(self.encoding.encode(prompt))
|
||||
completion_tokens = len(
|
||||
self.encoding.encode(model_response["choices"][0]["message"]["content"])
|
||||
)
|
||||
|
|
|
@ -103,7 +103,9 @@ def completion(
|
|||
return completion_with_fallbacks(**args)
|
||||
if litellm.model_alias_map and model in litellm.model_alias_map:
|
||||
args["model_alias_map"] = litellm.model_alias_map
|
||||
model = litellm.model_alias_map[model] # update the model to the actual value if an alias has been passed in
|
||||
model = litellm.model_alias_map[
|
||||
model
|
||||
] # update the model to the actual value if an alias has been passed in
|
||||
model_response = ModelResponse()
|
||||
if azure: # this flag is deprecated, remove once notebooks are also updated.
|
||||
custom_llm_provider = "azure"
|
||||
|
@ -146,7 +148,7 @@ def completion(
|
|||
custom_llm_provider=custom_llm_provider,
|
||||
custom_api_base=custom_api_base,
|
||||
litellm_call_id=litellm_call_id,
|
||||
model_alias_map=litellm.model_alias_map
|
||||
model_alias_map=litellm.model_alias_map,
|
||||
)
|
||||
logging = Logging(
|
||||
model=model,
|
||||
|
@ -216,7 +218,10 @@ def completion(
|
|||
# note: if a user sets a custom base - we should ensure this works
|
||||
# allow for the setting of dynamic and stateful api-bases
|
||||
api_base = (
|
||||
custom_api_base or litellm.api_base or get_secret("OPENAI_API_BASE") or "https://api.openai.com/v1"
|
||||
custom_api_base
|
||||
or litellm.api_base
|
||||
or get_secret("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
openai.api_base = api_base
|
||||
openai.api_version = None
|
||||
|
@ -255,9 +260,11 @@ def completion(
|
|||
original_response=response,
|
||||
additional_args={"headers": litellm.headers},
|
||||
)
|
||||
elif (model in litellm.open_ai_text_completion_models or
|
||||
"ft:babbage-002" in model or # support for finetuned completion models
|
||||
"ft:davinci-002" in model):
|
||||
elif (
|
||||
model in litellm.open_ai_text_completion_models
|
||||
or "ft:babbage-002" in model
|
||||
or "ft:davinci-002" in model # support for finetuned completion models
|
||||
):
|
||||
openai.api_type = "openai"
|
||||
openai.api_base = (
|
||||
litellm.api_base
|
||||
|
@ -544,7 +551,10 @@ def completion(
|
|||
logging.pre_call(input=prompt, api_key=TOGETHER_AI_TOKEN)
|
||||
|
||||
print(f"TOGETHER_AI_TOKEN: {TOGETHER_AI_TOKEN}")
|
||||
if "stream_tokens" in optional_params and optional_params["stream_tokens"] == True:
|
||||
if (
|
||||
"stream_tokens" in optional_params
|
||||
and optional_params["stream_tokens"] == True
|
||||
):
|
||||
res = requests.post(
|
||||
endpoint,
|
||||
json={
|
||||
|
@ -698,9 +708,7 @@ def completion(
|
|||
):
|
||||
custom_llm_provider = "baseten"
|
||||
baseten_key = (
|
||||
api_key
|
||||
or litellm.baseten_key
|
||||
or os.environ.get("BASETEN_API_KEY")
|
||||
api_key or litellm.baseten_key or os.environ.get("BASETEN_API_KEY")
|
||||
)
|
||||
baseten_client = BasetenLLM(
|
||||
encoding=encoding, api_key=baseten_key, logging_obj=logging
|
||||
|
@ -767,11 +775,14 @@ def completion(
|
|||
model=model, custom_llm_provider=custom_llm_provider, original_exception=e
|
||||
)
|
||||
|
||||
|
||||
def completion_with_retries(*args, **kwargs):
|
||||
import tenacity
|
||||
|
||||
retryer = tenacity.Retrying(stop=tenacity.stop_after_attempt(3), reraise=True)
|
||||
return retryer(completion, *args, **kwargs)
|
||||
|
||||
|
||||
def batch_completion(*args, **kwargs):
|
||||
batch_messages = args[1] if len(args) > 1 else kwargs.get("messages")
|
||||
completions = []
|
||||
|
@ -865,14 +876,16 @@ def embedding(
|
|||
custom_llm_provider="azure" if azure == True else None,
|
||||
)
|
||||
|
||||
|
||||
###### Text Completion ################
|
||||
def text_completion(*args, **kwargs):
|
||||
if 'prompt' in kwargs:
|
||||
messages = [{'role': 'system', 'content': kwargs['prompt']}]
|
||||
kwargs['messages'] = messages
|
||||
kwargs.pop('prompt')
|
||||
if "prompt" in kwargs:
|
||||
messages = [{"role": "system", "content": kwargs["prompt"]}]
|
||||
kwargs["messages"] = messages
|
||||
kwargs.pop("prompt")
|
||||
return completion(*args, **kwargs)
|
||||
|
||||
|
||||
####### HELPER FUNCTIONS ################
|
||||
## Set verbose to true -> ```litellm.set_verbose = True```
|
||||
def print_verbose(print_statement):
|
||||
|
|
|
@ -14,6 +14,7 @@ from litellm import embedding, completion
|
|||
|
||||
messages = [{"role": "user", "content": "who is ishaan Github? "}]
|
||||
|
||||
|
||||
# test if response cached
|
||||
def test_caching():
|
||||
try:
|
||||
|
@ -50,14 +51,16 @@ def test_caching_with_models():
|
|||
print(f"response1: {response1}")
|
||||
print(f"response2: {response2}")
|
||||
pytest.fail(f"Error occurred:")
|
||||
# test_caching_with_models()
|
||||
|
||||
|
||||
# test_caching_with_models()
|
||||
|
||||
|
||||
def test_gpt_cache():
|
||||
# INIT GPT Cache #
|
||||
from gptcache import cache
|
||||
from litellm.cache import completion
|
||||
|
||||
cache.init()
|
||||
cache.set_openai_key()
|
||||
|
||||
|
@ -67,10 +70,11 @@ def test_gpt_cache():
|
|||
print(f"response2: {response2}")
|
||||
print(f"response3: {response3}")
|
||||
|
||||
if response3['choices'] != response2['choices']:
|
||||
if response3["choices"] != response2["choices"]:
|
||||
# if models are different, it should not return cached response
|
||||
print(f"response2: {response2}")
|
||||
print(f"response3: {response3}")
|
||||
pytest.fail(f"Error occurred:")
|
||||
# test_gpt_cache()
|
||||
|
||||
|
||||
# test_gpt_cache()
|
||||
|
|
|
@ -142,9 +142,12 @@ def test_completion_openai():
|
|||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_completion_openai_prompt():
|
||||
try:
|
||||
response = text_completion(model="gpt-3.5-turbo", prompt="What's the weather in SF?")
|
||||
response = text_completion(
|
||||
model="gpt-3.5-turbo", prompt="What's the weather in SF?"
|
||||
)
|
||||
response_str = response["choices"][0]["message"]["content"]
|
||||
response_str_2 = response.choices[0].message.content
|
||||
print(response)
|
||||
|
@ -154,6 +157,7 @@ def test_completion_openai_prompt():
|
|||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_completion_text_openai():
|
||||
try:
|
||||
response = completion(model="text-davinci-003", messages=messages)
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
# # print(f"user_model_dict: {user_model_dict}")
|
||||
# pass
|
||||
|
||||
# # normal call
|
||||
# # normal call
|
||||
# def test_completion_custom_provider_model_name():
|
||||
# try:
|
||||
# response = completion_with_retries(
|
||||
|
@ -41,7 +41,7 @@
|
|||
# pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# # bad call
|
||||
# # bad call
|
||||
# # def test_completion_custom_provider_model_name():
|
||||
# # try:
|
||||
# # response = completion_with_retries(
|
||||
|
@ -54,7 +54,7 @@
|
|||
# # except Exception as e:
|
||||
# # pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# # impact on exception mapping
|
||||
# # impact on exception mapping
|
||||
# def test_context_window():
|
||||
# sample_text = "how does a court case get to the Supreme Court?" * 5000
|
||||
# messages = [{"content": sample_text, "role": "user"}]
|
||||
|
@ -83,4 +83,4 @@
|
|||
|
||||
# test_context_window()
|
||||
|
||||
# test_completion_custom_provider_model_name()
|
||||
# test_completion_custom_provider_model_name()
|
||||
|
|
|
@ -22,4 +22,6 @@ def test_openai_embedding():
|
|||
# print(f"response: {str(response)}")
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
test_openai_embedding()
|
||||
|
||||
|
||||
test_openai_embedding()
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.abspath('../..'))
|
||||
sys.path.insert(0, os.path.abspath("../.."))
|
||||
|
||||
from litellm import completion, embedding
|
||||
import litellm
|
||||
|
@ -17,11 +17,10 @@ litellm.set_verbose = True
|
|||
|
||||
def test_chat_openai():
|
||||
try:
|
||||
response = completion(model="gpt-3.5-turbo",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm openai"
|
||||
}])
|
||||
response = completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
|
||||
)
|
||||
|
||||
print(response)
|
||||
|
||||
|
@ -31,7 +30,7 @@ def test_chat_openai():
|
|||
|
||||
def test_embedding_openai():
|
||||
try:
|
||||
response = embedding(model="text-embedding-ada-002", input=['test'])
|
||||
response = embedding(model="text-embedding-ada-002", input=["test"])
|
||||
# Add any assertions here to check the response
|
||||
print(f"response: {str(response)[:50]}")
|
||||
except Exception as e:
|
||||
|
@ -39,4 +38,4 @@ def test_embedding_openai():
|
|||
|
||||
|
||||
test_chat_openai()
|
||||
test_embedding_openai()
|
||||
test_embedding_openai()
|
||||
|
|
|
@ -13,5 +13,19 @@ from litellm import embedding, completion
|
|||
litellm.set_verbose = True
|
||||
|
||||
# Test: Check if the alias created via LiteDebugger is mapped correctly
|
||||
{"top_p": 0.75, "prompt": "What's the meaning of life?", "num_beams": 4, "temperature": 0.1}
|
||||
print(completion("llama2", messages=[{"role": "user", "content": "Hey, how's it going?"}], top_p=0.1, temperature=0, num_beams=4, max_tokens=60))
|
||||
{
|
||||
"top_p": 0.75,
|
||||
"prompt": "What's the meaning of life?",
|
||||
"num_beams": 4,
|
||||
"temperature": 0.1,
|
||||
}
|
||||
print(
|
||||
completion(
|
||||
"llama2",
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
top_p=0.1,
|
||||
temperature=0,
|
||||
num_beams=4,
|
||||
max_tokens=60,
|
||||
)
|
||||
)
|
||||
|
|
|
@ -3,12 +3,14 @@
|
|||
|
||||
import sys, os
|
||||
import traceback
|
||||
import time
|
||||
import time
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
from litellm import completion
|
||||
|
||||
litellm.logging = False
|
||||
litellm.set_verbose = False
|
||||
|
||||
|
@ -31,11 +33,11 @@ messages = [{"content": user_message, "role": "user"}]
|
|||
# complete_response = ""
|
||||
# start_time = time.time()
|
||||
# for chunk in response:
|
||||
# chunk_time = time.time()
|
||||
# chunk_time = time.time()
|
||||
# print(f"time since initial request: {chunk_time - start_time:.5f}")
|
||||
# print(chunk["choices"][0]["delta"])
|
||||
# complete_response += chunk["choices"][0]["delta"]["content"]
|
||||
# if complete_response == "":
|
||||
# if complete_response == "":
|
||||
# raise Exception("Empty response received")
|
||||
# except:
|
||||
# print(f"error occurred: {traceback.format_exc()}")
|
||||
|
@ -50,11 +52,11 @@ messages = [{"content": user_message, "role": "user"}]
|
|||
# response = ""
|
||||
# start_time = time.time()
|
||||
# for chunk in response:
|
||||
# chunk_time = time.time()
|
||||
# chunk_time = time.time()
|
||||
# print(f"time since initial request: {chunk_time - start_time:.2f}")
|
||||
# print(chunk["choices"][0]["delta"])
|
||||
# response += chunk["choices"][0]["delta"]
|
||||
# if response == "":
|
||||
# if response == "":
|
||||
# raise Exception("Empty response received")
|
||||
# except:
|
||||
# print(f"error occurred: {traceback.format_exc()}")
|
||||
|
@ -73,7 +75,7 @@ try:
|
|||
print(f"time since initial request: {chunk_time - start_time:.5f}")
|
||||
print(chunk["choices"][0]["delta"])
|
||||
complete_response += chunk["choices"][0]["delta"]["content"]
|
||||
if complete_response == "":
|
||||
if complete_response == "":
|
||||
raise Exception("Empty response received")
|
||||
except:
|
||||
print(f"error occurred: {traceback.format_exc()}")
|
||||
|
@ -88,11 +90,11 @@ except:
|
|||
# )
|
||||
# complete_response = ""
|
||||
# for chunk in response:
|
||||
# chunk_time = time.time()
|
||||
# chunk_time = time.time()
|
||||
# print(f"time since initial request: {chunk_time - start_time:.2f}")
|
||||
# print(chunk["choices"][0]["delta"])
|
||||
# complete_response += chunk["choices"][0]["delta"]["content"] if len(chunk["choices"][0]["delta"].keys()) > 0 else ""
|
||||
# if complete_response == "":
|
||||
# complete_response += chunk["choices"][0]["delta"]["content"] if len(chunk["choices"][0]["delta"].keys()) > 0 else ""
|
||||
# if complete_response == "":
|
||||
# raise Exception("Empty response received")
|
||||
# except:
|
||||
# print(f"error occurred: {traceback.format_exc()}")
|
||||
|
@ -102,16 +104,20 @@ except:
|
|||
try:
|
||||
start_time = time.time()
|
||||
response = completion(
|
||||
model="Replit-Code-3B", messages=messages, logger_fn=logger_fn, stream= True
|
||||
model="Replit-Code-3B", messages=messages, logger_fn=logger_fn, stream=True
|
||||
)
|
||||
complete_response = ""
|
||||
print(f"returned response object: {response}")
|
||||
for chunk in response:
|
||||
chunk_time = time.time()
|
||||
chunk_time = time.time()
|
||||
print(f"time since initial request: {chunk_time - start_time:.2f}")
|
||||
print(chunk["choices"][0]["delta"])
|
||||
complete_response += chunk["choices"][0]["delta"]["content"] if len(chunk["choices"][0]["delta"].keys()) > 0 else ""
|
||||
if complete_response == "":
|
||||
complete_response += (
|
||||
chunk["choices"][0]["delta"]["content"]
|
||||
if len(chunk["choices"][0]["delta"].keys()) > 0
|
||||
else ""
|
||||
)
|
||||
if complete_response == "":
|
||||
raise Exception("Empty response received")
|
||||
except:
|
||||
print(f"error occurred: {traceback.format_exc()}")
|
||||
|
@ -121,16 +127,23 @@ except:
|
|||
try:
|
||||
start_time = time.time()
|
||||
response = completion(
|
||||
model="together_ai/bigcode/starcoder", messages=messages, logger_fn=logger_fn, stream= True
|
||||
model="together_ai/bigcode/starcoder",
|
||||
messages=messages,
|
||||
logger_fn=logger_fn,
|
||||
stream=True,
|
||||
)
|
||||
complete_response = ""
|
||||
print(f"returned response object: {response}")
|
||||
for chunk in response:
|
||||
chunk_time = time.time()
|
||||
complete_response += chunk["choices"][0]["delta"]["content"] if len(chunk["choices"][0]["delta"].keys()) > 0 else ""
|
||||
chunk_time = time.time()
|
||||
complete_response += (
|
||||
chunk["choices"][0]["delta"]["content"]
|
||||
if len(chunk["choices"][0]["delta"].keys()) > 0
|
||||
else ""
|
||||
)
|
||||
if len(complete_response) > 0:
|
||||
print(complete_response)
|
||||
if complete_response == "":
|
||||
if complete_response == "":
|
||||
raise Exception("Empty response received")
|
||||
except:
|
||||
print(f"error occurred: {traceback.format_exc()}")
|
||||
|
@ -144,11 +157,11 @@ except:
|
|||
# )
|
||||
# response = ""
|
||||
# for chunk in response:
|
||||
# chunk_time = time.time()
|
||||
# chunk_time = time.time()
|
||||
# print(f"time since initial request: {chunk_time - start_time:.2f}")
|
||||
# print(chunk["choices"][0]["delta"])
|
||||
# response += chunk["choices"][0]["delta"]
|
||||
# if response == "":
|
||||
# if response == "":
|
||||
# raise Exception("Empty response received")
|
||||
# except:
|
||||
# print(f"error occurred: {traceback.format_exc()}")
|
||||
|
@ -162,11 +175,11 @@ except:
|
|||
# )
|
||||
# response = ""
|
||||
# for chunk in response:
|
||||
# chunk_time = time.time()
|
||||
# chunk_time = time.time()
|
||||
# print(f"time since initial request: {chunk_time - start_time:.2f}")
|
||||
# print(chunk["choices"][0]["delta"])
|
||||
# response += chunk["choices"][0]["delta"]
|
||||
# if response == "":
|
||||
# if response == "":
|
||||
# raise Exception("Empty response received")
|
||||
# except:
|
||||
# print(f"error occurred: {traceback.format_exc()}")
|
||||
|
|
398
litellm/utils.py
398
litellm/utils.py
|
@ -69,7 +69,6 @@ last_fetched_at_keys = None
|
|||
|
||||
|
||||
class Message(OpenAIObject):
|
||||
|
||||
def __init__(self, content="default", role="assistant", **params):
|
||||
super(Message, self).__init__(**params)
|
||||
self.content = content
|
||||
|
@ -77,12 +76,7 @@ class Message(OpenAIObject):
|
|||
|
||||
|
||||
class Choices(OpenAIObject):
|
||||
|
||||
def __init__(self,
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=Message(),
|
||||
**params):
|
||||
def __init__(self, finish_reason="stop", index=0, message=Message(), **params):
|
||||
super(Choices, self).__init__(**params)
|
||||
self.finish_reason = finish_reason
|
||||
self.index = index
|
||||
|
@ -90,22 +84,20 @@ class Choices(OpenAIObject):
|
|||
|
||||
|
||||
class ModelResponse(OpenAIObject):
|
||||
|
||||
def __init__(self,
|
||||
choices=None,
|
||||
created=None,
|
||||
model=None,
|
||||
usage=None,
|
||||
**params):
|
||||
def __init__(self, choices=None, created=None, model=None, usage=None, **params):
|
||||
super(ModelResponse, self).__init__(**params)
|
||||
self.choices = choices if choices else [Choices()]
|
||||
self.created = created
|
||||
self.model = model
|
||||
self.usage = (usage if usage else {
|
||||
"prompt_tokens": None,
|
||||
"completion_tokens": None,
|
||||
"total_tokens": None,
|
||||
})
|
||||
self.usage = (
|
||||
usage
|
||||
if usage
|
||||
else {
|
||||
"prompt_tokens": None,
|
||||
"completion_tokens": None,
|
||||
"total_tokens": None,
|
||||
}
|
||||
)
|
||||
|
||||
def to_dict_recursive(self):
|
||||
d = super().to_dict_recursive()
|
||||
|
@ -173,7 +165,9 @@ class Logging:
|
|||
self.model_call_details["api_key"] = api_key
|
||||
self.model_call_details["additional_args"] = additional_args
|
||||
|
||||
if model: # if model name was changes pre-call, overwrite the initial model call name with the new one
|
||||
if (
|
||||
model
|
||||
): # if model name was changes pre-call, overwrite the initial model call name with the new one
|
||||
self.model_call_details["model"] = model
|
||||
|
||||
# User Logging -> if you pass in a custom logging function
|
||||
|
@ -203,8 +197,7 @@ class Logging:
|
|||
model=model,
|
||||
messages=messages,
|
||||
end_user=litellm._thread_context.user,
|
||||
litellm_call_id=self.
|
||||
litellm_params["litellm_call_id"],
|
||||
litellm_call_id=self.litellm_params["litellm_call_id"],
|
||||
print_verbose=print_verbose,
|
||||
)
|
||||
|
||||
|
@ -217,8 +210,7 @@ class Logging:
|
|||
model=model,
|
||||
messages=messages,
|
||||
end_user=litellm._thread_context.user,
|
||||
litellm_call_id=self.
|
||||
litellm_params["litellm_call_id"],
|
||||
litellm_call_id=self.litellm_params["litellm_call_id"],
|
||||
litellm_params=self.model_call_details["litellm_params"],
|
||||
optional_params=self.model_call_details["optional_params"],
|
||||
print_verbose=print_verbose,
|
||||
|
@ -263,7 +255,7 @@ class Logging:
|
|||
print_verbose(
|
||||
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {traceback.format_exc()}"
|
||||
)
|
||||
|
||||
|
||||
# Input Integration Logging -> If you want to log the fact that an attempt to call the model was made
|
||||
for callback in litellm.input_callback:
|
||||
try:
|
||||
|
@ -274,8 +266,7 @@ class Logging:
|
|||
print_verbose(f"liteDebuggerClient: {liteDebuggerClient}")
|
||||
liteDebuggerClient.post_call_log_event(
|
||||
original_response=original_response,
|
||||
litellm_call_id=self.
|
||||
litellm_params["litellm_call_id"],
|
||||
litellm_call_id=self.litellm_params["litellm_call_id"],
|
||||
print_verbose=print_verbose,
|
||||
)
|
||||
except:
|
||||
|
@ -295,6 +286,7 @@ class Logging:
|
|||
|
||||
# Add more methods as needed
|
||||
|
||||
|
||||
def exception_logging(
|
||||
additional_args={},
|
||||
logger_fn=None,
|
||||
|
@ -329,13 +321,18 @@ def exception_logging(
|
|||
# make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking
|
||||
def client(original_function):
|
||||
global liteDebuggerClient, get_all_keys
|
||||
|
||||
|
||||
def function_setup(
|
||||
*args, **kwargs
|
||||
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
|
||||
try:
|
||||
global callback_list, add_breadcrumb, user_logger_fn
|
||||
if litellm.email is not None or os.getenv("LITELLM_EMAIL", None) is not None or litellm.token is not None or os.getenv("LITELLM_TOKEN", None): # add to input, success and failure callbacks if user is using hosted product
|
||||
if (
|
||||
litellm.email is not None
|
||||
or os.getenv("LITELLM_EMAIL", None) is not None
|
||||
or litellm.token is not None
|
||||
or os.getenv("LITELLM_TOKEN", None)
|
||||
): # add to input, success and failure callbacks if user is using hosted product
|
||||
get_all_keys()
|
||||
if "lite_debugger" not in callback_list and litellm.logging:
|
||||
litellm.input_callback.append("lite_debugger")
|
||||
|
@ -381,11 +378,12 @@ def client(original_function):
|
|||
if litellm.telemetry:
|
||||
try:
|
||||
model = args[0] if len(args) > 0 else kwargs["model"]
|
||||
exception = kwargs[
|
||||
"exception"] if "exception" in kwargs else None
|
||||
custom_llm_provider = (kwargs["custom_llm_provider"]
|
||||
if "custom_llm_provider" in kwargs else
|
||||
None)
|
||||
exception = kwargs["exception"] if "exception" in kwargs else None
|
||||
custom_llm_provider = (
|
||||
kwargs["custom_llm_provider"]
|
||||
if "custom_llm_provider" in kwargs
|
||||
else None
|
||||
)
|
||||
safe_crash_reporting(
|
||||
model=model,
|
||||
exception=exception,
|
||||
|
@ -410,10 +408,10 @@ def client(original_function):
|
|||
def check_cache(*args, **kwargs):
|
||||
try: # never block execution
|
||||
prompt = get_prompt(*args, **kwargs)
|
||||
if (prompt != None): # check if messages / prompt exists
|
||||
if prompt != None: # check if messages / prompt exists
|
||||
if litellm.caching_with_models:
|
||||
# if caching with model names is enabled, key is prompt + model name
|
||||
if ("model" in kwargs):
|
||||
if "model" in kwargs:
|
||||
cache_key = prompt + kwargs["model"]
|
||||
if cache_key in local_cache:
|
||||
return local_cache[cache_key]
|
||||
|
@ -423,7 +421,7 @@ def client(original_function):
|
|||
return result
|
||||
else:
|
||||
return None
|
||||
return None # default to return None
|
||||
return None # default to return None
|
||||
except:
|
||||
return None
|
||||
|
||||
|
@ -431,7 +429,7 @@ def client(original_function):
|
|||
try: # never block execution
|
||||
prompt = get_prompt(*args, **kwargs)
|
||||
if litellm.caching_with_models: # caching with model + prompt
|
||||
if ("model" in kwargs):
|
||||
if "model" in kwargs:
|
||||
cache_key = prompt + kwargs["model"]
|
||||
local_cache[cache_key] = result
|
||||
else: # caching based only on prompts
|
||||
|
@ -449,7 +447,8 @@ def client(original_function):
|
|||
start_time = datetime.datetime.now()
|
||||
# [OPTIONAL] CHECK CACHE
|
||||
if (litellm.caching or litellm.caching_with_models) and (
|
||||
cached_result := check_cache(*args, **kwargs)) is not None:
|
||||
cached_result := check_cache(*args, **kwargs)
|
||||
) is not None:
|
||||
result = cached_result
|
||||
return result
|
||||
# MODEL CALL
|
||||
|
@ -458,25 +457,22 @@ def client(original_function):
|
|||
return result
|
||||
end_time = datetime.datetime.now()
|
||||
# [OPTIONAL] ADD TO CACHE
|
||||
if (litellm.caching or litellm.caching_with_models):
|
||||
if litellm.caching or litellm.caching_with_models:
|
||||
add_cache(result, *args, **kwargs)
|
||||
# LOG SUCCESS
|
||||
my_thread = threading.Thread(
|
||||
target=handle_success,
|
||||
args=(args, kwargs, result, start_time,
|
||||
end_time)) # don't interrupt execution of main thread
|
||||
target=handle_success, args=(args, kwargs, result, start_time, end_time)
|
||||
) # don't interrupt execution of main thread
|
||||
my_thread.start()
|
||||
# RETURN RESULT
|
||||
return result
|
||||
except Exception as e:
|
||||
|
||||
traceback_exception = traceback.format_exc()
|
||||
crash_reporting(*args, **kwargs, exception=traceback_exception)
|
||||
end_time = datetime.datetime.now()
|
||||
my_thread = threading.Thread(
|
||||
target=handle_failure,
|
||||
args=(e, traceback_exception, start_time, end_time, args,
|
||||
kwargs),
|
||||
args=(e, traceback_exception, start_time, end_time, args, kwargs),
|
||||
) # don't interrupt execution of main thread
|
||||
my_thread.start()
|
||||
if hasattr(e, "message"):
|
||||
|
@ -506,18 +502,18 @@ def token_counter(model, text):
|
|||
return num_tokens
|
||||
|
||||
|
||||
def cost_per_token(model="gpt-3.5-turbo",
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0):
|
||||
def cost_per_token(model="gpt-3.5-turbo", prompt_tokens=0, completion_tokens=0):
|
||||
# given
|
||||
prompt_tokens_cost_usd_dollar = 0
|
||||
completion_tokens_cost_usd_dollar = 0
|
||||
model_cost_ref = litellm.model_cost
|
||||
if model in model_cost_ref:
|
||||
prompt_tokens_cost_usd_dollar = (
|
||||
model_cost_ref[model]["input_cost_per_token"] * prompt_tokens)
|
||||
model_cost_ref[model]["input_cost_per_token"] * prompt_tokens
|
||||
)
|
||||
completion_tokens_cost_usd_dollar = (
|
||||
model_cost_ref[model]["output_cost_per_token"] * completion_tokens)
|
||||
model_cost_ref[model]["output_cost_per_token"] * completion_tokens
|
||||
)
|
||||
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
||||
else:
|
||||
# calculate average input cost
|
||||
|
@ -538,9 +534,8 @@ def completion_cost(model="gpt-3.5-turbo", prompt="", completion=""):
|
|||
prompt_tokens = token_counter(model=model, text=prompt)
|
||||
completion_tokens = token_counter(model=model, text=completion)
|
||||
prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar = cost_per_token(
|
||||
model=model,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens)
|
||||
model=model, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens
|
||||
)
|
||||
return prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
|
||||
|
||||
|
||||
|
@ -558,7 +553,7 @@ def get_litellm_params(
|
|||
custom_llm_provider=None,
|
||||
custom_api_base=None,
|
||||
litellm_call_id=None,
|
||||
model_alias_map=None
|
||||
model_alias_map=None,
|
||||
):
|
||||
litellm_params = {
|
||||
"return_async": return_async,
|
||||
|
@ -569,13 +564,13 @@ def get_litellm_params(
|
|||
"custom_llm_provider": custom_llm_provider,
|
||||
"custom_api_base": custom_api_base,
|
||||
"litellm_call_id": litellm_call_id,
|
||||
"model_alias_map": model_alias_map
|
||||
"model_alias_map": model_alias_map,
|
||||
}
|
||||
|
||||
return litellm_params
|
||||
|
||||
|
||||
def get_optional_params( # use the openai defaults
|
||||
def get_optional_params( # use the openai defaults
|
||||
# 12 optional params
|
||||
functions=[],
|
||||
function_call="",
|
||||
|
@ -588,7 +583,7 @@ def get_optional_params( # use the openai defaults
|
|||
presence_penalty=0,
|
||||
frequency_penalty=0,
|
||||
logit_bias={},
|
||||
num_beams=1,
|
||||
num_beams=1,
|
||||
user="",
|
||||
deployment_id=None,
|
||||
model=None,
|
||||
|
@ -635,8 +630,9 @@ def get_optional_params( # use the openai defaults
|
|||
optional_params["max_tokens"] = max_tokens
|
||||
if frequency_penalty != 0:
|
||||
optional_params["frequency_penalty"] = frequency_penalty
|
||||
elif (model == "chat-bison"
|
||||
): # chat-bison has diff args from chat-bison@001 ty Google
|
||||
elif (
|
||||
model == "chat-bison"
|
||||
): # chat-bison has diff args from chat-bison@001 ty Google
|
||||
if temperature != 1:
|
||||
optional_params["temperature"] = temperature
|
||||
if top_p != 1:
|
||||
|
@ -702,10 +698,7 @@ def load_test_model(
|
|||
test_prompt = prompt
|
||||
if num_calls:
|
||||
test_calls = num_calls
|
||||
messages = [[{
|
||||
"role": "user",
|
||||
"content": test_prompt
|
||||
}] for _ in range(test_calls)]
|
||||
messages = [[{"role": "user", "content": test_prompt}] for _ in range(test_calls)]
|
||||
start_time = time.time()
|
||||
try:
|
||||
litellm.batch_completion(
|
||||
|
@ -743,15 +736,17 @@ def set_callbacks(callback_list):
|
|||
try:
|
||||
import sentry_sdk
|
||||
except ImportError:
|
||||
print_verbose(
|
||||
"Package 'sentry_sdk' is missing. Installing it...")
|
||||
print_verbose("Package 'sentry_sdk' is missing. Installing it...")
|
||||
subprocess.check_call(
|
||||
[sys.executable, "-m", "pip", "install", "sentry_sdk"])
|
||||
[sys.executable, "-m", "pip", "install", "sentry_sdk"]
|
||||
)
|
||||
import sentry_sdk
|
||||
sentry_sdk_instance = sentry_sdk
|
||||
sentry_trace_rate = (os.environ.get("SENTRY_API_TRACE_RATE")
|
||||
if "SENTRY_API_TRACE_RATE" in os.environ
|
||||
else "1.0")
|
||||
sentry_trace_rate = (
|
||||
os.environ.get("SENTRY_API_TRACE_RATE")
|
||||
if "SENTRY_API_TRACE_RATE" in os.environ
|
||||
else "1.0"
|
||||
)
|
||||
sentry_sdk_instance.init(
|
||||
dsn=os.environ.get("SENTRY_API_URL"),
|
||||
traces_sample_rate=float(sentry_trace_rate),
|
||||
|
@ -762,10 +757,10 @@ def set_callbacks(callback_list):
|
|||
try:
|
||||
from posthog import Posthog
|
||||
except ImportError:
|
||||
print_verbose(
|
||||
"Package 'posthog' is missing. Installing it...")
|
||||
print_verbose("Package 'posthog' is missing. Installing it...")
|
||||
subprocess.check_call(
|
||||
[sys.executable, "-m", "pip", "install", "posthog"])
|
||||
[sys.executable, "-m", "pip", "install", "posthog"]
|
||||
)
|
||||
from posthog import Posthog
|
||||
posthog = Posthog(
|
||||
project_api_key=os.environ.get("POSTHOG_API_KEY"),
|
||||
|
@ -775,10 +770,10 @@ def set_callbacks(callback_list):
|
|||
try:
|
||||
from slack_bolt import App
|
||||
except ImportError:
|
||||
print_verbose(
|
||||
"Package 'slack_bolt' is missing. Installing it...")
|
||||
print_verbose("Package 'slack_bolt' is missing. Installing it...")
|
||||
subprocess.check_call(
|
||||
[sys.executable, "-m", "pip", "install", "slack_bolt"])
|
||||
[sys.executable, "-m", "pip", "install", "slack_bolt"]
|
||||
)
|
||||
from slack_bolt import App
|
||||
slack_app = App(
|
||||
token=os.environ.get("SLACK_API_TOKEN"),
|
||||
|
@ -809,8 +804,7 @@ def set_callbacks(callback_list):
|
|||
raise e
|
||||
|
||||
|
||||
def handle_failure(exception, traceback_exception, start_time, end_time, args,
|
||||
kwargs):
|
||||
def handle_failure(exception, traceback_exception, start_time, end_time, args, kwargs):
|
||||
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient
|
||||
try:
|
||||
# print_verbose(f"handle_failure args: {args}")
|
||||
|
@ -820,7 +814,8 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args,
|
|||
failure_handler = additional_details.pop("failure_handler", None)
|
||||
|
||||
additional_details["Event_Name"] = additional_details.pop(
|
||||
"failed_event_name", "litellm.failed_query")
|
||||
"failed_event_name", "litellm.failed_query"
|
||||
)
|
||||
print_verbose(f"self.failure_callback: {litellm.failure_callback}")
|
||||
for callback in litellm.failure_callback:
|
||||
try:
|
||||
|
@ -835,8 +830,9 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args,
|
|||
for detail in additional_details:
|
||||
slack_msg += f"{detail}: {additional_details[detail]}\n"
|
||||
slack_msg += f"Traceback: {traceback_exception}"
|
||||
slack_app.client.chat_postMessage(channel=alerts_channel,
|
||||
text=slack_msg)
|
||||
slack_app.client.chat_postMessage(
|
||||
channel=alerts_channel, text=slack_msg
|
||||
)
|
||||
elif callback == "sentry":
|
||||
capture_exception(exception)
|
||||
elif callback == "posthog":
|
||||
|
@ -855,8 +851,9 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args,
|
|||
print_verbose(f"ph_obj: {ph_obj}")
|
||||
print_verbose(f"PostHog Event Name: {event_name}")
|
||||
if "user_id" in additional_details:
|
||||
posthog.capture(additional_details["user_id"],
|
||||
event_name, ph_obj)
|
||||
posthog.capture(
|
||||
additional_details["user_id"], event_name, ph_obj
|
||||
)
|
||||
else: # PostHog calls require a unique id to identify a user - https://posthog.com/docs/libraries/python
|
||||
unique_id = str(uuid.uuid4())
|
||||
posthog.capture(unique_id, event_name)
|
||||
|
@ -870,10 +867,10 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args,
|
|||
"created": time.time(),
|
||||
"error": traceback_exception,
|
||||
"usage": {
|
||||
"prompt_tokens":
|
||||
prompt_token_calculator(model, messages=messages),
|
||||
"completion_tokens":
|
||||
0,
|
||||
"prompt_tokens": prompt_token_calculator(
|
||||
model, messages=messages
|
||||
),
|
||||
"completion_tokens": 0,
|
||||
},
|
||||
}
|
||||
berrispendLogger.log_event(
|
||||
|
@ -892,10 +889,10 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args,
|
|||
"model": model,
|
||||
"created": time.time(),
|
||||
"usage": {
|
||||
"prompt_tokens":
|
||||
prompt_token_calculator(model, messages=messages),
|
||||
"completion_tokens":
|
||||
0,
|
||||
"prompt_tokens": prompt_token_calculator(
|
||||
model, messages=messages
|
||||
),
|
||||
"completion_tokens": 0,
|
||||
},
|
||||
}
|
||||
aispendLogger.log_event(
|
||||
|
@ -910,10 +907,13 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args,
|
|||
|
||||
model = args[0] if len(args) > 0 else kwargs["model"]
|
||||
|
||||
input = args[1] if len(args) > 1 else kwargs.get(
|
||||
"messages", kwargs.get("input", None))
|
||||
input = (
|
||||
args[1]
|
||||
if len(args) > 1
|
||||
else kwargs.get("messages", kwargs.get("input", None))
|
||||
)
|
||||
|
||||
type = 'embed' if 'input' in kwargs else 'llm'
|
||||
type = "embed" if "input" in kwargs else "llm"
|
||||
|
||||
llmonitorLogger.log_event(
|
||||
type=type,
|
||||
|
@ -937,10 +937,10 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args,
|
|||
"created": time.time(),
|
||||
"error": traceback_exception,
|
||||
"usage": {
|
||||
"prompt_tokens":
|
||||
prompt_token_calculator(model, messages=messages),
|
||||
"completion_tokens":
|
||||
0,
|
||||
"prompt_tokens": prompt_token_calculator(
|
||||
model, messages=messages
|
||||
),
|
||||
"completion_tokens": 0,
|
||||
},
|
||||
}
|
||||
supabaseClient.log_event(
|
||||
|
@ -957,16 +957,28 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args,
|
|||
print_verbose("reaches lite_debugger for logging!")
|
||||
print_verbose(f"liteDebuggerClient: {liteDebuggerClient}")
|
||||
model = args[0] if len(args) > 0 else kwargs["model"]
|
||||
messages = args[1] if len(args) > 1 else kwargs.get("messages", [{"role": "user", "content": ' '.join(kwargs.get("input", ""))}])
|
||||
messages = (
|
||||
args[1]
|
||||
if len(args) > 1
|
||||
else kwargs.get(
|
||||
"messages",
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": " ".join(kwargs.get("input", "")),
|
||||
}
|
||||
],
|
||||
)
|
||||
)
|
||||
result = {
|
||||
"model": model,
|
||||
"created": time.time(),
|
||||
"error": traceback_exception,
|
||||
"usage": {
|
||||
"prompt_tokens":
|
||||
prompt_token_calculator(model, messages=messages),
|
||||
"completion_tokens":
|
||||
0,
|
||||
"prompt_tokens": prompt_token_calculator(
|
||||
model, messages=messages
|
||||
),
|
||||
"completion_tokens": 0,
|
||||
},
|
||||
}
|
||||
liteDebuggerClient.log_event(
|
||||
|
@ -1002,11 +1014,16 @@ def handle_success(args, kwargs, result, start_time, end_time):
|
|||
global heliconeLogger, aispendLogger, supabaseClient, liteDebuggerClient, llmonitorLogger
|
||||
try:
|
||||
model = args[0] if len(args) > 0 else kwargs["model"]
|
||||
input = args[1] if len(args) > 1 else kwargs.get("messages", kwargs.get("input", None))
|
||||
input = (
|
||||
args[1]
|
||||
if len(args) > 1
|
||||
else kwargs.get("messages", kwargs.get("input", None))
|
||||
)
|
||||
success_handler = additional_details.pop("success_handler", None)
|
||||
failure_handler = additional_details.pop("failure_handler", None)
|
||||
additional_details["Event_Name"] = additional_details.pop(
|
||||
"successful_event_name", "litellm.succes_query")
|
||||
"successful_event_name", "litellm.succes_query"
|
||||
)
|
||||
for callback in litellm.success_callback:
|
||||
try:
|
||||
if callback == "posthog":
|
||||
|
@ -1015,8 +1032,9 @@ def handle_success(args, kwargs, result, start_time, end_time):
|
|||
ph_obj[detail] = additional_details[detail]
|
||||
event_name = additional_details["Event_Name"]
|
||||
if "user_id" in additional_details:
|
||||
posthog.capture(additional_details["user_id"],
|
||||
event_name, ph_obj)
|
||||
posthog.capture(
|
||||
additional_details["user_id"], event_name, ph_obj
|
||||
)
|
||||
else: # PostHog calls require a unique id to identify a user - https://posthog.com/docs/libraries/python
|
||||
unique_id = str(uuid.uuid4())
|
||||
posthog.capture(unique_id, event_name, ph_obj)
|
||||
|
@ -1025,8 +1043,9 @@ def handle_success(args, kwargs, result, start_time, end_time):
|
|||
slack_msg = ""
|
||||
for detail in additional_details:
|
||||
slack_msg += f"{detail}: {additional_details[detail]}\n"
|
||||
slack_app.client.chat_postMessage(channel=alerts_channel,
|
||||
text=slack_msg)
|
||||
slack_app.client.chat_postMessage(
|
||||
channel=alerts_channel, text=slack_msg
|
||||
)
|
||||
elif callback == "helicone":
|
||||
print_verbose("reaches helicone for logging!")
|
||||
model = args[0] if len(args) > 0 else kwargs["model"]
|
||||
|
@ -1043,11 +1062,14 @@ def handle_success(args, kwargs, result, start_time, end_time):
|
|||
print_verbose("reaches llmonitor for logging!")
|
||||
model = args[0] if len(args) > 0 else kwargs["model"]
|
||||
|
||||
input = args[1] if len(args) > 1 else kwargs.get(
|
||||
"messages", kwargs.get("input", None))
|
||||
input = (
|
||||
args[1]
|
||||
if len(args) > 1
|
||||
else kwargs.get("messages", kwargs.get("input", None))
|
||||
)
|
||||
|
||||
#if contains input, it's 'embedding', otherwise 'llm'
|
||||
type = 'embed' if 'input' in kwargs else 'llm'
|
||||
# if contains input, it's 'embedding', otherwise 'llm'
|
||||
type = "embed" if "input" in kwargs else "llm"
|
||||
|
||||
llmonitorLogger.log_event(
|
||||
type=type,
|
||||
|
@ -1069,7 +1091,6 @@ def handle_success(args, kwargs, result, start_time, end_time):
|
|||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
print_verbose=print_verbose,
|
||||
|
||||
)
|
||||
elif callback == "aispend":
|
||||
print_verbose("reaches aispend for logging!")
|
||||
|
@ -1084,7 +1105,11 @@ def handle_success(args, kwargs, result, start_time, end_time):
|
|||
elif callback == "supabase":
|
||||
print_verbose("reaches supabase for logging!")
|
||||
model = args[0] if len(args) > 0 else kwargs["model"]
|
||||
messages = args[1] if len(args) > 1 else kwargs.get("messages", {"role": "user", "content": ""})
|
||||
messages = (
|
||||
args[1]
|
||||
if len(args) > 1
|
||||
else kwargs.get("messages", {"role": "user", "content": ""})
|
||||
)
|
||||
print(f"supabaseClient: {supabaseClient}")
|
||||
supabaseClient.log_event(
|
||||
model=model,
|
||||
|
@ -1099,7 +1124,19 @@ def handle_success(args, kwargs, result, start_time, end_time):
|
|||
elif callback == "lite_debugger":
|
||||
print_verbose("reaches lite_debugger for logging!")
|
||||
print_verbose(f"liteDebuggerClient: {liteDebuggerClient}")
|
||||
messages = args[1] if len(args) > 1 else kwargs.get("messages", [{"role": "user", "content": ' '.join(kwargs.get("input", ""))}])
|
||||
messages = (
|
||||
args[1]
|
||||
if len(args) > 1
|
||||
else kwargs.get(
|
||||
"messages",
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": " ".join(kwargs.get("input", "")),
|
||||
}
|
||||
],
|
||||
)
|
||||
)
|
||||
liteDebuggerClient.log_event(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
@ -1129,6 +1166,7 @@ def handle_success(args, kwargs, result, start_time, end_time):
|
|||
)
|
||||
pass
|
||||
|
||||
|
||||
def acreate(*args, **kwargs): ## Thin client to handle the acreate langchain call
|
||||
return litellm.acompletion(*args, **kwargs)
|
||||
|
||||
|
@ -1170,28 +1208,43 @@ def modify_integration(integration_name, integration_params):
|
|||
if "table_name" in integration_params:
|
||||
Supabase.supabase_table_name = integration_params["table_name"]
|
||||
|
||||
|
||||
####### [BETA] HOSTED PRODUCT ################ - https://docs.litellm.ai/docs/debugging/hosted_debugging
|
||||
|
||||
|
||||
def get_all_keys(llm_provider=None):
|
||||
try:
|
||||
global last_fetched_at_keys
|
||||
# if user is using hosted product -> instantiate their env with their hosted api keys - refresh every 5 minutes
|
||||
print_verbose(f"Reaches get all keys, llm_provider: {llm_provider}")
|
||||
user_email = os.getenv("LITELLM_EMAIL") or litellm.email or litellm.token or os.getenv("LITELLM_TOKEN")
|
||||
user_email = (
|
||||
os.getenv("LITELLM_EMAIL")
|
||||
or litellm.email
|
||||
or litellm.token
|
||||
or os.getenv("LITELLM_TOKEN")
|
||||
)
|
||||
if user_email:
|
||||
time_delta = 0
|
||||
if last_fetched_at_keys != None:
|
||||
current_time = time.time()
|
||||
time_delta = current_time - last_fetched_at_keys
|
||||
if time_delta > 300 or last_fetched_at_keys == None or llm_provider: # if the llm provider is passed in , assume this happening due to an AuthError for that provider
|
||||
if (
|
||||
time_delta > 300 or last_fetched_at_keys == None or llm_provider
|
||||
): # if the llm provider is passed in , assume this happening due to an AuthError for that provider
|
||||
# make the api call
|
||||
last_fetched_at = time.time()
|
||||
print_verbose(f"last_fetched_at: {last_fetched_at}")
|
||||
response = requests.post(url="http://api.litellm.ai/get_all_keys", headers={"content-type": "application/json"}, data=json.dumps({"user_email": user_email}))
|
||||
response = requests.post(
|
||||
url="http://api.litellm.ai/get_all_keys",
|
||||
headers={"content-type": "application/json"},
|
||||
data=json.dumps({"user_email": user_email}),
|
||||
)
|
||||
print_verbose(f"get model key response: {response.text}")
|
||||
data = response.json()
|
||||
# update model list
|
||||
for key, value in data["model_keys"].items(): # follows the LITELLM API KEY format - <UPPERCASE_PROVIDER_NAME>_API_KEY - e.g. HUGGINGFACE_API_KEY
|
||||
for key, value in data[
|
||||
"model_keys"
|
||||
].items(): # follows the LITELLM API KEY format - <UPPERCASE_PROVIDER_NAME>_API_KEY - e.g. HUGGINGFACE_API_KEY
|
||||
os.environ[key] = value
|
||||
# set model alias map
|
||||
for model_alias, value in data["model_alias_map"].items():
|
||||
|
@ -1200,19 +1253,31 @@ def get_all_keys(llm_provider=None):
|
|||
return None
|
||||
return None
|
||||
except:
|
||||
print_verbose(f"[Non-Blocking Error] get_all_keys error - {traceback.format_exc()}")
|
||||
print_verbose(
|
||||
f"[Non-Blocking Error] get_all_keys error - {traceback.format_exc()}"
|
||||
)
|
||||
pass
|
||||
|
||||
|
||||
def get_model_list():
|
||||
global last_fetched_at
|
||||
try:
|
||||
# if user is using hosted product -> get their updated model list
|
||||
user_email = os.getenv("LITELLM_EMAIL") or litellm.email or litellm.token or os.getenv("LITELLM_TOKEN")
|
||||
user_email = (
|
||||
os.getenv("LITELLM_EMAIL")
|
||||
or litellm.email
|
||||
or litellm.token
|
||||
or os.getenv("LITELLM_TOKEN")
|
||||
)
|
||||
if user_email:
|
||||
# make the api call
|
||||
last_fetched_at = time.time()
|
||||
print(f"last_fetched_at: {last_fetched_at}")
|
||||
response = requests.post(url="http://api.litellm.ai/get_model_list", headers={"content-type": "application/json"}, data=json.dumps({"user_email": user_email}))
|
||||
response = requests.post(
|
||||
url="http://api.litellm.ai/get_model_list",
|
||||
headers={"content-type": "application/json"},
|
||||
data=json.dumps({"user_email": user_email}),
|
||||
)
|
||||
print_verbose(f"get_model_list response: {response.text}")
|
||||
data = response.json()
|
||||
# update model list
|
||||
|
@ -1224,12 +1289,14 @@ def get_model_list():
|
|||
if f"{item.upper()}_API_KEY" not in os.environ:
|
||||
missing_llm_provider = item
|
||||
break
|
||||
# update environment - if required
|
||||
# update environment - if required
|
||||
threading.Thread(target=get_all_keys, args=(missing_llm_provider)).start()
|
||||
return model_list
|
||||
return [] # return empty list by default
|
||||
return [] # return empty list by default
|
||||
except:
|
||||
print_verbose(f"[Non-Blocking Error] get_all_keys error - {traceback.format_exc()}")
|
||||
print_verbose(
|
||||
f"[Non-Blocking Error] get_all_keys error - {traceback.format_exc()}"
|
||||
)
|
||||
|
||||
|
||||
####### EXCEPTION MAPPING ################
|
||||
|
@ -1253,36 +1320,33 @@ def exception_type(model, original_exception, custom_llm_provider):
|
|||
exception_type = ""
|
||||
if "claude" in model: # one of the anthropics
|
||||
if hasattr(original_exception, "status_code"):
|
||||
print_verbose(
|
||||
f"status_code: {original_exception.status_code}")
|
||||
print_verbose(f"status_code: {original_exception.status_code}")
|
||||
if original_exception.status_code == 401:
|
||||
exception_mapping_worked = True
|
||||
raise AuthenticationError(
|
||||
message=
|
||||
f"AnthropicException - {original_exception.message}",
|
||||
message=f"AnthropicException - {original_exception.message}",
|
||||
llm_provider="anthropic",
|
||||
)
|
||||
elif original_exception.status_code == 400:
|
||||
exception_mapping_worked = True
|
||||
raise InvalidRequestError(
|
||||
message=
|
||||
f"AnthropicException - {original_exception.message}",
|
||||
message=f"AnthropicException - {original_exception.message}",
|
||||
model=model,
|
||||
llm_provider="anthropic",
|
||||
)
|
||||
elif original_exception.status_code == 429:
|
||||
exception_mapping_worked = True
|
||||
raise RateLimitError(
|
||||
message=
|
||||
f"AnthropicException - {original_exception.message}",
|
||||
message=f"AnthropicException - {original_exception.message}",
|
||||
llm_provider="anthropic",
|
||||
)
|
||||
elif ("Could not resolve authentication method. Expected either api_key or auth_token to be set."
|
||||
in error_str):
|
||||
elif (
|
||||
"Could not resolve authentication method. Expected either api_key or auth_token to be set."
|
||||
in error_str
|
||||
):
|
||||
exception_mapping_worked = True
|
||||
raise AuthenticationError(
|
||||
message=
|
||||
f"AnthropicException - {original_exception.message}",
|
||||
message=f"AnthropicException - {original_exception.message}",
|
||||
llm_provider="anthropic",
|
||||
)
|
||||
elif "replicate" in model:
|
||||
|
@ -1306,36 +1370,35 @@ def exception_type(model, original_exception, custom_llm_provider):
|
|||
llm_provider="replicate",
|
||||
)
|
||||
elif (
|
||||
exception_type == "ReplicateError"
|
||||
exception_type == "ReplicateError"
|
||||
): # ReplicateError implies an error on Replicate server side, not user side
|
||||
raise ServiceUnavailableError(
|
||||
message=f"ReplicateException - {error_str}",
|
||||
llm_provider="replicate",
|
||||
)
|
||||
elif model == "command-nightly": # Cohere
|
||||
if ("invalid api token" in error_str
|
||||
or "No API key provided." in error_str):
|
||||
if (
|
||||
"invalid api token" in error_str
|
||||
or "No API key provided." in error_str
|
||||
):
|
||||
exception_mapping_worked = True
|
||||
raise AuthenticationError(
|
||||
message=
|
||||
f"CohereException - {original_exception.message}",
|
||||
message=f"CohereException - {original_exception.message}",
|
||||
llm_provider="cohere",
|
||||
)
|
||||
elif "too many tokens" in error_str:
|
||||
exception_mapping_worked = True
|
||||
raise InvalidRequestError(
|
||||
message=
|
||||
f"CohereException - {original_exception.message}",
|
||||
message=f"CohereException - {original_exception.message}",
|
||||
model=model,
|
||||
llm_provider="cohere",
|
||||
)
|
||||
elif (
|
||||
"CohereConnectionError" in exception_type
|
||||
"CohereConnectionError" in exception_type
|
||||
): # cohere seems to fire these errors when we load test it (1k+ messages / min)
|
||||
exception_mapping_worked = True
|
||||
raise RateLimitError(
|
||||
message=
|
||||
f"CohereException - {original_exception.message}",
|
||||
message=f"CohereException - {original_exception.message}",
|
||||
llm_provider="cohere",
|
||||
)
|
||||
elif custom_llm_provider == "huggingface":
|
||||
|
@ -1343,23 +1406,20 @@ def exception_type(model, original_exception, custom_llm_provider):
|
|||
if original_exception.status_code == 401:
|
||||
exception_mapping_worked = True
|
||||
raise AuthenticationError(
|
||||
message=
|
||||
f"HuggingfaceException - {original_exception.message}",
|
||||
message=f"HuggingfaceException - {original_exception.message}",
|
||||
llm_provider="huggingface",
|
||||
)
|
||||
elif original_exception.status_code == 400:
|
||||
exception_mapping_worked = True
|
||||
raise InvalidRequestError(
|
||||
message=
|
||||
f"HuggingfaceException - {original_exception.message}",
|
||||
message=f"HuggingfaceException - {original_exception.message}",
|
||||
model=model,
|
||||
llm_provider="huggingface",
|
||||
)
|
||||
elif original_exception.status_code == 429:
|
||||
exception_mapping_worked = True
|
||||
raise RateLimitError(
|
||||
message=
|
||||
f"HuggingfaceException - {original_exception.message}",
|
||||
message=f"HuggingfaceException - {original_exception.message}",
|
||||
llm_provider="huggingface",
|
||||
)
|
||||
raise original_exception # base case - return the original exception
|
||||
|
@ -1375,8 +1435,10 @@ def exception_type(model, original_exception, custom_llm_provider):
|
|||
},
|
||||
exception=e,
|
||||
)
|
||||
## AUTH ERROR
|
||||
if isinstance(e, AuthenticationError) and (litellm.email or "LITELLM_EMAIL" in os.environ):
|
||||
## AUTH ERROR
|
||||
if isinstance(e, AuthenticationError) and (
|
||||
litellm.email or "LITELLM_EMAIL" in os.environ
|
||||
):
|
||||
threading.Thread(target=get_all_keys, args=(e.llm_provider,)).start()
|
||||
if exception_mapping_worked:
|
||||
raise e
|
||||
|
@ -1391,7 +1453,8 @@ def safe_crash_reporting(model=None, exception=None, custom_llm_provider=None):
|
|||
"exception": str(exception),
|
||||
"custom_llm_provider": custom_llm_provider,
|
||||
}
|
||||
threading.Thread(target=litellm_telemetry, args=(data, )).start()
|
||||
threading.Thread(target=litellm_telemetry, args=(data,)).start()
|
||||
|
||||
|
||||
def get_or_generate_uuid():
|
||||
uuid_file = "litellm_uuid.txt"
|
||||
|
@ -1445,8 +1508,7 @@ def get_secret(secret_name):
|
|||
# TODO: check which secret manager is being used
|
||||
# currently only supports Infisical
|
||||
try:
|
||||
secret = litellm.secret_manager_client.get_secret(
|
||||
secret_name).secret_value
|
||||
secret = litellm.secret_manager_client.get_secret(secret_name).secret_value
|
||||
except:
|
||||
secret = None
|
||||
return secret
|
||||
|
@ -1460,7 +1522,6 @@ def get_secret(secret_name):
|
|||
# wraps the completion stream to return the correct format for the model
|
||||
# replicate/anthropic/cohere
|
||||
class CustomStreamWrapper:
|
||||
|
||||
def __init__(self, completion_stream, model, custom_llm_provider=None):
|
||||
self.model = model
|
||||
self.custom_llm_provider = custom_llm_provider
|
||||
|
@ -1509,8 +1570,9 @@ class CustomStreamWrapper:
|
|||
elif self.model == "replicate":
|
||||
chunk = next(self.completion_stream)
|
||||
completion_obj["content"] = chunk
|
||||
elif (self.custom_llm_provider and self.custom_llm_provider == "together_ai") or ("togethercomputer"
|
||||
in self.model):
|
||||
elif (
|
||||
self.custom_llm_provider and self.custom_llm_provider == "together_ai"
|
||||
) or ("togethercomputer" in self.model):
|
||||
chunk = next(self.completion_stream)
|
||||
text_data = self.handle_together_ai_chunk(chunk)
|
||||
if text_data == "":
|
||||
|
@ -1545,9 +1607,9 @@ def read_config_args(config_path):
|
|||
########## ollama implementation ############################
|
||||
|
||||
|
||||
async def get_ollama_response_stream(api_base="http://localhost:11434",
|
||||
model="llama2",
|
||||
prompt="Why is the sky blue?"):
|
||||
async def get_ollama_response_stream(
|
||||
api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?"
|
||||
):
|
||||
session = aiohttp.ClientSession()
|
||||
url = f"{api_base}/api/generate"
|
||||
data = {
|
||||
|
@ -1570,11 +1632,7 @@ async def get_ollama_response_stream(api_base="http://localhost:11434",
|
|||
"content": "",
|
||||
}
|
||||
completion_obj["content"] = j["response"]
|
||||
yield {
|
||||
"choices": [{
|
||||
"delta": completion_obj
|
||||
}]
|
||||
}
|
||||
yield {"choices": [{"delta": completion_obj}]}
|
||||
# self.responses.append(j["response"])
|
||||
# yield "blank"
|
||||
except Exception as e:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue