forked from phoenix/litellm-mirror
replacing individual provider flags with 'custom_llm_provider'
This commit is contained in:
parent
bc767cc42a
commit
72c1b5dcfc
3 changed files with 16 additions and 20 deletions
|
@ -285,7 +285,7 @@ def completion(
|
||||||
|
|
||||||
completion_response = response[0].text
|
completion_response = response[0].text
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging(model=model, input=prompt, azure=azure, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn)
|
logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn)
|
||||||
prompt_tokens = len(encoding.encode(prompt))
|
prompt_tokens = len(encoding.encode(prompt))
|
||||||
completion_tokens = len(encoding.encode(completion_response))
|
completion_tokens = len(encoding.encode(completion_response))
|
||||||
## RESPONSE OBJECT
|
## RESPONSE OBJECT
|
||||||
|
@ -306,11 +306,11 @@ def completion(
|
||||||
|
|
||||||
prompt = " ".join([message["content"] for message in messages])
|
prompt = " ".join([message["content"] for message in messages])
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging(model=model, input=prompt, azure=azure, logger_fn=logger_fn)
|
logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, logger_fn=logger_fn)
|
||||||
input_payload = {"inputs": prompt}
|
input_payload = {"inputs": prompt}
|
||||||
response = requests.post(API_URL, headers=headers, json=input_payload)
|
response = requests.post(API_URL, headers=headers, json=input_payload)
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging(model=model, input=prompt, azure=azure, additional_args={"max_tokens": max_tokens, "original_response": response.text}, logger_fn=logger_fn)
|
logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, additional_args={"max_tokens": max_tokens, "original_response": response.text}, logger_fn=logger_fn)
|
||||||
completion_response = response.json()[0]['generated_text']
|
completion_response = response.json()[0]['generated_text']
|
||||||
prompt_tokens = len(encoding.encode(prompt))
|
prompt_tokens = len(encoding.encode(prompt))
|
||||||
completion_tokens = len(encoding.encode(completion_response))
|
completion_tokens = len(encoding.encode(completion_response))
|
||||||
|
@ -332,7 +332,7 @@ def completion(
|
||||||
prompt = " ".join([message["content"] for message in messages]) # TODO: Add chat support for together AI
|
prompt = " ".join([message["content"] for message in messages]) # TODO: Add chat support for together AI
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging(model=model, input=prompt, azure=azure, logger_fn=logger_fn)
|
logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, logger_fn=logger_fn)
|
||||||
res = requests.post(endpoint, json={
|
res = requests.post(endpoint, json={
|
||||||
"model": model,
|
"model": model,
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
|
@ -342,7 +342,7 @@ def completion(
|
||||||
headers=headers
|
headers=headers
|
||||||
)
|
)
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging(model=model, input=prompt, azure=azure, additional_args={"max_tokens": max_tokens, "original_response": res.text}, logger_fn=logger_fn)
|
logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, additional_args={"max_tokens": max_tokens, "original_response": res.text}, logger_fn=logger_fn)
|
||||||
if stream == True:
|
if stream == True:
|
||||||
response = CustomStreamWrapper(res, "together_ai")
|
response = CustomStreamWrapper(res, "together_ai")
|
||||||
return response
|
return response
|
||||||
|
|
|
@ -37,7 +37,7 @@ def test_context_window(model):
|
||||||
try:
|
try:
|
||||||
azure = model == "chatgpt-test"
|
azure = model == "chatgpt-test"
|
||||||
print(f"model: {model}")
|
print(f"model: {model}")
|
||||||
response = completion(model=model, messages=messages, azure=azure, logger_fn=logging_fn)
|
response = completion(model=model, messages=messages, custom_llm_provider=custom_llm_provider, logger_fn=logging_fn)
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
except InvalidRequestError:
|
except InvalidRequestError:
|
||||||
print("InvalidRequestError")
|
print("InvalidRequestError")
|
||||||
|
@ -59,14 +59,14 @@ def invalid_auth(model): # set the model key to an invalid key, depending on the
|
||||||
messages = [{ "content": "Hello, how are you?","role": "user"}]
|
messages = [{ "content": "Hello, how are you?","role": "user"}]
|
||||||
temporary_key = None
|
temporary_key = None
|
||||||
try:
|
try:
|
||||||
azure = False
|
custom_llm_provider = None
|
||||||
if model == "gpt-3.5-turbo":
|
if model == "gpt-3.5-turbo":
|
||||||
temporary_key = os.environ["OPENAI_API_KEY"]
|
temporary_key = os.environ["OPENAI_API_KEY"]
|
||||||
os.environ["OPENAI_API_KEY"] = "bad-key"
|
os.environ["OPENAI_API_KEY"] = "bad-key"
|
||||||
elif model == "chatgpt-test":
|
elif model == "chatgpt-test":
|
||||||
temporary_key = os.environ["AZURE_API_KEY"]
|
temporary_key = os.environ["AZURE_API_KEY"]
|
||||||
os.environ["AZURE_API_KEY"] = "bad-key"
|
os.environ["AZURE_API_KEY"] = "bad-key"
|
||||||
azure = True
|
custom_llm_provider = "azure"
|
||||||
elif model == "claude-instant-1":
|
elif model == "claude-instant-1":
|
||||||
temporary_key = os.environ["ANTHROPIC_API_KEY"]
|
temporary_key = os.environ["ANTHROPIC_API_KEY"]
|
||||||
os.environ["ANTHROPIC_API_KEY"] = "bad-key"
|
os.environ["ANTHROPIC_API_KEY"] = "bad-key"
|
||||||
|
@ -77,7 +77,7 @@ def invalid_auth(model): # set the model key to an invalid key, depending on the
|
||||||
temporary_key = os.environ["REPLICATE_API_KEY"]
|
temporary_key = os.environ["REPLICATE_API_KEY"]
|
||||||
os.environ["REPLICATE_API_KEY"] = "bad-key"
|
os.environ["REPLICATE_API_KEY"] = "bad-key"
|
||||||
print(f"model: {model}")
|
print(f"model: {model}")
|
||||||
response = completion(model=model, messages=messages, azure=azure)
|
response = completion(model=model, messages=messages, custom_llm_provider=custom_llm_provider)
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
except AuthenticationError as e:
|
except AuthenticationError as e:
|
||||||
print(f"AuthenticationError Caught Exception - {e}")
|
print(f"AuthenticationError Caught Exception - {e}")
|
||||||
|
@ -107,11 +107,11 @@ invalid_auth("command-nightly")
|
||||||
# try:
|
# try:
|
||||||
# sample_text = "how does a court case get to the Supreme Court?" * 50000
|
# sample_text = "how does a court case get to the Supreme Court?" * 50000
|
||||||
# messages = [{ "content": sample_text,"role": "user"}]
|
# messages = [{ "content": sample_text,"role": "user"}]
|
||||||
# azure = False
|
# custom_llm_provider = None
|
||||||
# if model == "chatgpt-test":
|
# if model == "chatgpt-test":
|
||||||
# azure = True
|
# custom_llm_provider = "azure"
|
||||||
# print(f"model: {model}")
|
# print(f"model: {model}")
|
||||||
# response = completion(model=model, messages=messages, azure=azure)
|
# response = completion(model=model, messages=messages, custom_llm_provider=custom_llm_provider)
|
||||||
# except RateLimitError:
|
# except RateLimitError:
|
||||||
# return True
|
# return True
|
||||||
# except OpenAIError: # is at least an openai error -> in case of random model errors - e.g. overloaded server
|
# except OpenAIError: # is at least an openai error -> in case of random model errors - e.g. overloaded server
|
||||||
|
|
|
@ -75,10 +75,6 @@ def logging(model=None, input=None, custom_llm_provider=None, azure=False, addit
|
||||||
model_call_details["custom_llm_provider"] = custom_llm_provider
|
model_call_details["custom_llm_provider"] = custom_llm_provider
|
||||||
if exception:
|
if exception:
|
||||||
model_call_details["exception"] = exception
|
model_call_details["exception"] = exception
|
||||||
|
|
||||||
# if litellm.telemetry:
|
|
||||||
# safe_crash_reporting(model=model, exception=exception, azure=azure) # log usage-crash details. Do not log any user details. If you want to turn this off, set `litellm.telemetry=False`.
|
|
||||||
|
|
||||||
if input:
|
if input:
|
||||||
model_call_details["input"] = input
|
model_call_details["input"] = input
|
||||||
|
|
||||||
|
@ -134,8 +130,8 @@ def client(original_function):
|
||||||
try:
|
try:
|
||||||
model = args[0] if len(args) > 0 else kwargs["model"]
|
model = args[0] if len(args) > 0 else kwargs["model"]
|
||||||
exception = kwargs["exception"] if "exception" in kwargs else None
|
exception = kwargs["exception"] if "exception" in kwargs else None
|
||||||
azure = kwargs["azure"] if "azure" in kwargs else None
|
custom_llm_provider = kwargs["custom_llm_provider"] if "custom_llm_provider" in kwargs else None
|
||||||
safe_crash_reporting(model=model, exception=exception, azure=azure) # log usage-crash details. Do not log any user details. If you want to turn this off, set `litellm.telemetry=False`.
|
safe_crash_reporting(model=model, exception=exception, custom_llm_provider=custom_llm_provider) # log usage-crash details. Do not log any user details. If you want to turn this off, set `litellm.telemetry=False`.
|
||||||
except:
|
except:
|
||||||
#[Non-Blocking Error]
|
#[Non-Blocking Error]
|
||||||
pass
|
pass
|
||||||
|
@ -647,11 +643,11 @@ def exception_type(model, original_exception):
|
||||||
else: # don't let an error with mapping interrupt the user from receiving an error from the llm api calls
|
else: # don't let an error with mapping interrupt the user from receiving an error from the llm api calls
|
||||||
raise original_exception
|
raise original_exception
|
||||||
|
|
||||||
def safe_crash_reporting(model=None, exception=None, azure=None):
|
def safe_crash_reporting(model=None, exception=None, custom_llm_provider=None):
|
||||||
data = {
|
data = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"exception": str(exception),
|
"exception": str(exception),
|
||||||
"azure": azure
|
"custom_llm_provider": custom_llm_provider
|
||||||
}
|
}
|
||||||
threading.Thread(target=litellm_telemetry, args=(data,)).start()
|
threading.Thread(target=litellm_telemetry, args=(data,)).start()
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue