Merge branch 'BerriAI:main' into main

This commit is contained in:
Zakhar Kogan 2023-08-20 11:21:33 +03:00 committed by GitHub
commit 2309f2f407
22 changed files with 246 additions and 317 deletions

View file

@ -35,7 +35,7 @@ messages = [{ "content": "Hello, how are you?","role": "user"}]
response = completion(model="gpt-3.5-turbo", messages=messages) response = completion(model="gpt-3.5-turbo", messages=messages)
# cohere call # cohere call
response = completion("command-nightly", messages) response = completion(model="command-nightly", messages)
``` ```
Code Sample: [Getting Started Notebook](https://colab.research.google.com/drive/1gR3pY-JzDZahzpVdbGBtrNGDBmzUNJaJ?usp=sharing) Code Sample: [Getting Started Notebook](https://colab.research.google.com/drive/1gR3pY-JzDZahzpVdbGBtrNGDBmzUNJaJ?usp=sharing)

View file

@ -1,4 +1,4 @@
# *🚅 litellm* # litellm
[![PyPI Version](https://img.shields.io/pypi/v/litellm.svg)](https://pypi.org/project/litellm/) [![PyPI Version](https://img.shields.io/pypi/v/litellm.svg)](https://pypi.org/project/litellm/)
[![PyPI Version](https://img.shields.io/badge/stable%20version-v0.1.345-blue?color=green&link=https://pypi.org/project/litellm/0.1.1/)](https://pypi.org/project/litellm/0.1.1/) [![PyPI Version](https://img.shields.io/badge/stable%20version-v0.1.345-blue?color=green&link=https://pypi.org/project/litellm/0.1.1/)](https://pypi.org/project/litellm/0.1.1/)
[![CircleCI](https://dl.circleci.com/status-badge/img/gh/BerriAI/litellm/tree/main.svg?style=svg)](https://dl.circleci.com/status-badge/redirect/gh/BerriAI/litellm/tree/main) [![CircleCI](https://dl.circleci.com/status-badge/img/gh/BerriAI/litellm/tree/main.svg?style=svg)](https://dl.circleci.com/status-badge/redirect/gh/BerriAI/litellm/tree/main)

View file

@ -22,11 +22,13 @@ create table
messages json null default '{}'::json, messages json null default '{}'::json,
response json null default '{}'::json, response json null default '{}'::json,
end_user text null default ''::text, end_user text null default ''::text,
status text null default ''::text,
error json null default '{}'::json, error json null default '{}'::json,
response_time real null default '0'::real, response_time real null default '0'::real,
total_cost real null, total_cost real null,
additional_details json null default '{}'::json, additional_details json null default '{}'::json,
constraint request_logs_pkey primary key (id) litellm_call_id text unique,
primary key (id)
) tablespace pg_default; ) tablespace pg_default;
``` ```

View file

@ -8,7 +8,7 @@ const darkCodeTheme = require('prism-react-renderer/themes/dracula');
const config = { const config = {
title: 'liteLLM', title: 'liteLLM',
tagline: 'Simplify LLM API Calls', tagline: 'Simplify LLM API Calls',
favicon: 'static/img/favicon.ico', favicon: '/img/favicon.ico',
// Set the production url of your site here // Set the production url of your site here
url: 'https://litellm.vercel.app/', url: 'https://litellm.vercel.app/',

View file

@ -1,6 +1,6 @@
import threading import threading
from typing import Callable, List, Optional from typing import Callable, List, Optional
input_callback: List[str] = []
success_callback: List[str] = [] success_callback: List[str] = []
failure_callback: List[str] = [] failure_callback: List[str] = []
set_verbose = False set_verbose = False
@ -216,7 +216,6 @@ from .timeout import timeout
from .testing import * from .testing import *
from .utils import ( from .utils import (
client, client,
logging,
exception_type, exception_type,
get_optional_params, get_optional_params,
modify_integration, modify_integration,
@ -224,6 +223,7 @@ from .utils import (
cost_per_token, cost_per_token,
completion_cost, completion_cost,
get_litellm_params, get_litellm_params,
Logging
) )
from .main import * # type: ignore from .main import * # type: ignore
from .integrations import * from .integrations import *

View file

@ -144,6 +144,28 @@ class Supabase:
) )
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
def input_log_event(self, model, messages, end_user, litellm_call_id, print_verbose):
try:
print_verbose(
f"Supabase Logging - Enters input logging function for model {model}"
)
supabase_data_obj = {
"model": model,
"messages": messages,
"end_user": end_user,
"status": "initiated",
"litellm_call_id": litellm_call_id
}
data, count = (
self.supabase_client.table(self.supabase_table_name)
.insert(supabase_data_obj)
.execute()
)
print(f"data: {data}")
pass
except:
pass
def log_event( def log_event(
self, self,
model, model,
@ -152,6 +174,7 @@ class Supabase:
response_obj, response_obj,
start_time, start_time,
end_time, end_time,
litellm_call_id,
print_verbose, print_verbose,
): ):
try: try:
@ -176,16 +199,20 @@ class Supabase:
"messages": messages, "messages": messages,
"response": response_obj["choices"][0]["message"]["content"], "response": response_obj["choices"][0]["message"]["content"],
"end_user": end_user, "end_user": end_user,
"litellm_call_id": litellm_call_id,
"status": "success"
} }
print_verbose( print_verbose(
f"Supabase Logging - final data object: {supabase_data_obj}" f"Supabase Logging - final data object: {supabase_data_obj}"
) )
data, count = ( data, count = (
self.supabase_client.table(self.supabase_table_name) self.supabase_client.table(self.supabase_table_name)
.insert(supabase_data_obj) .upsert(supabase_data_obj)
.execute() .execute()
) )
elif "error" in response_obj: elif "error" in response_obj:
if "Unable to map your input to a model." in response_obj["error"]:
total_cost = 0
supabase_data_obj = { supabase_data_obj = {
"response_time": response_time, "response_time": response_time,
"model": response_obj["model"], "model": response_obj["model"],
@ -193,13 +220,15 @@ class Supabase:
"messages": messages, "messages": messages,
"error": response_obj["error"], "error": response_obj["error"],
"end_user": end_user, "end_user": end_user,
"litellm_call_id": litellm_call_id,
"status": "failure"
} }
print_verbose( print_verbose(
f"Supabase Logging - final data object: {supabase_data_obj}" f"Supabase Logging - final data object: {supabase_data_obj}"
) )
data, count = ( data, count = (
self.supabase_client.table(self.supabase_table_name) self.supabase_client.table(self.supabase_table_name)
.insert(supabase_data_obj) .upsert(supabase_data_obj)
.execute() .execute()
) )

View file

@ -1,7 +1,6 @@
import os, json import os, json
from enum import Enum from enum import Enum
import requests import requests
from litellm import logging
import time import time
from typing import Callable from typing import Callable
from litellm.utils import ModelResponse from litellm.utils import ModelResponse
@ -22,11 +21,12 @@ class AnthropicError(Exception):
class AnthropicLLM: class AnthropicLLM:
def __init__(self, encoding, default_max_tokens_to_sample, api_key=None): def __init__(self, encoding, default_max_tokens_to_sample, logging_obj, api_key=None):
self.encoding = encoding self.encoding = encoding
self.default_max_tokens_to_sample = default_max_tokens_to_sample self.default_max_tokens_to_sample = default_max_tokens_to_sample
self.completion_url = "https://api.anthropic.com/v1/complete" self.completion_url = "https://api.anthropic.com/v1/complete"
self.api_key = api_key self.api_key = api_key
self.logging_obj = logging_obj
self.validate_environment(api_key=api_key) self.validate_environment(api_key=api_key)
def validate_environment( def validate_environment(
@ -84,15 +84,7 @@ class AnthropicLLM:
} }
## LOGGING ## LOGGING
logging( self.logging_obj.pre_call(input=prompt, api_key=self.api_key, additional_args={"complete_input_dict": data})
model=model,
input=prompt,
additional_args={
"litellm_params": litellm_params,
"optional_params": optional_params,
},
logger_fn=logger_fn,
)
## COMPLETION CALL ## COMPLETION CALL
response = requests.post( response = requests.post(
self.completion_url, headers=self.headers, data=json.dumps(data) self.completion_url, headers=self.headers, data=json.dumps(data)
@ -101,16 +93,7 @@ class AnthropicLLM:
return response.iter_lines() return response.iter_lines()
else: else:
## LOGGING ## LOGGING
logging( self.logging_obj.post_call(input=prompt, api_key=self.api_key, original_response=response.text, additional_args={"complete_input_dict": data})
model=model,
input=prompt,
additional_args={
"litellm_params": litellm_params,
"optional_params": optional_params,
"original_response": response.text,
},
logger_fn=logger_fn,
)
print_verbose(f"raw model_response: {response.text}") print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT ## RESPONSE OBJECT
completion_response = response.json() completion_response = response.json()

View file

@ -2,7 +2,6 @@
import os, json import os, json
from enum import Enum from enum import Enum
import requests import requests
from litellm import logging
import time import time
from typing import Callable from typing import Callable
from litellm.utils import ModelResponse from litellm.utils import ModelResponse
@ -19,8 +18,9 @@ class HuggingfaceError(Exception):
class HuggingfaceRestAPILLM: class HuggingfaceRestAPILLM:
def __init__(self, encoding, api_key=None) -> None: def __init__(self, encoding, logging_obj, api_key=None) -> None:
self.encoding = encoding self.encoding = encoding
self.logging_obj = logging_obj
self.validate_environment(api_key=api_key) self.validate_environment(api_key=api_key)
def validate_environment( def validate_environment(
@ -74,18 +74,10 @@ class HuggingfaceRestAPILLM:
optional_params["max_new_tokens"] = value optional_params["max_new_tokens"] = value
data = { data = {
"inputs": prompt, "inputs": prompt,
# "parameters": optional_params "parameters": optional_params
} }
## LOGGING ## LOGGING
logging( self.logging_obj.pre_call(input=prompt, api_key=self.api_key, additional_args={"complete_input_dict": data})
model=model,
input=prompt,
additional_args={
"litellm_params": litellm_params,
"optional_params": optional_params,
},
logger_fn=logger_fn,
)
## COMPLETION CALL ## COMPLETION CALL
response = requests.post( response = requests.post(
completion_url, headers=self.headers, data=json.dumps(data) completion_url, headers=self.headers, data=json.dumps(data)
@ -94,17 +86,7 @@ class HuggingfaceRestAPILLM:
return response.iter_lines() return response.iter_lines()
else: else:
## LOGGING ## LOGGING
logging( self.logging_obj.post_call(input=prompt, api_key=self.api_key, original_response=response.text, additional_args={"complete_input_dict": data})
model=model,
input=prompt,
additional_args={
"litellm_params": litellm_params,
"optional_params": optional_params,
"original_response": response.text,
},
logger_fn=logger_fn,
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT ## RESPONSE OBJECT
completion_response = response.json() completion_response = response.json()
print_verbose(f"response: {completion_response}") print_verbose(f"response: {completion_response}")

View file

@ -6,11 +6,11 @@ from copy import deepcopy
import litellm import litellm
from litellm import ( # type: ignore from litellm import ( # type: ignore
client, client,
logging,
exception_type, exception_type,
timeout, timeout,
get_optional_params, get_optional_params,
get_litellm_params, get_litellm_params,
Logging
) )
from litellm.utils import ( from litellm.utils import (
get_secret, get_secret,
@ -85,6 +85,7 @@ def completion(
azure=False, azure=False,
custom_llm_provider=None, custom_llm_provider=None,
custom_api_base=None, custom_api_base=None,
litellm_call_id=None,
# model specific optional params # model specific optional params
# used by text-bison only # used by text-bison only
top_k=40, top_k=40,
@ -94,6 +95,11 @@ def completion(
model_response = ModelResponse() model_response = ModelResponse()
if azure: # this flag is deprecated, remove once notebooks are also updated. if azure: # this flag is deprecated, remove once notebooks are also updated.
custom_llm_provider = "azure" custom_llm_provider = "azure"
elif model.split("/", 1)[0] in litellm.provider_list: # allow custom provider to be passed in via the model name "azure/chatgpt-test"
custom_llm_provider = model.split("/", 1)[0]
model = model.split("/", 1)[1]
if "replicate" == custom_llm_provider and "/" not in model: # handle the "replicate/llama2..." edge-case
model = custom_llm_provider + "/" + model
args = locals() args = locals()
# check if user passed in any of the OpenAI optional params # check if user passed in any of the OpenAI optional params
optional_params = get_optional_params( optional_params = get_optional_params(
@ -124,8 +130,9 @@ def completion(
verbose=verbose, verbose=verbose,
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
custom_api_base=custom_api_base, custom_api_base=custom_api_base,
litellm_call_id=litellm_call_id
) )
logging = Logging(model=model, messages=messages, optional_params=optional_params, litellm_params=litellm_params)
if custom_llm_provider == "azure": if custom_llm_provider == "azure":
# azure configs # azure configs
openai.api_type = "azure" openai.api_type = "azure"
@ -139,16 +146,14 @@ def completion(
if litellm.api_version is not None if litellm.api_version is not None
else get_secret("AZURE_API_VERSION") else get_secret("AZURE_API_VERSION")
) )
if not api_key and litellm.azure_key:
api_key = litellm.azure_key
elif not api_key and get_secret("AZURE_API_KEY"):
api_key = get_secret("AZURE_API_KEY")
# set key # set key
openai.api_key = api_key or litellm.azure_key or get_secret("AZURE_API_KEY") openai.api_key = api_key
## LOGGING ## LOGGING
logging( logging.pre_call(input=messages, api_key=openai.api_key, additional_args={"headers": litellm.headers, "api_version": openai.api_version, "api_base": openai.api_base})
model=model,
input=messages,
additional_args=optional_params,
custom_llm_provider=custom_llm_provider,
logger_fn=logger_fn,
)
## COMPLETION CALL ## COMPLETION CALL
if litellm.headers: if litellm.headers:
response = openai.ChatCompletion.create( response = openai.ChatCompletion.create(
@ -161,6 +166,8 @@ def completion(
response = openai.ChatCompletion.create( response = openai.ChatCompletion.create(
model=model, messages=messages, **optional_params model=model, messages=messages, **optional_params
) )
## LOGGING
logging.post_call(input=messages, api_key=openai.api_key, original_response=response, additional_args={"headers": litellm.headers, "api_version": openai.api_version, "api_base": openai.api_base})
elif ( elif (
model in litellm.open_ai_chat_completion_models model in litellm.open_ai_chat_completion_models
or custom_llm_provider == "custom_openai" or custom_llm_provider == "custom_openai"
@ -177,18 +184,15 @@ def completion(
if litellm.organization: if litellm.organization:
openai.organization = litellm.organization openai.organization = litellm.organization
# set API KEY # set API KEY
openai.api_key = ( if not api_key and litellm.openai_key:
api_key or litellm.openai_key or get_secret("OPENAI_API_KEY") api_key = litellm.openai_key
) elif not api_key and get_secret("AZURE_API_KEY"):
api_key = get_secret("OPENAI_API_KEY")
openai.api_key = api_key
## LOGGING ## LOGGING
logging( logging.pre_call(input=messages, api_key=api_key, additional_args={"headers": litellm.headers, "api_base": api_base})
model=model,
input=messages,
additional_args=args,
custom_llm_provider=custom_llm_provider,
logger_fn=logger_fn,
)
## COMPLETION CALL ## COMPLETION CALL
if litellm.headers: if litellm.headers:
response = openai.ChatCompletion.create( response = openai.ChatCompletion.create(
@ -201,6 +205,8 @@ def completion(
response = openai.ChatCompletion.create( response = openai.ChatCompletion.create(
model=model, messages=messages, **optional_params model=model, messages=messages, **optional_params
) )
## LOGGING
logging.post_call(input=messages, api_key=api_key, original_response=response, additional_args={"headers": litellm.headers})
elif model in litellm.open_ai_text_completion_models: elif model in litellm.open_ai_text_completion_models:
openai.api_type = "openai" openai.api_type = "openai"
openai.api_base = ( openai.api_base = (
@ -209,20 +215,19 @@ def completion(
else "https://api.openai.com/v1" else "https://api.openai.com/v1"
) )
openai.api_version = None openai.api_version = None
openai.api_key = ( # set API KEY
api_key or litellm.openai_key or get_secret("OPENAI_API_KEY") if not api_key and litellm.openai_key:
) api_key = litellm.openai_key
elif not api_key and get_secret("AZURE_API_KEY"):
api_key = get_secret("OPENAI_API_KEY")
openai.api_key = api_key
if litellm.organization: if litellm.organization:
openai.organization = litellm.organization openai.organization = litellm.organization
prompt = " ".join([message["content"] for message in messages]) prompt = " ".join([message["content"] for message in messages])
## LOGGING ## LOGGING
logging( logging.pre_call(input=prompt, api_key=api_key, additional_args={"openai_organization": litellm.organization, "headers": litellm.headers, "api_base": openai.api_base, "api_type": openai.api_type})
model=model,
input=prompt,
additional_args=optional_params,
custom_llm_provider=custom_llm_provider,
logger_fn=logger_fn,
)
## COMPLETION CALL ## COMPLETION CALL
if litellm.headers: if litellm.headers:
response = openai.Completion.create( response = openai.Completion.create(
@ -232,19 +237,10 @@ def completion(
) )
else: else:
response = openai.Completion.create(model=model, prompt=prompt) response = openai.Completion.create(model=model, prompt=prompt)
completion_response = response["choices"][0]["text"]
## LOGGING ## LOGGING
logging( logging.post_call(input=prompt, api_key=api_key, original_response=response, additional_args={"openai_organization": litellm.organization, "headers": litellm.headers, "api_base": openai.api_base, "api_type": openai.api_type})
model=model,
input=prompt,
custom_llm_provider=custom_llm_provider,
additional_args={
"max_tokens": max_tokens,
"original_response": completion_response,
},
logger_fn=logger_fn,
)
## RESPONSE OBJECT ## RESPONSE OBJECT
completion_response = response["choices"][0]["text"]
model_response["choices"][0]["message"]["content"] = completion_response model_response["choices"][0]["message"]["content"] = completion_response
model_response["created"] = response["created"] model_response["created"] = response["created"]
model_response["model"] = model model_response["model"] = model
@ -273,13 +269,7 @@ def completion(
input["max_length"] = max_tokens # for t5 models input["max_length"] = max_tokens # for t5 models
input["max_new_tokens"] = max_tokens # for llama2 models input["max_new_tokens"] = max_tokens # for llama2 models
## LOGGING ## LOGGING
logging( logging.pre_call(input=prompt, api_key=replicate_key, additional_args={"complete_input_dict": input, "max_tokens": max_tokens})
model=model,
input=input,
custom_llm_provider=custom_llm_provider,
additional_args={"max_tokens": max_tokens},
logger_fn=logger_fn,
)
## COMPLETION CALL ## COMPLETION CALL
output = replicate.run(model, input=input) output = replicate.run(model, input=input)
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
@ -292,16 +282,8 @@ def completion(
response += item response += item
completion_response = response completion_response = response
## LOGGING ## LOGGING
logging( logging.post_call(input=prompt, api_key=replicate_key, original_response=completion_response, additional_args={"complete_input_dict": input, "max_tokens": max_tokens})
model=model, ## USAGE
input=prompt,
custom_llm_provider=custom_llm_provider,
additional_args={
"max_tokens": max_tokens,
"original_response": completion_response,
},
logger_fn=logger_fn,
)
prompt_tokens = len(encoding.encode(prompt)) prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(encoding.encode(completion_response)) completion_tokens = len(encoding.encode(completion_response))
## RESPONSE OBJECT ## RESPONSE OBJECT
@ -322,6 +304,7 @@ def completion(
encoding=encoding, encoding=encoding,
default_max_tokens_to_sample=litellm.max_tokens, default_max_tokens_to_sample=litellm.max_tokens,
api_key=anthropic_key, api_key=anthropic_key,
logging_obj = logging # model call logging done inside the class as we make need to modify I/O to fit anthropic's requirements
) )
model_response = anthropic_client.completion( model_response = anthropic_client.completion(
model=model, model=model,
@ -357,13 +340,7 @@ def completion(
"OR_API_KEY" "OR_API_KEY"
) )
## LOGGING ## LOGGING
logging( logging.pre_call(input=messages, api_key=openai.api_key)
model=model,
input=messages,
additional_args=optional_params,
custom_llm_provider=custom_llm_provider,
logger_fn=logger_fn,
)
## COMPLETION CALL ## COMPLETION CALL
if litellm.headers: if litellm.headers:
response = openai.ChatCompletion.create( response = openai.ChatCompletion.create(
@ -390,6 +367,8 @@ def completion(
}, },
**optional_params, **optional_params,
) )
## LOGGING
logging.post_call(input=messages, api_key=openai.api_key, original_response=response)
elif model in litellm.cohere_models: elif model in litellm.cohere_models:
# import cohere/if it fails then pip install cohere # import cohere/if it fails then pip install cohere
install_and_import("cohere") install_and_import("cohere")
@ -404,31 +383,17 @@ def completion(
co = cohere.Client(cohere_key) co = cohere.Client(cohere_key)
prompt = " ".join([message["content"] for message in messages]) prompt = " ".join([message["content"] for message in messages])
## LOGGING ## LOGGING
logging( logging.pre_call(input=prompt, api_key=cohere_key)
model=model,
input=prompt,
custom_llm_provider=custom_llm_provider,
logger_fn=logger_fn,
)
## COMPLETION CALL ## COMPLETION CALL
response = co.generate(model=model, prompt=prompt, **optional_params) response = co.generate(model=model, prompt=prompt, **optional_params)
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object, # don't try to access stream object,
response = CustomStreamWrapper(response, model) response = CustomStreamWrapper(response, model)
return response return response
completion_response = response[0].text
## LOGGING ## LOGGING
logging( logging.post_call(input=prompt, api_key=cohere_key, original_response=response)
model=model, ## USAGE
input=prompt, completion_response = response[0].text
custom_llm_provider=custom_llm_provider,
additional_args={
"max_tokens": max_tokens,
"original_response": completion_response,
},
logger_fn=logger_fn,
)
prompt_tokens = len(encoding.encode(prompt)) prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(encoding.encode(completion_response)) completion_tokens = len(encoding.encode(completion_response))
## RESPONSE OBJECT ## RESPONSE OBJECT
@ -452,7 +417,7 @@ def completion(
or os.environ.get("HUGGINGFACE_API_KEY") or os.environ.get("HUGGINGFACE_API_KEY")
) )
huggingface_client = HuggingfaceRestAPILLM( huggingface_client = HuggingfaceRestAPILLM(
encoding=encoding, api_key=huggingface_key encoding=encoding, api_key=huggingface_key, logging_obj=logging
) )
model_response = huggingface_client.completion( model_response = huggingface_client.completion(
model=model, model=model,
@ -487,12 +452,7 @@ def completion(
) # TODO: Add chat support for together AI ) # TODO: Add chat support for together AI
## LOGGING ## LOGGING
logging( logging.pre_call(input=prompt, api_key=TOGETHER_AI_TOKEN)
model=model,
input=prompt,
custom_llm_provider=custom_llm_provider,
logger_fn=logger_fn,
)
if stream == True: if stream == True:
return together_ai_completion_streaming( return together_ai_completion_streaming(
{ {
@ -514,17 +474,7 @@ def completion(
headers=headers, headers=headers,
) )
## LOGGING ## LOGGING
logging( logging.post_call(input=prompt, api_key=TOGETHER_AI_TOKEN, original_response=res.text)
model=model,
input=prompt,
custom_llm_provider=custom_llm_provider,
additional_args={
"max_tokens": max_tokens,
"original_response": res.text,
},
logger_fn=logger_fn,
)
# make this safe for reading, if output does not exist raise an error # make this safe for reading, if output does not exist raise an error
json_response = res.json() json_response = res.json()
if "output" not in json_response: if "output" not in json_response:
@ -557,16 +507,7 @@ def completion(
prompt = " ".join([message["content"] for message in messages]) prompt = " ".join([message["content"] for message in messages])
## LOGGING ## LOGGING
logging( logging.pre_call(input=prompt, api_key=None)
model=model,
input=prompt,
custom_llm_provider=custom_llm_provider,
additional_args={
"litellm_params": litellm_params,
"optional_params": optional_params,
},
logger_fn=logger_fn,
)
chat_model = ChatModel.from_pretrained(model) chat_model = ChatModel.from_pretrained(model)
@ -574,16 +515,7 @@ def completion(
completion_response = chat.send_message(prompt, **optional_params) completion_response = chat.send_message(prompt, **optional_params)
## LOGGING ## LOGGING
logging( logging.post_call(input=prompt, api_key=None, original_response=completion_response)
model=model,
input=prompt,
custom_llm_provider=custom_llm_provider,
additional_args={
"max_tokens": max_tokens,
"original_response": completion_response,
},
logger_fn=logger_fn,
)
## RESPONSE OBJECT ## RESPONSE OBJECT
model_response["choices"][0]["message"]["content"] = completion_response model_response["choices"][0]["message"]["content"] = completion_response
@ -602,27 +534,13 @@ def completion(
prompt = " ".join([message["content"] for message in messages]) prompt = " ".join([message["content"] for message in messages])
## LOGGING ## LOGGING
logging( logging.pre_call(input=prompt, api_key=None)
model=model,
input=prompt,
custom_llm_provider=custom_llm_provider,
logger_fn=logger_fn,
)
vertex_model = TextGenerationModel.from_pretrained(model) vertex_model = TextGenerationModel.from_pretrained(model)
completion_response = vertex_model.predict(prompt, **optional_params) completion_response = vertex_model.predict(prompt, **optional_params)
## LOGGING ## LOGGING
logging( logging.post_call(input=prompt, api_key=None, original_response=completion_response)
model=model,
input=prompt,
custom_llm_provider=custom_llm_provider,
additional_args={
"max_tokens": max_tokens,
"original_response": completion_response,
},
logger_fn=logger_fn,
)
## RESPONSE OBJECT ## RESPONSE OBJECT
model_response["choices"][0]["message"]["content"] = completion_response model_response["choices"][0]["message"]["content"] = completion_response
model_response["created"] = time.time() model_response["created"] = time.time()
@ -636,12 +554,7 @@ def completion(
prompt = " ".join([message["content"] for message in messages]) prompt = " ".join([message["content"] for message in messages])
## LOGGING ## LOGGING
logging( logging.pre_call(input=prompt, api_key=ai21.api_key)
model=model,
input=prompt,
custom_llm_provider=custom_llm_provider,
logger_fn=logger_fn,
)
ai21_response = ai21.Completion.execute( ai21_response = ai21.Completion.execute(
model=model, model=model,
@ -650,16 +563,7 @@ def completion(
completion_response = ai21_response["completions"][0]["data"]["text"] completion_response = ai21_response["completions"][0]["data"]["text"]
## LOGGING ## LOGGING
logging( logging.post_call(input=prompt, api_key=ai21.api_key, original_response=completion_response)
model=model,
input=prompt,
custom_llm_provider=custom_llm_provider,
additional_args={
"max_tokens": max_tokens,
"original_response": completion_response,
},
logger_fn=logger_fn,
)
## RESPONSE OBJECT ## RESPONSE OBJECT
model_response["choices"][0]["message"]["content"] = completion_response model_response["choices"][0]["message"]["content"] = completion_response
@ -673,7 +577,8 @@ def completion(
prompt = " ".join([message["content"] for message in messages]) prompt = " ".join([message["content"] for message in messages])
## LOGGING ## LOGGING
logging(model=model, input=prompt, azure=azure, logger_fn=logger_fn) logging.pre_call(input=prompt, api_key=None, additional_args={"endpoint": endpoint})
generator = get_ollama_response_stream(endpoint, model, prompt) generator = get_ollama_response_stream(endpoint, model, prompt)
# assume all responses are streamed # assume all responses are streamed
return generator return generator
@ -688,12 +593,7 @@ def completion(
prompt = " ".join([message["content"] for message in messages]) prompt = " ".join([message["content"] for message in messages])
## LOGGING ## LOGGING
logging( logging.pre_call(input=prompt, api_key=base_ten_key)
model=model,
input=prompt,
custom_llm_provider=custom_llm_provider,
logger_fn=logger_fn,
)
base_ten__model = baseten.deployed_model_version_id(model) base_ten__model = baseten.deployed_model_version_id(model)
@ -703,16 +603,8 @@ def completion(
if type(completion_response) == dict: if type(completion_response) == dict:
completion_response = completion_response["generated_text"] completion_response = completion_response["generated_text"]
logging( ## LOGGING
model=model, logging.post_call(input=prompt, api_key=base_ten_key, original_response=completion_response)
input=prompt,
custom_llm_provider=custom_llm_provider,
additional_args={
"max_tokens": max_tokens,
"original_response": completion_response,
},
logger_fn=logger_fn,
)
## RESPONSE OBJECT ## RESPONSE OBJECT
model_response["choices"][0]["message"]["content"] = completion_response model_response["choices"][0]["message"]["content"] = completion_response
@ -729,26 +621,14 @@ def completion(
prompt = " ".join([message["content"] for message in messages]) prompt = " ".join([message["content"] for message in messages])
## LOGGING ## LOGGING
logging( logging.pre_call(input=prompt, api_key=None, additional_args={"url": url, "max_new_tokens": 100})
model=model,
input=prompt,
custom_llm_provider=custom_llm_provider,
logger_fn=logger_fn,
)
response = requests.post( response = requests.post(
url, data={"inputs": prompt, "max_new_tokens": 100, "model": model} url, data={"inputs": prompt, "max_new_tokens": 100, "model": model}
) )
## LOGGING ## LOGGING
logging( logging.post_call(input=prompt, api_key=None, original_response=response.text, additional_args={"url": url, "max_new_tokens": 100})
model=model,
input=prompt,
custom_llm_provider=custom_llm_provider,
additional_args={
"max_tokens": max_tokens,
"original_response": response,
},
logger_fn=logger_fn,
)
completion_response = response.json()["outputs"] completion_response = response.json()["outputs"]
# RESPONSE OBJECT # RESPONSE OBJECT
@ -757,13 +637,6 @@ def completion(
model_response["model"] = model model_response["model"] = model
response = model_response response = model_response
else: else:
## LOGGING
logging(
model=model,
input=messages,
custom_llm_provider=custom_llm_provider,
logger_fn=logger_fn,
)
args = locals() args = locals()
raise ValueError( raise ValueError(
f"Unable to map your input to a model. Check your input - {args}" f"Unable to map your input to a model. Check your input - {args}"
@ -771,14 +644,7 @@ def completion(
return response return response
except Exception as e: except Exception as e:
## LOGGING ## LOGGING
logging( logging.post_call(input=messages, api_key=api_key, original_response=e)
model=model,
input=messages,
custom_llm_provider=custom_llm_provider,
additional_args={"max_tokens": max_tokens},
logger_fn=logger_fn,
exception=e,
)
## Map to OpenAI Exception ## Map to OpenAI Exception
raise exception_type( raise exception_type(
model=model, custom_llm_provider=custom_llm_provider, original_exception=e model=model, custom_llm_provider=custom_llm_provider, original_exception=e
@ -810,9 +676,10 @@ def batch_completion(*args, **kwargs):
@timeout( # type: ignore @timeout( # type: ignore
60 60
) ## set timeouts, in case calls hang (e.g. Azure) - default is 60s, override with `force_timeout` ) ## set timeouts, in case calls hang (e.g. Azure) - default is 60s, override with `force_timeout`
def embedding(model, input=[], azure=False, force_timeout=60, logger_fn=None): def embedding(model, input=[], azure=False, force_timeout=60, litellm_call_id=None, logger_fn=None):
try: try:
response = None response = None
logging = Logging(model=model, messages=input, optional_params={}, litellm_params={"azure": azure, "force_timeout": force_timeout, "logger_fn": logger_fn, "litellm_call_id": litellm_call_id})
if azure == True: if azure == True:
# azure configs # azure configs
openai.api_type = "azure" openai.api_type = "azure"
@ -820,7 +687,7 @@ def embedding(model, input=[], azure=False, force_timeout=60, logger_fn=None):
openai.api_version = get_secret("AZURE_API_VERSION") openai.api_version = get_secret("AZURE_API_VERSION")
openai.api_key = get_secret("AZURE_API_KEY") openai.api_key = get_secret("AZURE_API_KEY")
## LOGGING ## LOGGING
logging(model=model, input=input, azure=azure, logger_fn=logger_fn) logging.pre_call(input=input, api_key=openai.api_key, additional_args={"api_type": openai.api_type, "api_base": openai.api_base, "api_version": openai.api_version})
## EMBEDDING CALL ## EMBEDDING CALL
response = openai.Embedding.create(input=input, engine=model) response = openai.Embedding.create(input=input, engine=model)
print_verbose(f"response_value: {str(response)[:50]}") print_verbose(f"response_value: {str(response)[:50]}")
@ -830,19 +697,16 @@ def embedding(model, input=[], azure=False, force_timeout=60, logger_fn=None):
openai.api_version = None openai.api_version = None
openai.api_key = get_secret("OPENAI_API_KEY") openai.api_key = get_secret("OPENAI_API_KEY")
## LOGGING ## LOGGING
logging(model=model, input=input, azure=azure, logger_fn=logger_fn) logging.pre_call(input=input, api_key=openai.api_key, additional_args={"api_type": openai.api_type, "api_base": openai.api_base, "api_version": openai.api_version})
## EMBEDDING CALL ## EMBEDDING CALL
response = openai.Embedding.create(input=input, model=model) response = openai.Embedding.create(input=input, model=model)
print_verbose(f"response_value: {str(response)[:50]}") print_verbose(f"response_value: {str(response)[:50]}")
else: else:
logging(model=model, input=input, azure=azure, logger_fn=logger_fn)
args = locals() args = locals()
raise ValueError(f"No valid embedding model args passed in - {args}") raise ValueError(f"No valid embedding model args passed in - {args}")
return response return response
except Exception as e: except Exception as e:
# log the original exception
logging(model=model, input=input, azure=azure, logger_fn=logger_fn, exception=e)
## Map to OpenAI Exception ## Map to OpenAI Exception
raise exception_type(model=model, original_exception=e, custom_llm_provider="azure" if azure==True else None) raise exception_type(model=model, original_exception=e, custom_llm_provider="azure" if azure==True else None)
raise e raise e

View file

@ -25,6 +25,18 @@ def logger_fn(user_model_dict):
print(f"user_model_dict: {user_model_dict}") print(f"user_model_dict: {user_model_dict}")
def test_completion_custom_provider_model_name():
try:
response = completion(
model="together_ai/togethercomputer/llama-2-70b-chat", messages=messages, logger_fn=logger_fn
)
# Add any assertions here to check the response
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
test_completion_custom_provider_model_name()
def test_completion_claude(): def test_completion_claude():
try: try:
response = completion( response = completion(
@ -77,7 +89,7 @@ def test_completion_claude_stream():
def test_completion_cohere(): def test_completion_cohere():
try: try:
response = completion( response = completion(
model="command-nightly", messages=messages, max_tokens=100 model="command-nightly", messages=messages, max_tokens=100, logit_bias={40: 10}
) )
# Add any assertions here to check the response # Add any assertions here to check the response
print(response) print(response)
@ -91,7 +103,6 @@ def test_completion_cohere():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
def test_completion_cohere_stream(): def test_completion_cohere_stream():
try: try:
messages = [ messages = [

View file

@ -9,10 +9,11 @@
# import litellm # import litellm
# from litellm import embedding, completion # from litellm import embedding, completion
# litellm.input_callback = ["supabase"]
# litellm.success_callback = ["supabase"] # litellm.success_callback = ["supabase"]
# litellm.failure_callback = ["supabase"] # litellm.failure_callback = ["supabase"]
# litellm.modify_integration("supabase",{"table_name": "litellm_logs"}) # litellm.modify_integration("supabase",{"table_name": "test_table"})
# litellm.set_verbose = True # litellm.set_verbose = True

View file

@ -135,48 +135,105 @@ def install_and_import(package: str):
####### LOGGING ################### ####### LOGGING ###################
# Logging function -> log the exact model details + what's being sent | Non-Blocking # Logging function -> log the exact model details + what's being sent | Non-Blocking
def logging( class Logging:
model=None, def __init__(self, model, messages, optional_params, litellm_params):
input=None, self.model = model
custom_llm_provider=None, self.messages = messages
azure=False, self.optional_params = optional_params
self.litellm_params = litellm_params
self.logger_fn = litellm_params["logger_fn"]
self.model_call_details = {
"model": model,
"messages": messages,
"optional_params": self.optional_params,
"litellm_params": self.litellm_params,
}
def pre_call(self, input, api_key, additional_args={}):
try:
print(f"logging pre call for model: {self.model}")
self.model_call_details["input"] = input
self.model_call_details["api_key"] = api_key
self.model_call_details["additional_args"] = additional_args
## User Logging -> if you pass in a custom logging function
print_verbose(
f"Logging Details: logger_fn - {self.logger_fn} | callable(logger_fn) - {callable(self.logger_fn)}"
)
if self.logger_fn and callable(self.logger_fn):
try:
self.logger_fn(
self.model_call_details
) # Expectation: any logger function passed in by the user should accept a dict object
except Exception as e:
print_verbose(
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {traceback.format_exc()}"
)
## Input Integration Logging -> If you want to log the fact that an attempt to call the model was made
for callback in litellm.input_callback:
try:
if callback == "supabase":
print_verbose("reaches supabase for logging!")
model = self.model
messages = self.messages
print(f"litellm._thread_context: {litellm._thread_context}")
supabaseClient.input_log_event(
model=model,
messages=messages,
end_user=litellm._thread_context.user,
litellm_call_id=self.litellm_params["litellm_call_id"],
print_verbose=print_verbose,
)
pass
except:
pass
except:
print_verbose(
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {traceback.format_exc()}"
)
pass
def post_call(self, input, api_key, original_response, additional_args={}):
# Do something here
try:
self.model_call_details["input"] = input
self.model_call_details["api_key"] = api_key
self.model_call_details["original_response"] = original_response
self.model_call_details["additional_args"] = additional_args
## User Logging -> if you pass in a custom logging function
print_verbose(
f"Logging Details: logger_fn - {self.logger_fn} | callable(logger_fn) - {callable(self.logger_fn)}"
)
if self.logger_fn and callable(self.logger_fn):
try:
self.logger_fn(
self.model_call_details
) # Expectation: any logger function passed in by the user should accept a dict object
except Exception as e:
print_verbose(
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {traceback.format_exc()}"
)
except:
print_verbose(
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {traceback.format_exc()}"
)
pass
# Add more methods as needed
def exception_logging(
additional_args={}, additional_args={},
logger_fn=None, logger_fn=None,
exception=None, exception=None,
): ):
try: try:
model_call_details = {} model_call_details = {}
if model:
model_call_details["model"] = model
if azure:
model_call_details["azure"] = azure
if custom_llm_provider:
model_call_details["custom_llm_provider"] = custom_llm_provider
if exception: if exception:
model_call_details["exception"] = exception model_call_details["exception"] = exception
if input:
model_call_details["input"] = input
if len(additional_args):
model_call_details["additional_args"] = additional_args model_call_details["additional_args"] = additional_args
# log additional call details -> api key, etc.
if model:
if (
azure == True
or model in litellm.open_ai_chat_completion_models
or model in litellm.open_ai_chat_completion_models
or model in litellm.open_ai_embedding_models
):
model_call_details["api_type"] = openai.api_type
model_call_details["api_base"] = openai.api_base
model_call_details["api_version"] = openai.api_version
model_call_details["api_key"] = openai.api_key
elif "replicate" in model:
model_call_details["api_key"] = os.environ.get("REPLICATE_API_TOKEN")
elif model in litellm.anthropic_models:
model_call_details["api_key"] = os.environ.get("ANTHROPIC_API_KEY")
elif model in litellm.cohere_models:
model_call_details["api_key"] = os.environ.get("COHERE_API_KEY")
## User Logging -> if you pass in a custom logging function or want to use sentry breadcrumbs ## User Logging -> if you pass in a custom logging function or want to use sentry breadcrumbs
print_verbose( print_verbose(
f"Logging Details: logger_fn - {logger_fn} | callable(logger_fn) - {callable(logger_fn)}" f"Logging Details: logger_fn - {logger_fn} | callable(logger_fn) - {callable(logger_fn)}"
@ -206,10 +263,10 @@ def client(original_function):
try: try:
global callback_list, add_breadcrumb, user_logger_fn global callback_list, add_breadcrumb, user_logger_fn
if ( if (
len(litellm.success_callback) > 0 or len(litellm.failure_callback) > 0 len(litellm.input_callback) > 0 or len(litellm.success_callback) > 0 or len(litellm.failure_callback) > 0
) and len(callback_list) == 0: ) and len(callback_list) == 0:
callback_list = list( callback_list = list(
set(litellm.success_callback + litellm.failure_callback) set(litellm.input_callback + litellm.success_callback + litellm.failure_callback)
) )
set_callbacks( set_callbacks(
callback_list=callback_list, callback_list=callback_list,
@ -299,13 +356,16 @@ def client(original_function):
result = None result = None
try: try:
function_setup(*args, **kwargs) function_setup(*args, **kwargs)
## MODEL CALL litellm_call_id = str(uuid.uuid4())
kwargs["litellm_call_id"] = litellm_call_id
## [OPTIONAL] CHECK CACHE
start_time = datetime.datetime.now() start_time = datetime.datetime.now()
if (litellm.caching or litellm.caching_with_models) and ( if (litellm.caching or litellm.caching_with_models) and (
cached_result := check_cache(*args, **kwargs) cached_result := check_cache(*args, **kwargs)
) is not None: ) is not None:
result = cached_result result = cached_result
else: else:
## MODEL CALL
result = original_function(*args, **kwargs) result = original_function(*args, **kwargs)
end_time = datetime.datetime.now() end_time = datetime.datetime.now()
## Add response to CACHE ## Add response to CACHE
@ -399,6 +459,7 @@ def get_litellm_params(
together_ai=False, together_ai=False,
custom_llm_provider=None, custom_llm_provider=None,
custom_api_base=None, custom_api_base=None,
litellm_call_id=None,
): ):
litellm_params = { litellm_params = {
"return_async": return_async, "return_async": return_async,
@ -408,6 +469,7 @@ def get_litellm_params(
"verbose": verbose, "verbose": verbose,
"custom_llm_provider": custom_llm_provider, "custom_llm_provider": custom_llm_provider,
"custom_api_base": custom_api_base, "custom_api_base": custom_api_base,
"litellm_call_id": litellm_call_id
} }
return litellm_params return litellm_params
@ -452,6 +514,8 @@ def get_optional_params(
optional_params["temperature"] = temperature optional_params["temperature"] = temperature
if max_tokens != float("inf"): if max_tokens != float("inf"):
optional_params["max_tokens"] = max_tokens optional_params["max_tokens"] = max_tokens
if logit_bias != {}:
optional_params["logit_bias"] = logit_bias
return optional_params return optional_params
elif custom_llm_provider == "replicate": elif custom_llm_provider == "replicate":
# any replicate models # any replicate models
@ -565,7 +629,8 @@ def set_callbacks(callback_list):
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient
try: try:
for callback in callback_list: for callback in callback_list:
if callback == "sentry" or "SENTRY_API_URL" in os.environ: print(f"callback: {callback}")
if callback == "sentry":
try: try:
import sentry_sdk import sentry_sdk
except ImportError: except ImportError:
@ -621,6 +686,7 @@ def set_callbacks(callback_list):
elif callback == "berrispend": elif callback == "berrispend":
berrispendLogger = BerriSpendLogger() berrispendLogger = BerriSpendLogger()
elif callback == "supabase": elif callback == "supabase":
print(f"instantiating supabase")
supabaseClient = Supabase() supabaseClient = Supabase()
except Exception as e: except Exception as e:
raise e raise e
@ -741,7 +807,6 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args, k
"completion_tokens": 0, "completion_tokens": 0,
}, },
} }
print(f"litellm._thread_context: {litellm._thread_context}")
supabaseClient.log_event( supabaseClient.log_event(
model=model, model=model,
messages=messages, messages=messages,
@ -749,9 +814,9 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args, k
response_obj=result, response_obj=result,
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
litellm_call_id=kwargs["litellm_call_id"],
print_verbose=print_verbose, print_verbose=print_verbose,
) )
except: except:
print_verbose( print_verbose(
f"Error Occurred while logging failure: {traceback.format_exc()}" f"Error Occurred while logging failure: {traceback.format_exc()}"
@ -767,7 +832,7 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args, k
pass pass
except Exception as e: except Exception as e:
## LOGGING ## LOGGING
logging(logger_fn=user_logger_fn, exception=e) exception_logging(logger_fn=user_logger_fn, exception=e)
pass pass
@ -847,11 +912,12 @@ def handle_success(args, kwargs, result, start_time, end_time):
response_obj=result, response_obj=result,
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
litellm_call_id=kwargs["litellm_call_id"],
print_verbose=print_verbose, print_verbose=print_verbose,
) )
except Exception as e: except Exception as e:
## LOGGING ## LOGGING
logging(logger_fn=user_logger_fn, exception=e) exception_logging(logger_fn=user_logger_fn, exception=e)
print_verbose( print_verbose(
f"[Non-Blocking] Success Callback Error - {traceback.format_exc()}" f"[Non-Blocking] Success Callback Error - {traceback.format_exc()}"
) )
@ -862,7 +928,7 @@ def handle_success(args, kwargs, result, start_time, end_time):
pass pass
except Exception as e: except Exception as e:
## LOGGING ## LOGGING
logging(logger_fn=user_logger_fn, exception=e) exception_logging(logger_fn=user_logger_fn, exception=e)
print_verbose( print_verbose(
f"[Non-Blocking] Success Callback Error - {traceback.format_exc()}" f"[Non-Blocking] Success Callback Error - {traceback.format_exc()}"
) )
@ -910,15 +976,6 @@ def exception_type(model, original_exception, custom_llm_provider):
exception_type = type(original_exception).__name__ exception_type = type(original_exception).__name__
else: else:
exception_type = "" exception_type = ""
logging(
model=model,
additional_args={
"error_str": error_str,
"exception_type": exception_type,
"original_exception": original_exception,
},
logger_fn=user_logger_fn,
)
if "claude" in model: # one of the anthropics if "claude" in model: # one of the anthropics
if hasattr(original_exception, "status_code"): if hasattr(original_exception, "status_code"):
print_verbose(f"status_code: {original_exception.status_code}") print_verbose(f"status_code: {original_exception.status_code}")
@ -1028,7 +1085,7 @@ def exception_type(model, original_exception, custom_llm_provider):
raise original_exception raise original_exception
except Exception as e: except Exception as e:
## LOGGING ## LOGGING
logging( exception_logging(
logger_fn=user_logger_fn, logger_fn=user_logger_fn,
additional_args={ additional_args={
"exception_mapping_worked": exception_mapping_worked, "exception_mapping_worked": exception_mapping_worked,

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "litellm" name = "litellm"
version = "0.1.426" version = "0.1.431"
description = "Library to easily interface with LLM API providers" description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"] authors = ["BerriAI"]
license = "MIT License" license = "MIT License"