allowing user to set model keys dynamically

This commit is contained in:
Krrish Dholakia 2023-08-03 15:05:29 -07:00
parent 123de53475
commit 499d626c76
3 changed files with 41 additions and 7 deletions

View file

@ -4,6 +4,11 @@ set_verbose=False
telemetry=True telemetry=True
max_tokens = 256 # OpenAI Defaults max_tokens = 256 # OpenAI Defaults
retry = True # control tenacity retries. retry = True # control tenacity retries.
openai_key = None
azure_key = None
anthropic_key = None
replicate_key = None
cohere_key = None
####### PROXY PARAMS ################### configurable params if you use proxy models like Helicone ####### PROXY PARAMS ################### configurable params if you use proxy models like Helicone
api_base = None api_base = None
headers = None headers = None
@ -35,7 +40,11 @@ anthropic_models = [
"claude-instant-1" "claude-instant-1"
] ]
model_list = open_ai_chat_completion_models + open_ai_text_completion_models + cohere_models + anthropic_models replicate_models = [
"replicate/"
] # placeholder, to make sure we accept any replicate model in our model_list
model_list = open_ai_chat_completion_models + open_ai_text_completion_models + cohere_models + anthropic_models + replicate_models
####### EMBEDDING MODELS ################### ####### EMBEDDING MODELS ###################
open_ai_embedding_models = [ open_ai_embedding_models = [

View file

@ -6,7 +6,7 @@ from functools import partial
import dotenv import dotenv
import traceback import traceback
import litellm import litellm
from litellm import client, logging, exception_type, timeout, success_callback, failure_callback from litellm import client, logging, exception_type, timeout
import random import random
import asyncio import asyncio
from tenacity import ( from tenacity import (
@ -97,7 +97,12 @@ def completion(
openai.api_type = "azure" openai.api_type = "azure"
openai.api_base = litellm.api_base if litellm.api_base is not None else os.environ.get("AZURE_API_BASE") openai.api_base = litellm.api_base if litellm.api_base is not None else os.environ.get("AZURE_API_BASE")
openai.api_version = os.environ.get("AZURE_API_VERSION") openai.api_version = os.environ.get("AZURE_API_VERSION")
openai.api_key = api_key if api_key is not None else os.environ.get("AZURE_API_KEY") if api_key:
openai.api_key = api_key
elif litellm.azure_key:
openai.api_key = litellm.azure_key
else:
openai.api_key = os.environ.get("AZURE_API_KEY")
## LOGGING ## LOGGING
logging(model=model, input=messages, azure=azure, logger_fn=logger_fn) logging(model=model, input=messages, azure=azure, logger_fn=logger_fn)
## COMPLETION CALL ## COMPLETION CALL
@ -118,7 +123,12 @@ def completion(
openai.api_type = "openai" openai.api_type = "openai"
openai.api_base = litellm.api_base if litellm.api_base is not None else "https://api.openai.com/v1" openai.api_base = litellm.api_base if litellm.api_base is not None else "https://api.openai.com/v1"
openai.api_version = None openai.api_version = None
openai.api_key = api_key if api_key is not None else os.environ.get("OPENAI_API_KEY") if api_key:
openai.api_key = api_key
elif litellm.openai_key:
openai.api_key = litellm.openai_key
else:
openai.api_key = os.environ.get("OPENAI_API_KEY")
## LOGGING ## LOGGING
logging(model=model, input=messages, azure=azure, logger_fn=logger_fn) logging(model=model, input=messages, azure=azure, logger_fn=logger_fn)
## COMPLETION CALL ## COMPLETION CALL
@ -139,7 +149,12 @@ def completion(
openai.api_type = "openai" openai.api_type = "openai"
openai.api_base = litellm.api_base if litellm.api_base is not None else "https://api.openai.com/v1" openai.api_base = litellm.api_base if litellm.api_base is not None else "https://api.openai.com/v1"
openai.api_version = None openai.api_version = None
openai.api_key = api_key if api_key is not None else os.environ.get("OPENAI_API_KEY") if api_key:
openai.api_key = api_key
elif litellm.openai_key:
openai.api_key = litellm.openai_key
else:
openai.api_key = os.environ.get("OPENAI_API_KEY")
prompt = " ".join([message["content"] for message in messages]) prompt = " ".join([message["content"] for message in messages])
## LOGGING ## LOGGING
logging(model=model, input=prompt, azure=azure, logger_fn=logger_fn) logging(model=model, input=prompt, azure=azure, logger_fn=logger_fn)
@ -163,6 +178,9 @@ def completion(
os.environ["REPLICATE_API_TOKEN"] = replicate_api_token os.environ["REPLICATE_API_TOKEN"] = replicate_api_token
elif api_key: elif api_key:
os.environ["REPLICATE_API_TOKEN"] = api_key os.environ["REPLICATE_API_TOKEN"] = api_key
elif litellm.replicate_key:
os.environ["REPLICATE_API_TOKEN"] = litellm.replicate_key
prompt = " ".join([message["content"] for message in messages]) prompt = " ".join([message["content"] for message in messages])
input = {"prompt": prompt} input = {"prompt": prompt}
if max_tokens != float('inf'): if max_tokens != float('inf'):
@ -194,6 +212,8 @@ def completion(
#anthropic defaults to os.environ.get("ANTHROPIC_API_KEY") #anthropic defaults to os.environ.get("ANTHROPIC_API_KEY")
if api_key: if api_key:
os.environ["ANTHROPIC_API_KEY"] = api_key os.environ["ANTHROPIC_API_KEY"] = api_key
elif litellm.anthropic_key:
os.environ["ANTHROPIC_API_TOKEN"] = litellm.anthropic_key
prompt = f"{HUMAN_PROMPT}" prompt = f"{HUMAN_PROMPT}"
for message in messages: for message in messages:
if "role" in message: if "role" in message:
@ -233,7 +253,12 @@ def completion(
print_verbose(f"new response: {new_response}") print_verbose(f"new response: {new_response}")
response = new_response response = new_response
elif model in litellm.cohere_models: elif model in litellm.cohere_models:
cohere_key = api_key if api_key is not None else os.environ.get("COHERE_API_KEY") if api_key:
cohere_key = api_key
elif litellm.api_key:
cohere_key = litellm.api_key
else:
cohere_key = os.environ.get("COHERE_API_KEY")
co = cohere.Client(cohere_key) co = cohere.Client(cohere_key)
prompt = " ".join([message["content"] for message in messages]) prompt = " ".join([message["content"] for message in messages])
## LOGGING ## LOGGING

View file

@ -2,7 +2,7 @@ from setuptools import setup, find_packages
setup( setup(
name='litellm', name='litellm',
version='0.1.2291', version='0.1.230',
description='Library to easily interface with LLM API providers', description='Library to easily interface with LLM API providers',
author='BerriAI', author='BerriAI',
packages=[ packages=[