diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index b24385b53..4856b5553 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -10,7 +10,7 @@ sys.path.insert( ) # Adds the parent directory to the system path import pytest import litellm -from litellm import embedding, completion, text_completion +from litellm import embedding, completion, text_completion, completion_cost litellm.vertex_project = "pathrise-convert-1606954137718" litellm.vertex_location = "us-central1" @@ -185,14 +185,18 @@ def test_completion_cohere_stream(): def test_completion_openai(): try: response = completion(model="gpt-3.5-turbo", messages=messages) - + response_str = response["choices"][0]["message"]["content"] response_str_2 = response.choices[0].message.content + print("response\n", response) + cost = completion_cost(completion_response=response) + print("Cost for completion call with gpt-3.5-turbo: ", f"${float(cost):.10f}") assert response_str == response_str_2 assert type(response_str) == str assert len(response_str) > 1 except Exception as e: pytest.fail(f"Error occurred: {e}") +# test_completion_openai() def test_completion_openai_prompt(): @@ -361,6 +365,8 @@ def test_completion_replicate_llama_2(): custom_llm_provider="replicate" ) print(response) + cost = completion_cost(completion_response=response) + print("Cost for completion call with llama-2: ", f"${float(cost):.10f}") # Add any assertions here to check the response response_str = response["choices"][0]["message"]["content"] print(response_str) @@ -432,9 +438,11 @@ def test_completion_together_ai(): response = completion(model=model_name, messages=messages, max_tokens=256, logger_fn=logger_fn) # Add any assertions here to check the response print(response) + cost = completion_cost(completion_response=response) + print("Cost for completion call together-computer/llama-2-70b: ", f"${float(cost):.10f}") except Exception as e: pytest.fail(f"Error occurred: {e}") - +# test_completion_together_ai() # def test_customprompt_together_ai(): # try: # litellm.register_prompt_template( diff --git a/litellm/tests/test_completion_cost.py b/litellm/tests/test_completion_cost.py deleted file mode 100644 index f77b6e8ef..000000000 --- a/litellm/tests/test_completion_cost.py +++ /dev/null @@ -1,68 +0,0 @@ - -import sys, os -import traceback -from dotenv import load_dotenv - -load_dotenv() -import os - -sys.path.insert( - 0, os.path.abspath("../..") -) # Adds the parent directory to the system path -import pytest -import litellm -from litellm import embedding, completion, text_completion -from litellm.utils import completion_cost - - -user_message = "Write a short poem about the sky" -messages = [{"content": user_message, "role": "user"}] - - -def test_completion_togetherai_cost(): - try: - response = completion( - model="together_ai/togethercomputer/llama-2-70b-chat", - messages=messages, - request_timeout=200, - ) - # Add any assertions here to check the response - print(response) - print("Completion Cost: for togethercomputer/llama-2-70b-chat") - cost = completion_cost(completion_response=response) - formatted_string = f"${float(cost):.10f}" - print(formatted_string) - - except Exception as e: - pytest.fail(f"Error occurred: {e}") -# test_completion_togetherai_cost() - - -def test_completion_replicate_llama_2(): - model_name = "replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf" - try: - response = completion( - model=model_name, - messages=messages, - max_tokens=20, - custom_llm_provider="replicate" - ) - print(response) - # Add any assertions here to check the response - response_str = response["choices"][0]["message"]["content"] - print(response_str) - - # Add any assertions here to check the response - print(response) - print("Completion Cost: for togethercomputer/llama-2-70b-chat") - cost = completion_cost(completion_response=response) - formatted_string = f"${float(cost):.10f}" - print(formatted_string) - - if type(response_str) != str: - pytest.fail(f"Error occurred: {e}") - except Exception as e: - pytest.fail(f"Error occurred: {e}") - -# v1 -# test_completion_replicate_llama_2() \ No newline at end of file diff --git a/litellm/utils.py b/litellm/utils.py index 8e0d7bd79..bced7067e 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -594,20 +594,19 @@ def get_model_params_and_category(model_name): return None -def get_replicate_completion_pricing(completion_response=None, run_time_in_seconds=0.0): +def get_replicate_completion_pricing(completion_response=None, total_time=0.0): # see https://replicate.com/pricing a100_40gb_price_per_second_public = 0.001150 - # for all litellm currently supported LLMs, almost all requests go to a100_80gb - a100_80gb_price_per_second_public = 0.001400 + a100_80gb_price_per_second_public = 0.001400 # assume all calls sent to A100 80GB for now + if total_time == 0.0: + start_time = completion_response['created'] + end_time = completion_response["ended"] + total_time = end_time - start_time - start_time = completion_response['created'] - end_time = completion_response["ended"] - run_time_in_seconds = end_time - start_time + print("total_replicate_run_time", total_time) - print("total_replicate_run_time", run_time_in_seconds) - - return a100_80gb_price_per_second_public*run_time_in_seconds + return a100_80gb_price_per_second_public*total_time def token_counter(model, text): @@ -657,10 +656,11 @@ def cost_per_token(model="gpt-3.5-turbo", prompt_tokens=0, completion_tokens=0): def completion_cost( + completion_response=None, model="gpt-3.5-turbo", prompt="", completion="", - completion_response=None + total_time=0.0, # used for replicate ): # Handle Inputs to completion_cost @@ -686,8 +686,7 @@ def completion_cost( model in litellm.replicate_models or "replicate" in model ): - return get_replicate_completion_pricing(completion_response) - + return get_replicate_completion_pricing(completion_response, total_time) prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar = cost_per_token( model=model, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens )