fix(bedrock_httpx.py): move bedrock ai21 calls to being async

This commit is contained in:
Krrish Dholakia 2024-05-16 22:21:30 -07:00
parent 180bc46ca4
commit 0293f7766a
5 changed files with 88 additions and 71 deletions

View file

@ -421,7 +421,12 @@ class BedrockLLM(BaseLLM):
setattr(model_response, "usage", _usage) setattr(model_response, "usage", _usage)
else: else:
outputText = completion_response["completion"] outputText = completion_response["completion"]
model_response["finish_reason"] = completion_response["stop_reason"] model_response["finish_reason"] = completion_response["stop_reason"]
elif provider == "ai21":
outputText = (
completion_response.get("completions")[0].get("data").get("text")
)
except Exception as e: except Exception as e:
raise BedrockError( raise BedrockError(
message="Error processing={}, Received error={}".format( message="Error processing={}, Received error={}".format(
@ -430,6 +435,49 @@ class BedrockLLM(BaseLLM):
status_code=422, status_code=422,
) )
try:
if (
len(outputText) > 0
and hasattr(model_response.choices[0], "message")
and getattr(model_response.choices[0].message, "tool_calls", None)
is None
):
model_response["choices"][0]["message"]["content"] = outputText
elif (
hasattr(model_response.choices[0], "message")
and getattr(model_response.choices[0].message, "tool_calls", None)
is not None
):
pass
else:
raise Exception()
except:
raise BedrockError(
message=json.dumps(outputText), status_code=response.status_code
)
if stream and provider == "ai21":
streaming_model_response = ModelResponse(stream=True)
streaming_model_response.choices[0].finish_reason = model_response.choices[ # type: ignore
0
].finish_reason
# streaming_model_response.choices = [litellm.utils.StreamingChoices()]
streaming_choice = litellm.utils.StreamingChoices()
streaming_choice.index = model_response.choices[0].index
delta_obj = litellm.utils.Delta(
content=getattr(model_response.choices[0].message, "content", None),
role=model_response.choices[0].message.role,
)
streaming_choice.delta = delta_obj
streaming_model_response.choices = [streaming_choice]
mri = ModelResponseIterator(model_response=streaming_model_response)
return CustomStreamWrapper(
completion_stream=mri,
model=model,
custom_llm_provider="cached_response",
logging_obj=logging_obj,
)
## CALCULATING USAGE - bedrock returns usage in the headers ## CALCULATING USAGE - bedrock returns usage in the headers
bedrock_input_tokens = response.headers.get( bedrock_input_tokens = response.headers.get(
"x-amzn-bedrock-input-token-count", None "x-amzn-bedrock-input-token-count", None
@ -489,6 +537,7 @@ class BedrockLLM(BaseLLM):
## SETUP ## ## SETUP ##
stream = optional_params.pop("stream", None) stream = optional_params.pop("stream", None)
provider = model.split(".")[0]
## CREDENTIALS ## ## CREDENTIALS ##
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
@ -544,14 +593,13 @@ class BedrockLLM(BaseLLM):
else: else:
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com" endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
if stream is not None and stream == True: if (stream is not None and stream == True) and provider != "ai21":
endpoint_url = f"{endpoint_url}/model/{model}/invoke-with-response-stream" endpoint_url = f"{endpoint_url}/model/{model}/invoke-with-response-stream"
else: else:
endpoint_url = f"{endpoint_url}/model/{model}/invoke" endpoint_url = f"{endpoint_url}/model/{model}/invoke"
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name) sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
provider = model.split(".")[0]
prompt, chat_history = self.convert_messages_to_prompt( prompt, chat_history = self.convert_messages_to_prompt(
model, messages, provider, custom_prompt_dict model, messages, provider, custom_prompt_dict
) )
@ -633,6 +681,16 @@ class BedrockLLM(BaseLLM):
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v inference_params[k] = v
data = json.dumps({"prompt": prompt, **inference_params}) data = json.dumps({"prompt": prompt, **inference_params})
elif provider == "ai21":
## LOAD CONFIG
config = litellm.AmazonAI21Config.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
data = json.dumps({"prompt": prompt, **inference_params})
else: else:
raise Exception("UNSUPPORTED PROVIDER") raise Exception("UNSUPPORTED PROVIDER")
@ -662,7 +720,7 @@ class BedrockLLM(BaseLLM):
if acompletion: if acompletion:
if isinstance(client, HTTPHandler): if isinstance(client, HTTPHandler):
client = None client = None
if stream: if stream == True and provider != "ai21":
return self.async_streaming( return self.async_streaming(
model=model, model=model,
messages=messages, messages=messages,
@ -691,7 +749,7 @@ class BedrockLLM(BaseLLM):
encoding=encoding, encoding=encoding,
logging_obj=logging_obj, logging_obj=logging_obj,
optional_params=optional_params, optional_params=optional_params,
stream=False, stream=stream, # type: ignore
litellm_params=litellm_params, litellm_params=litellm_params,
logger_fn=logger_fn, logger_fn=logger_fn,
headers=prepped.headers, headers=prepped.headers,
@ -708,7 +766,7 @@ class BedrockLLM(BaseLLM):
self.client = HTTPHandler(**_params) # type: ignore self.client = HTTPHandler(**_params) # type: ignore
else: else:
self.client = client self.client = client
if stream is not None and stream == True: if (stream is not None and stream == True) and provider != "ai21":
response = self.client.post( response = self.client.post(
url=prepped.url, url=prepped.url,
headers=prepped.headers, # type: ignore headers=prepped.headers, # type: ignore
@ -787,7 +845,7 @@ class BedrockLLM(BaseLLM):
model=model, model=model,
response=response, response=response,
model_response=model_response, model_response=model_response,
stream=stream, stream=stream if isinstance(stream, bool) else False,
logging_obj=logging_obj, logging_obj=logging_obj,
api_key="", api_key="",
data=data, data=data,

View file

@ -328,7 +328,7 @@ async def acompletion(
or custom_llm_provider == "predibase" or custom_llm_provider == "predibase"
or ( or (
custom_llm_provider == "bedrock" custom_llm_provider == "bedrock"
and ("cohere" in model or "anthropic" in model) and ("cohere" in model or "anthropic" in model or "ai21" in model)
) )
or custom_llm_provider in litellm.openai_compatible_providers or custom_llm_provider in litellm.openai_compatible_providers
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all. ): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
@ -1982,7 +1982,7 @@ def completion(
# boto3 reads keys from .env # boto3 reads keys from .env
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
if "cohere" in model or "anthropic" in model: if "cohere" in model or "anthropic" in model or "ai21" in model:
response = bedrock_chat_completion.completion( response = bedrock_chat_completion.completion(
model=model, model=model,
messages=messages, messages=messages,

View file

@ -2665,7 +2665,12 @@ def response_format_tests(response: litellm.ModelResponse):
@pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model", "model",
["bedrock/cohere.command-r-plus-v1:0", "anthropic.claude-3-sonnet-20240229-v1:0"], [
"bedrock/cohere.command-r-plus-v1:0",
"anthropic.claude-3-sonnet-20240229-v1:0",
"anthropic.claude-instant-v1",
"bedrock/ai21.j2-mid",
],
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_completion_bedrock_httpx_models(sync_mode, model): async def test_completion_bedrock_httpx_models(sync_mode, model):
@ -2675,6 +2680,8 @@ async def test_completion_bedrock_httpx_models(sync_mode, model):
response = completion( response = completion(
model=model, model=model,
messages=[{"role": "user", "content": "Hey! how's it going?"}], messages=[{"role": "user", "content": "Hey! how's it going?"}],
temperature=0.2,
max_tokens=200,
) )
assert isinstance(response, litellm.ModelResponse) assert isinstance(response, litellm.ModelResponse)
@ -2684,6 +2691,8 @@ async def test_completion_bedrock_httpx_models(sync_mode, model):
response = await litellm.acompletion( response = await litellm.acompletion(
model=model, model=model,
messages=[{"role": "user", "content": "Hey! how's it going?"}], messages=[{"role": "user", "content": "Hey! how's it going?"}],
temperature=0.2,
max_tokens=200,
) )
assert isinstance(response, litellm.ModelResponse) assert isinstance(response, litellm.ModelResponse)
@ -2740,48 +2749,9 @@ def test_completion_bedrock_titan():
# test_completion_bedrock_titan() # test_completion_bedrock_titan()
def test_completion_bedrock_claude():
print("calling claude")
try:
response = completion(
model="anthropic.claude-instant-v1",
messages=messages,
max_tokens=10,
temperature=0.1,
logger_fn=logger_fn,
)
# Add any assertions here to check the response
print(response)
except RateLimitError:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_bedrock_claude() # test_completion_bedrock_claude()
def test_completion_bedrock_cohere():
print("calling bedrock cohere")
litellm.set_verbose = True
try:
response = completion(
model="bedrock/cohere.command-text-v14",
messages=[{"role": "user", "content": "hi"}],
temperature=0.1,
max_tokens=10,
stream=True,
)
# Add any assertions here to check the response
print(response)
for chunk in response:
print(chunk)
except RateLimitError:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_bedrock_cohere() # test_completion_bedrock_cohere()
@ -2804,23 +2774,6 @@ def test_completion_bedrock_cohere():
# pytest.fail(f"Error occurred: {e}") # pytest.fail(f"Error occurred: {e}")
# test_completion_bedrock_claude_stream() # test_completion_bedrock_claude_stream()
# def test_completion_bedrock_ai21():
# try:
# litellm.set_verbose = False
# response = completion(
# model="bedrock/ai21.j2-mid",
# messages=messages,
# temperature=0.2,
# top_p=0.2,
# max_tokens=20
# )
# # Add any assertions here to check the response
# print(response)
# except RateLimitError:
# pass
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
######## Test VLLM ######## ######## Test VLLM ########
# def test_completion_vllm(): # def test_completion_vllm():

View file

@ -1044,8 +1044,10 @@ async def test_completion_replicate_llama3_streaming(sync_mode):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model", "model",
[ [
# "bedrock/cohere.command-r-plus-v1:0", "bedrock/cohere.command-r-plus-v1:0",
"anthropic.claude-3-sonnet-20240229-v1:0" "anthropic.claude-3-sonnet-20240229-v1:0",
"anthropic.claude-instant-v1",
"bedrock/ai21.j2-mid",
], ],
) )
@pytest.mark.asyncio @pytest.mark.asyncio

View file

@ -11510,7 +11510,11 @@ class CustomStreamWrapper:
or self.custom_llm_provider == "predibase" or self.custom_llm_provider == "predibase"
or ( or (
self.custom_llm_provider == "bedrock" self.custom_llm_provider == "bedrock"
and ("cohere" in self.model or "anthropic" in self.model) and (
"cohere" in self.model
or "anthropic" in self.model
or "ai21" in self.model
)
) )
or self.custom_llm_provider in litellm.openai_compatible_endpoints or self.custom_llm_provider in litellm.openai_compatible_endpoints
): ):