diff --git a/litellm/llms/cohere.py b/litellm/llms/cohere.py index 960dc66d3..a09e249af 100644 --- a/litellm/llms/cohere.py +++ b/litellm/llms/cohere.py @@ -300,8 +300,7 @@ def embedding( for text in input: input_tokens += len(encoding.encode(text)) - model_response["usage"] = { - "prompt_tokens": input_tokens, - "total_tokens": input_tokens, - } + model_response["usage"] = Usage( + prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens + ) return model_response diff --git a/litellm/tests/test_embedding.py b/litellm/tests/test_embedding.py index c32a55353..7eecca60b 100644 --- a/litellm/tests/test_embedding.py +++ b/litellm/tests/test_embedding.py @@ -117,6 +117,8 @@ def test_openai_azure_embedding_simple(): print("Calculated request cost=", request_cost) + assert isinstance(response.usage, litellm.Usage) + except Exception as e: pytest.fail(f"Error occurred: {e}") @@ -204,6 +206,8 @@ def test_cohere_embedding(): input=["good morning from litellm", "this is another item"], ) print(f"response:", response) + + assert isinstance(response.usage, litellm.Usage) except Exception as e: pytest.fail(f"Error occurred: {e}") @@ -269,6 +273,8 @@ def test_bedrock_embedding_titan(): assert end_time - start_time < 0.1 litellm.disable_cache() + + assert isinstance(response.usage, litellm.Usage) except Exception as e: pytest.fail(f"Error occurred: {e}") @@ -295,6 +301,8 @@ def test_bedrock_embedding_cohere(): isinstance(x, float) for x in response["data"][0]["embedding"] ), "Expected response to be a list of floats" # print(f"response:", response) + + assert isinstance(response.usage, litellm.Usage) except Exception as e: pytest.fail(f"Error occurred: {e}") @@ -331,6 +339,8 @@ def test_hf_embedding(): input=["good morning from litellm", "this is another item"], ) print(f"response:", response) + + assert isinstance(response.usage, litellm.Usage) except Exception as e: # Note: Huggingface inference API is unstable and fails with "model loading errors all the time" pass @@ -386,6 +396,8 @@ def test_aembedding_azure(): response._hidden_params["custom_llm_provider"], ) assert response._hidden_params["custom_llm_provider"] == "azure" + + assert isinstance(response.usage, litellm.Usage) except Exception as e: pytest.fail(f"Error occurred: {e}") @@ -440,6 +452,7 @@ def test_mistral_embeddings(): input=["good morning from litellm"], ) print(f"response: {response}") + assert isinstance(response.usage, litellm.Usage) except Exception as e: pytest.fail(f"Error occurred: {e}")