diff --git a/litellm/llms/bedrock.py b/litellm/llms/bedrock.py index 7e7ee2ffc..ee655174c 100644 --- a/litellm/llms/bedrock.py +++ b/litellm/llms/bedrock.py @@ -483,15 +483,11 @@ def completion( import traceback raise BedrockError(status_code=500, message=traceback.format_exc()) - - -def embedding( - model: str, - input: list, - logging_obj=None, - model_response=None, - optional_params=None, - encoding=None, +def _embedding_func_single( + model: str, + input: str, + optional_params=None, + encoding=None, ): # logic for parsing in - calling - parsing out model embedding calls # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them @@ -509,38 +505,61 @@ def embedding( aws_region_name=aws_region_name, ), ) - - # translate to bedrock - # bedrock only accepts (str) for inputText - if type(input) == list: - if len(input) > 1: # input is a list with more than 1 elem, raise Exception, Bedrock only supports one element - raise BedrockError(message="Bedrock cannot embed() more than one string - len(input) must always == 1, input = ['hi from litellm']", status_code=400) - input_str = "".join(input) - response = client.invoke_model( - body=json.dumps({ - "inputText": input_str - }), - modelId=model, - accept="*/*", - contentType="application/json" + input = input.replace(os.linesep, " ") + body = json.dumps({"inputText": input}) + try: + response = client.invoke_model( + body=body, + modelId=model, + accept="application/json", + contentType="application/json", + ) + response_body = json.loads(response.get("body").read()) + return response_body.get("embedding") + except Exception as e: + raise BedrockError(message=f"Embedding Error with model {model}: {e}", status_code=500) + +def embedding( + model: str, + input: list, + api_key: Optional[str] = None, + logging_obj=None, + model_response=None, + optional_params=None, + encoding=None, +): + + ## LOGGING + logging_obj.pre_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": {"model": model, + "texts": input}}, ) - response_body = json.loads(response.get('body').read()) + ## Embedding Call + embeddings = [_embedding_func_single(model, i, optional_params) for i in input] - embedding_response = response_body["embedding"] + ## Populate OpenAI compliant dictionary + embedding_response = [] + for idx, embedding in enumerate(embeddings): + embedding_response.append( + { + "object": "embedding", + "index": idx, + "embedding": embedding, + } + ) model_response["object"] = "list" model_response["data"] = embedding_response model_response["model"] = model input_tokens = 0 - input_tokens+=len(encoding.encode(input_str)) + input_str = "".join(input) - model_response["usage"] = { - "prompt_tokens": input_tokens, - "total_tokens": input_tokens, - } + input_tokens+=len(encoding.encode(input_str)) usage = Usage( prompt_tokens=input_tokens, @@ -549,4 +568,13 @@ def embedding( ) model_response.usage = usage + ## LOGGING + logging_obj.post_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": {"model": model, + "texts": input}}, + original_response=embeddings, + ) + return model_response diff --git a/litellm/tests/test_embedding.py b/litellm/tests/test_embedding.py index d533b7c42..546c3dd59 100644 --- a/litellm/tests/test_embedding.py +++ b/litellm/tests/test_embedding.py @@ -95,7 +95,8 @@ def test_cohere_embedding3(): def test_bedrock_embedding(): try: response = embedding( - model="amazon.titan-embed-text-v1", input=["good morning from litellm, attempting to embed data"] + model="amazon.titan-embed-text-v1", input=["good morning from litellm, attempting to embed data", + "lets test a second string for good measure"] ) print(f"response:", response) except Exception as e: