add support for litellm proxy calls

This commit is contained in:
Krrish Dholakia 2023-09-18 12:15:19 -07:00
parent 0f88b82c4f
commit 9067ec3b43
7 changed files with 110 additions and 16 deletions

View file

@ -18,6 +18,7 @@ from litellm.utils import (
read_config_args,
completion_with_fallbacks,
get_llm_provider,
get_api_key,
mock_completion_streaming_obj
)
from .llms import anthropic
@ -233,6 +234,10 @@ def completion(
custom_llm_provider = model.split("/", 1)[0]
model = model.split("/", 1)[1]
model, custom_llm_provider = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider)
model_api_key = get_api_key(llm_provider=custom_llm_provider, dynamic_api_key=api_key) # get the api key from the environment if required for the model
if model_api_key and "sk-litellm" in model_api_key:
api_base = "https://proxy.litellm.ai"
custom_llm_provider = "openai"
# check if user passed in any of the OpenAI optional params
optional_params = get_optional_params(
functions=functions,

View file

@ -133,22 +133,22 @@ def test_completion_with_litellm_call_id():
# pytest.fail(f"Error occurred: {e}")
# using Non TGI or conversational LLMs
def hf_test_completion():
try:
# litellm.set_verbose=True
user_message = "My name is Merve and my favorite"
messages = [{ "content": user_message,"role": "user"}]
response = completion(
model="huggingface/roneneldan/TinyStories-3M",
messages=messages,
api_base="https://p69xlsj6rpno5drq.us-east-1.aws.endpoints.huggingface.cloud",
task=None,
)
# Add any assertions here to check the response
print(response)
# def hf_test_completion():
# try:
# # litellm.set_verbose=True
# user_message = "My name is Merve and my favorite"
# messages = [{ "content": user_message,"role": "user"}]
# response = completion(
# model="huggingface/roneneldan/TinyStories-3M",
# messages=messages,
# api_base="https://p69xlsj6rpno5drq.us-east-1.aws.endpoints.huggingface.cloud",
# task=None,
# )
# # Add any assertions here to check the response
# print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# hf_test_completion()
@ -427,6 +427,20 @@ def test_completion_azure_deployment_id():
pytest.fail(f"Error occurred: {e}")
# test_completion_azure_deployment_id()
# def test_completion_anthropic_litellm_proxy():
# try:
# response = completion(
# model="claude-2",
# messages=messages,
# api_key="sk-litellm-1234"
# )
# # Add any assertions here to check the response
# print(response)
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# test_completion_anthropic_litellm_proxy()
# def test_hf_conversational_task():
# try:
# messages = [{ "content": "There's a llama in my garden 😱 What should I do?","role": "user"}]

View file

@ -1129,6 +1129,81 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None):
except Exception as e:
raise e
def get_api_key(llm_provider: str, dynamic_api_key: Optional[str]):
api_key = (dynamic_api_key or litellm.api_key)
# openai
if llm_provider == "openai" or llm_provider == "text-completion-openai":
api_key = (
api_key or
litellm.openai_key or
get_secret("OPENAI_API_KEY")
)
# anthropic
elif llm_provider == "anthropic":
api_key = (
api_key or
litellm.anthropic_key or
get_secret("ANTHROPIC_API_KEY")
)
# ai21
elif llm_provider == "ai21":
api_key = (
api_key or
litellm.ai21_key or
get_secret("AI211_API_KEY")
)
# aleph_alpha
elif llm_provider == "aleph_alpha":
api_key = (
api_key or
litellm.aleph_alpha_key or
get_secret("ALEPH_ALPHA_API_KEY")
)
# baseten
elif llm_provider == "baseten":
api_key = (
api_key or
litellm.baseten_key or
get_secret("BASETEN_API_KEY")
)
# cohere
elif llm_provider == "cohere":
api_key = (
api_key or
litellm.cohere_key or
get_secret("COHERE_API_KEY")
)
# huggingface
elif llm_provider == "huggingface":
api_key = (
api_key or
litellm.huggingface_key or
get_secret("HUGGINGFACE_API_KEY")
)
# nlp_cloud
elif llm_provider == "nlp_cloud":
api_key = (
api_key or
litellm.nlp_cloud_key or
get_secret("NLP_CLOUD_API_KEY")
)
# replicate
elif llm_provider == "replicate":
api_key = (
api_key or
litellm.replicate_key or
get_secret("REPLICATE_API_KEY")
)
# together_ai
elif llm_provider == "together_ai":
api_key = (
api_key or
litellm.togetherai_api_key or
get_secret("TOGETHERAI_API_KEY") or
get_secret("TOGETHER_AI_TOKEN")
)
return api_key
def get_max_tokens(model: str):
try:
return litellm.model_cost[model]

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "litellm"
version = "0.1.692"
version = "0.1.693"
description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"]
license = "MIT License"