From dc0c0848133dbe7f616f05354cc23e28d476b4bd Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Wed, 6 Sep 2023 12:29:32 -0700 Subject: [PATCH] improve replicate usage --- docs/my-website/docs/providers/replicate.md | 16 ++++++++++ litellm/main.py | 34 +++++++++------------ 2 files changed, 30 insertions(+), 20 deletions(-) diff --git a/docs/my-website/docs/providers/replicate.md b/docs/my-website/docs/providers/replicate.md index f55408d35..c0671fc1e 100644 --- a/docs/my-website/docs/providers/replicate.md +++ b/docs/my-website/docs/providers/replicate.md @@ -9,6 +9,22 @@ os.environ["REPLICATE_API_KEY"] = "" ``` +### Example Call + +```python + +from litellm import completion + +## set ENV variables +os.environ["REPLICATE_API_KEY"] = "replicate key" + +messages = [{ "content": "Hello, how are you?","role": "user"}] + +# replicate llama-2 call +response = completion("replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf", messages) +``` + + ### Replicate Models liteLLM supports all replicate LLMs. For replicate models ensure to add a `replicate` prefix to the `model` arg. liteLLM detects it using this arg. Below are examples on how to call replicate LLMs using liteLLM diff --git a/litellm/main.py b/litellm/main.py index d27daf403..539ccdd12 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -116,7 +116,7 @@ def completion( model ] # update the model to the actual value if an alias has been passed in model_response = ModelResponse() - if deployment_id != None: + if deployment_id != None: # azure llms model=deployment_id custom_llm_provider="azure" elif ( @@ -124,10 +124,7 @@ def completion( ): # allow custom provider to be passed in via the model name "azure/chatgpt-test" custom_llm_provider = model.split("/", 1)[0] model = model.split("/", 1)[1] - if ( - "replicate" == custom_llm_provider and "/" not in model - ): # handle the "replicate/llama2..." edge-case - model = custom_llm_provider + "/" + model + # check if user passed in any of the OpenAI optional params optional_params = get_optional_params( functions=functions, @@ -340,22 +337,19 @@ def completion( model_response["model"] = model model_response["usage"] = response["usage"] response = model_response - elif "replicate" in model or custom_llm_provider == "replicate": - # import replicate/if it fails then pip install replicate - - + elif ( + "replicate" in model or + custom_llm_provider == "replicate" or + model in litellm.replicate_models + ): # Setting the relevant API KEY for replicate, replicate defaults to using os.environ.get("REPLICATE_API_TOKEN") - replicate_key = os.environ.get("REPLICATE_API_TOKEN") - if replicate_key == None: - # user did not set REPLICATE_API_TOKEN in .env - replicate_key = ( - get_secret("REPLICATE_API_KEY") - or get_secret("REPLICATE_API_TOKEN") - or api_key - or litellm.replicate_key - ) - # set replicate key - os.environ["REPLICATE_API_TOKEN"] = str(replicate_key) + replicate_key = None + replicate_key = ( + get_secret("REPLICATE_API_KEY") + or get_secret("REPLICATE_API_TOKEN") + or api_key + or litellm.replicate_key + ) model_response = replicate.completion( model=model,