forked from phoenix/litellm-mirror
fix(bedrock.py): adding support for cohere embeddings
This commit is contained in:
parent
cf6ecc03a5
commit
d962d5d4c0
4 changed files with 85 additions and 38 deletions
|
@ -339,7 +339,7 @@ cohere_embedding_models: List = [
|
|||
"embed-english-light-v2.0",
|
||||
"embed-multilingual-v2.0",
|
||||
]
|
||||
bedrock_embedding_models: List = ["amazon.titan-embed-text-v1"]
|
||||
bedrock_embedding_models: List = ["amazon.titan-embed-text-v1", "cohere.embed-english-v3", "cohere.embed-multilingual-v3"]
|
||||
|
||||
all_embedding_models = open_ai_embedding_models + cohere_embedding_models + bedrock_embedding_models
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ import json, copy, types
|
|||
import os
|
||||
from enum import Enum
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, Optional, Any
|
||||
import litellm
|
||||
from litellm.utils import ModelResponse, get_secret, Usage
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
|
@ -205,15 +205,25 @@ class AmazonLlamaConfig():
|
|||
|
||||
def init_bedrock_client(
|
||||
region_name = None,
|
||||
aws_access_key_id = None,
|
||||
aws_secret_access_key = None,
|
||||
aws_region_name=None,
|
||||
aws_bedrock_runtime_endpoint=None,
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
aws_secret_access_key: Optional[str] = None,
|
||||
aws_region_name: Optional[str] =None,
|
||||
aws_bedrock_runtime_endpoint: Optional[str]=None,
|
||||
):
|
||||
|
||||
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
|
||||
litellm_aws_region_name = get_secret("AWS_REGION_NAME")
|
||||
standard_aws_region_name = get_secret("AWS_REGION")
|
||||
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||
standard_aws_region_name = get_secret("AWS_REGION", None)
|
||||
|
||||
## CHECK IS 'os.environ/' passed in
|
||||
# Define the list of parameters to check
|
||||
params_to_check = [aws_access_key_id, aws_secret_access_key, aws_region_name, aws_bedrock_runtime_endpoint]
|
||||
|
||||
# Iterate over parameters and update if needed
|
||||
for i, param in enumerate(params_to_check):
|
||||
if param and param.startswith('os.environ/'):
|
||||
params_to_check[i] = get_secret(param)
|
||||
# Assign updated values back to parameters
|
||||
aws_access_key_id, aws_secret_access_key, aws_region_name, aws_bedrock_runtime_endpoint = params_to_check
|
||||
if region_name:
|
||||
pass
|
||||
elif aws_region_name:
|
||||
|
@ -533,36 +543,55 @@ def completion(
|
|||
def _embedding_func_single(
|
||||
model: str,
|
||||
input: str,
|
||||
client: Any,
|
||||
optional_params=None,
|
||||
encoding=None,
|
||||
logging_obj=None,
|
||||
):
|
||||
# logic for parsing in - calling - parsing out model embedding calls
|
||||
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||
|
||||
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
||||
client = optional_params.pop(
|
||||
"aws_bedrock_client",
|
||||
# only pass variables that are not None
|
||||
init_bedrock_client(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_region_name=aws_region_name,
|
||||
),
|
||||
)
|
||||
|
||||
## FORMAT EMBEDDING INPUT ##
|
||||
provider = model.split(".")[0]
|
||||
inference_params = copy.deepcopy(optional_params)
|
||||
if provider == "amazon":
|
||||
input = input.replace(os.linesep, " ")
|
||||
body = json.dumps({"inputText": input})
|
||||
data = {"inputText": input, **inference_params}
|
||||
# data = json.dumps(data)
|
||||
elif provider == "cohere":
|
||||
inference_params["input_type"] = inference_params.get("input_type", "search_document") # aws bedrock example default - https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/providers?model=cohere.embed-english-v3
|
||||
data = {"texts": [input], **inference_params}
|
||||
body = json.dumps(data).encode("utf-8")
|
||||
## LOGGING
|
||||
request_str = f"""
|
||||
response = client.invoke_model(
|
||||
body={body},
|
||||
modelId={model},
|
||||
accept="*/*",
|
||||
contentType="application/json",
|
||||
)""" # type: ignore
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
api_key="", # boto3 is used for init.
|
||||
additional_args={"complete_input_dict": {"model": model,
|
||||
"texts": input}, "request_str": request_str},
|
||||
)
|
||||
try:
|
||||
response = client.invoke_model(
|
||||
body=body,
|
||||
modelId=model,
|
||||
accept="application/json",
|
||||
modelId="cohere.embed-multilingual-v3",
|
||||
accept="*/*",
|
||||
contentType="application/json",
|
||||
)
|
||||
response_body = json.loads(response.get("body").read())
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key="",
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=response_body,
|
||||
)
|
||||
if provider == "cohere":
|
||||
return response_body.get("embeddings")
|
||||
elif provider == "amazon":
|
||||
return response_body.get("embedding")
|
||||
except Exception as e:
|
||||
raise BedrockError(message=f"Embedding Error with model {model}: {e}", status_code=500)
|
||||
|
@ -576,17 +605,21 @@ def embedding(
|
|||
optional_params=None,
|
||||
encoding=None,
|
||||
):
|
||||
### BOTO3 INIT ###
|
||||
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": {"model": model,
|
||||
"texts": input}},
|
||||
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
|
||||
client = init_bedrock_client(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_region_name=aws_region_name,
|
||||
)
|
||||
|
||||
## Embedding Call
|
||||
embeddings = [_embedding_func_single(model, i, optional_params) for i in input]
|
||||
embeddings = [_embedding_func_single(model, i, optional_params=optional_params, client=client, logging_obj=logging_obj) for i in input] # [TODO]: make these parallel calls
|
||||
|
||||
|
||||
## Populate OpenAI compliant dictionary
|
||||
|
|
|
@ -843,6 +843,7 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
|
|||
data["user"] = user_api_key_dict.user_id
|
||||
|
||||
if "metadata" in data:
|
||||
print(f'received metadata: {data["metadata"]}')
|
||||
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||
data["metadata"]["headers"] = request.headers
|
||||
else:
|
||||
|
|
|
@ -151,7 +151,7 @@ def test_cohere_embedding3():
|
|||
|
||||
# test_cohere_embedding3()
|
||||
|
||||
def test_bedrock_embedding():
|
||||
def test_bedrock_embedding_titan():
|
||||
try:
|
||||
response = embedding(
|
||||
model="amazon.titan-embed-text-v1", input=["good morning from litellm, attempting to embed data",
|
||||
|
@ -162,6 +162,19 @@ def test_bedrock_embedding():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
# test_bedrock_embedding()
|
||||
|
||||
def test_bedrock_embedding_cohere():
|
||||
try:
|
||||
# litellm.set_verbose=True
|
||||
response = embedding(
|
||||
model="cohere.embed-multilingual-v3", input=["good morning from litellm, attempting to embed data", "lets test a second string for good measure"],
|
||||
aws_region_name="os.environ/AWS_REGION_NAME_2"
|
||||
)
|
||||
# print(f"response:", response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
test_bedrock_embedding_cohere()
|
||||
|
||||
# comment out hf tests - since hf endpoints are unstable
|
||||
def test_hf_embedding():
|
||||
try:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue