fix(cohere.py): return usage as a pydantic object not dict

This commit is contained in:
Krrish Dholakia 2024-03-15 10:00:22 -07:00
parent 5edf414a5f
commit 4e1dc7d62e
2 changed files with 16 additions and 4 deletions

View file

@ -300,8 +300,7 @@ def embedding(
for text in input: for text in input:
input_tokens += len(encoding.encode(text)) input_tokens += len(encoding.encode(text))
model_response["usage"] = { model_response["usage"] = Usage(
"prompt_tokens": input_tokens, prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
"total_tokens": input_tokens, )
}
return model_response return model_response

View file

@ -117,6 +117,8 @@ def test_openai_azure_embedding_simple():
print("Calculated request cost=", request_cost) print("Calculated request cost=", request_cost)
assert isinstance(response.usage, litellm.Usage)
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@ -204,6 +206,8 @@ def test_cohere_embedding():
input=["good morning from litellm", "this is another item"], input=["good morning from litellm", "this is another item"],
) )
print(f"response:", response) print(f"response:", response)
assert isinstance(response.usage, litellm.Usage)
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@ -269,6 +273,8 @@ def test_bedrock_embedding_titan():
assert end_time - start_time < 0.1 assert end_time - start_time < 0.1
litellm.disable_cache() litellm.disable_cache()
assert isinstance(response.usage, litellm.Usage)
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@ -295,6 +301,8 @@ def test_bedrock_embedding_cohere():
isinstance(x, float) for x in response["data"][0]["embedding"] isinstance(x, float) for x in response["data"][0]["embedding"]
), "Expected response to be a list of floats" ), "Expected response to be a list of floats"
# print(f"response:", response) # print(f"response:", response)
assert isinstance(response.usage, litellm.Usage)
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@ -331,6 +339,8 @@ def test_hf_embedding():
input=["good morning from litellm", "this is another item"], input=["good morning from litellm", "this is another item"],
) )
print(f"response:", response) print(f"response:", response)
assert isinstance(response.usage, litellm.Usage)
except Exception as e: except Exception as e:
# Note: Huggingface inference API is unstable and fails with "model loading errors all the time" # Note: Huggingface inference API is unstable and fails with "model loading errors all the time"
pass pass
@ -386,6 +396,8 @@ def test_aembedding_azure():
response._hidden_params["custom_llm_provider"], response._hidden_params["custom_llm_provider"],
) )
assert response._hidden_params["custom_llm_provider"] == "azure" assert response._hidden_params["custom_llm_provider"] == "azure"
assert isinstance(response.usage, litellm.Usage)
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@ -440,6 +452,7 @@ def test_mistral_embeddings():
input=["good morning from litellm"], input=["good morning from litellm"],
) )
print(f"response: {response}") print(f"response: {response}")
assert isinstance(response.usage, litellm.Usage)
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")