From 8b3b682000a2b76bafbc1649a70b8eaee6472328 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Wed, 6 Sep 2023 18:14:33 -0700 Subject: [PATCH] add replicate pricing --- litellm/llms/replicate.py | 7 +++--- litellm/tests/test_completion_cost.py | 32 ++++++++++++++++++++++++++- litellm/utils.py | 29 +++++++++++++++++++++++- model_prices_and_context_window.json | 4 +--- 4 files changed, 63 insertions(+), 9 deletions(-) diff --git a/litellm/llms/replicate.py b/litellm/llms/replicate.py index f3c7a52c8a..e63344492f 100644 --- a/litellm/llms/replicate.py +++ b/litellm/llms/replicate.py @@ -125,6 +125,7 @@ def completion( ## Step1: Start Prediction: gets a prediction url ## Step2: Poll prediction url for response ## Step2: is handled with and without streaming + model_response["created"] = time.time() # for pricing this must remain right before calling api prediction_url = start_prediction(version_id, input_data, api_key, logging_obj=logging_obj) print_verbose(prediction_url) @@ -134,7 +135,7 @@ def completion( return handle_prediction_response_streaming(prediction_url, api_key, print_verbose) else: result, logs = handle_prediction_response(prediction_url, api_key, print_verbose) - + model_response["ended"] = time.time() # for pricing this must remain right after calling api ## LOGGING logging_obj.post_call( input=prompt, @@ -154,8 +155,7 @@ def completion( # Calculate usage prompt_tokens = len(encoding.encode(prompt)) completion_tokens = len(encoding.encode(model_response["choices"][0]["message"]["content"])) - model_response["created"] = time.time() - model_response["model"] = model + model_response["model"] = "replicate/" + model model_response["usage"] = { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, @@ -164,7 +164,6 @@ def completion( return model_response - # # Example usage: # response = completion( # api_key="", diff --git a/litellm/tests/test_completion_cost.py b/litellm/tests/test_completion_cost.py index 082156ac63..c90a1df2bc 100644 --- a/litellm/tests/test_completion_cost.py +++ b/litellm/tests/test_completion_cost.py @@ -35,4 +35,34 @@ def test_completion_togetherai_cost(): except Exception as e: pytest.fail(f"Error occurred: {e}") -test_completion_togetherai_cost() \ No newline at end of file +# test_completion_togetherai_cost() + + +def test_completion_replicate_llama_2(): + model_name = "replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf" + try: + response = completion( + model=model_name, + messages=messages, + max_tokens=20, + custom_llm_provider="replicate" + ) + print(response) + # Add any assertions here to check the response + response_str = response["choices"][0]["message"]["content"] + print(response_str) + + # Add any assertions here to check the response + print(response) + print("Completion Cost: for togethercomputer/llama-2-70b-chat") + cost = completion_cost(completion_response=response) + formatted_string = f"${float(cost):.10f}" + print(formatted_string) + + if type(response_str) != str: + pytest.fail(f"Error occurred: {e}") + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + +# test_completion_replicate_llama_2() \ No newline at end of file diff --git a/litellm/utils.py b/litellm/utils.py index 020cca0934..8e0d7bd794 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -529,7 +529,7 @@ def client(original_function): # TODO: Add to cache for streaming return result - + # [OPTIONAL] ADD TO CACHE if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object litellm.cache.add_cache(result, *args, **kwargs) @@ -594,6 +594,21 @@ def get_model_params_and_category(model_name): return None +def get_replicate_completion_pricing(completion_response=None, run_time_in_seconds=0.0): + # see https://replicate.com/pricing + a100_40gb_price_per_second_public = 0.001150 + + # for all litellm currently supported LLMs, almost all requests go to a100_80gb + a100_80gb_price_per_second_public = 0.001400 + + start_time = completion_response['created'] + end_time = completion_response["ended"] + run_time_in_seconds = end_time - start_time + + print("total_replicate_run_time", run_time_in_seconds) + + return a100_80gb_price_per_second_public*run_time_in_seconds + def token_counter(model, text): # use tiktoken or anthropic's tokenizer depending on the model @@ -647,6 +662,8 @@ def completion_cost( completion="", completion_response=None ): + + # Handle Inputs to completion_cost prompt_tokens = 0 completion_tokens = 0 if completion_response != None: @@ -657,10 +674,20 @@ def completion_cost( else: prompt_tokens = token_counter(model=model, text=prompt) completion_tokens = token_counter(model=model, text=completion) + + # Calculate cost based on prompt_tokens, completion_tokens if "togethercomputer" in model: # together ai prices based on size of llm # get_model_params_and_category takes a model name and returns the category of LLM size it is in model_prices_and_context_window.json model = get_model_params_and_category(model) + # replicate llms are calculate based on time for request running + # see https://replicate.com/pricing + elif ( + model in litellm.replicate_models or + "replicate" in model + ): + return get_replicate_completion_pricing(completion_response) + prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar = cost_per_token( model=model, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens ) diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index 289f8faf9e..d6bf6416a1 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -65,9 +65,7 @@ "output_cost_per_token": 0.000015 }, "replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1": { - "max_tokens": 4096, - "input_cost_per_token": 0.00000608, - "output_cost_per_token": 0.00000608 + "max_tokens": 4096 }, "together-ai-up-to-3b": { "input_cost_per_token": 0.0000001,