diff --git a/litellm/__init__.py b/litellm/__init__.py index 837ca2b01..1ef9b9af2 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -339,7 +339,7 @@ cohere_embedding_models: List = [ "embed-english-light-v2.0", "embed-multilingual-v2.0", ] -bedrock_embedding_models: List = ["amazon.titan-embed-text-v1"] +bedrock_embedding_models: List = ["amazon.titan-embed-text-v1", "cohere.embed-english-v3", "cohere.embed-multilingual-v3"] all_embedding_models = open_ai_embedding_models + cohere_embedding_models + bedrock_embedding_models diff --git a/litellm/llms/bedrock.py b/litellm/llms/bedrock.py index 30aa3e6ce..2559027a0 100644 --- a/litellm/llms/bedrock.py +++ b/litellm/llms/bedrock.py @@ -2,7 +2,7 @@ import json, copy, types import os from enum import Enum import time -from typing import Callable, Optional +from typing import Callable, Optional, Any import litellm from litellm.utils import ModelResponse, get_secret, Usage from .prompt_templates.factory import prompt_factory, custom_prompt @@ -205,15 +205,25 @@ class AmazonLlamaConfig(): def init_bedrock_client( region_name = None, - aws_access_key_id = None, - aws_secret_access_key = None, - aws_region_name=None, - aws_bedrock_runtime_endpoint=None, + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_region_name: Optional[str] =None, + aws_bedrock_runtime_endpoint: Optional[str]=None, ): - # check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client - litellm_aws_region_name = get_secret("AWS_REGION_NAME") - standard_aws_region_name = get_secret("AWS_REGION") + litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) + standard_aws_region_name = get_secret("AWS_REGION", None) + + ## CHECK IS 'os.environ/' passed in + # Define the list of parameters to check + params_to_check = [aws_access_key_id, aws_secret_access_key, aws_region_name, aws_bedrock_runtime_endpoint] + + # Iterate over parameters and update if needed + for i, param in enumerate(params_to_check): + if param and param.startswith('os.environ/'): + params_to_check[i] = get_secret(param) + # Assign updated values back to parameters + aws_access_key_id, aws_secret_access_key, aws_region_name, aws_bedrock_runtime_endpoint = params_to_check if region_name: pass elif aws_region_name: @@ -533,37 +543,56 @@ def completion( def _embedding_func_single( model: str, input: str, + client: Any, optional_params=None, encoding=None, + logging_obj=None, ): # logic for parsing in - calling - parsing out model embedding calls - # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them - aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) - aws_access_key_id = optional_params.pop("aws_access_key_id", None) - aws_region_name = optional_params.pop("aws_region_name", None) - - # use passed in BedrockRuntime.Client if provided, otherwise create a new one - client = optional_params.pop( - "aws_bedrock_client", - # only pass variables that are not None - init_bedrock_client( - aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key, - aws_region_name=aws_region_name, - ), + ## FORMAT EMBEDDING INPUT ## + provider = model.split(".")[0] + inference_params = copy.deepcopy(optional_params) + if provider == "amazon": + input = input.replace(os.linesep, " ") + data = {"inputText": input, **inference_params} + # data = json.dumps(data) + elif provider == "cohere": + inference_params["input_type"] = inference_params.get("input_type", "search_document") # aws bedrock example default - https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/providers?model=cohere.embed-english-v3 + data = {"texts": [input], **inference_params} + body = json.dumps(data).encode("utf-8") + ## LOGGING + request_str = f""" + response = client.invoke_model( + body={body}, + modelId={model}, + accept="*/*", + contentType="application/json", + )""" # type: ignore + logging_obj.pre_call( + input=input, + api_key="", # boto3 is used for init. + additional_args={"complete_input_dict": {"model": model, + "texts": input}, "request_str": request_str}, ) - - input = input.replace(os.linesep, " ") - body = json.dumps({"inputText": input}) try: response = client.invoke_model( body=body, - modelId=model, - accept="application/json", + modelId="cohere.embed-multilingual-v3", + accept="*/*", contentType="application/json", ) response_body = json.loads(response.get("body").read()) - return response_body.get("embedding") + ## LOGGING + logging_obj.post_call( + input=input, + api_key="", + additional_args={"complete_input_dict": data}, + original_response=response_body, + ) + if provider == "cohere": + return response_body.get("embeddings") + elif provider == "amazon": + return response_body.get("embedding") except Exception as e: raise BedrockError(message=f"Embedding Error with model {model}: {e}", status_code=500) @@ -576,17 +605,21 @@ def embedding( optional_params=None, encoding=None, ): + ### BOTO3 INIT ### + # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them + aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) + aws_access_key_id = optional_params.pop("aws_access_key_id", None) + aws_region_name = optional_params.pop("aws_region_name", None) - ## LOGGING - logging_obj.pre_call( - input=input, - api_key=api_key, - additional_args={"complete_input_dict": {"model": model, - "texts": input}}, - ) + # use passed in BedrockRuntime.Client if provided, otherwise create a new one + client = init_bedrock_client( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_region_name=aws_region_name, + ) ## Embedding Call - embeddings = [_embedding_func_single(model, i, optional_params) for i in input] + embeddings = [_embedding_func_single(model, i, optional_params=optional_params, client=client, logging_obj=logging_obj) for i in input] # [TODO]: make these parallel calls ## Populate OpenAI compliant dictionary diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index f3d3c5fc0..2d32df1cc 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -843,6 +843,7 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap data["user"] = user_api_key_dict.user_id if "metadata" in data: + print(f'received metadata: {data["metadata"]}') data["metadata"]["user_api_key"] = user_api_key_dict.api_key data["metadata"]["headers"] = request.headers else: diff --git a/litellm/tests/test_embedding.py b/litellm/tests/test_embedding.py index f958d6cfc..5f1043728 100644 --- a/litellm/tests/test_embedding.py +++ b/litellm/tests/test_embedding.py @@ -151,7 +151,7 @@ def test_cohere_embedding3(): # test_cohere_embedding3() -def test_bedrock_embedding(): +def test_bedrock_embedding_titan(): try: response = embedding( model="amazon.titan-embed-text-v1", input=["good morning from litellm, attempting to embed data", @@ -162,6 +162,19 @@ def test_bedrock_embedding(): pytest.fail(f"Error occurred: {e}") # test_bedrock_embedding() +def test_bedrock_embedding_cohere(): + try: + # litellm.set_verbose=True + response = embedding( + model="cohere.embed-multilingual-v3", input=["good morning from litellm, attempting to embed data", "lets test a second string for good measure"], + aws_region_name="os.environ/AWS_REGION_NAME_2" + ) + # print(f"response:", response) + except Exception as e: + pytest.fail(f"Error occurred: {e}") + +test_bedrock_embedding_cohere() + # comment out hf tests - since hf endpoints are unstable def test_hf_embedding(): try: