mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(bedrock_httpx.py): move bedrock ai21 calls to being async
This commit is contained in:
parent
180bc46ca4
commit
0293f7766a
5 changed files with 88 additions and 71 deletions
|
@ -421,7 +421,12 @@ class BedrockLLM(BaseLLM):
|
|||
setattr(model_response, "usage", _usage)
|
||||
else:
|
||||
outputText = completion_response["completion"]
|
||||
|
||||
model_response["finish_reason"] = completion_response["stop_reason"]
|
||||
elif provider == "ai21":
|
||||
outputText = (
|
||||
completion_response.get("completions")[0].get("data").get("text")
|
||||
)
|
||||
except Exception as e:
|
||||
raise BedrockError(
|
||||
message="Error processing={}, Received error={}".format(
|
||||
|
@ -430,6 +435,49 @@ class BedrockLLM(BaseLLM):
|
|||
status_code=422,
|
||||
)
|
||||
|
||||
try:
|
||||
if (
|
||||
len(outputText) > 0
|
||||
and hasattr(model_response.choices[0], "message")
|
||||
and getattr(model_response.choices[0].message, "tool_calls", None)
|
||||
is None
|
||||
):
|
||||
model_response["choices"][0]["message"]["content"] = outputText
|
||||
elif (
|
||||
hasattr(model_response.choices[0], "message")
|
||||
and getattr(model_response.choices[0].message, "tool_calls", None)
|
||||
is not None
|
||||
):
|
||||
pass
|
||||
else:
|
||||
raise Exception()
|
||||
except:
|
||||
raise BedrockError(
|
||||
message=json.dumps(outputText), status_code=response.status_code
|
||||
)
|
||||
|
||||
if stream and provider == "ai21":
|
||||
streaming_model_response = ModelResponse(stream=True)
|
||||
streaming_model_response.choices[0].finish_reason = model_response.choices[ # type: ignore
|
||||
0
|
||||
].finish_reason
|
||||
# streaming_model_response.choices = [litellm.utils.StreamingChoices()]
|
||||
streaming_choice = litellm.utils.StreamingChoices()
|
||||
streaming_choice.index = model_response.choices[0].index
|
||||
delta_obj = litellm.utils.Delta(
|
||||
content=getattr(model_response.choices[0].message, "content", None),
|
||||
role=model_response.choices[0].message.role,
|
||||
)
|
||||
streaming_choice.delta = delta_obj
|
||||
streaming_model_response.choices = [streaming_choice]
|
||||
mri = ModelResponseIterator(model_response=streaming_model_response)
|
||||
return CustomStreamWrapper(
|
||||
completion_stream=mri,
|
||||
model=model,
|
||||
custom_llm_provider="cached_response",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
## CALCULATING USAGE - bedrock returns usage in the headers
|
||||
bedrock_input_tokens = response.headers.get(
|
||||
"x-amzn-bedrock-input-token-count", None
|
||||
|
@ -489,6 +537,7 @@ class BedrockLLM(BaseLLM):
|
|||
|
||||
## SETUP ##
|
||||
stream = optional_params.pop("stream", None)
|
||||
provider = model.split(".")[0]
|
||||
|
||||
## CREDENTIALS ##
|
||||
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||
|
@ -544,14 +593,13 @@ class BedrockLLM(BaseLLM):
|
|||
else:
|
||||
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
|
||||
|
||||
if stream is not None and stream == True:
|
||||
if (stream is not None and stream == True) and provider != "ai21":
|
||||
endpoint_url = f"{endpoint_url}/model/{model}/invoke-with-response-stream"
|
||||
else:
|
||||
endpoint_url = f"{endpoint_url}/model/{model}/invoke"
|
||||
|
||||
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
||||
|
||||
provider = model.split(".")[0]
|
||||
prompt, chat_history = self.convert_messages_to_prompt(
|
||||
model, messages, provider, custom_prompt_dict
|
||||
)
|
||||
|
@ -633,6 +681,16 @@ class BedrockLLM(BaseLLM):
|
|||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
data = json.dumps({"prompt": prompt, **inference_params})
|
||||
elif provider == "ai21":
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonAI21Config.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
|
||||
data = json.dumps({"prompt": prompt, **inference_params})
|
||||
else:
|
||||
raise Exception("UNSUPPORTED PROVIDER")
|
||||
|
||||
|
@ -662,7 +720,7 @@ class BedrockLLM(BaseLLM):
|
|||
if acompletion:
|
||||
if isinstance(client, HTTPHandler):
|
||||
client = None
|
||||
if stream:
|
||||
if stream == True and provider != "ai21":
|
||||
return self.async_streaming(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
@ -691,7 +749,7 @@ class BedrockLLM(BaseLLM):
|
|||
encoding=encoding,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
stream=False,
|
||||
stream=stream, # type: ignore
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=prepped.headers,
|
||||
|
@ -708,7 +766,7 @@ class BedrockLLM(BaseLLM):
|
|||
self.client = HTTPHandler(**_params) # type: ignore
|
||||
else:
|
||||
self.client = client
|
||||
if stream is not None and stream == True:
|
||||
if (stream is not None and stream == True) and provider != "ai21":
|
||||
response = self.client.post(
|
||||
url=prepped.url,
|
||||
headers=prepped.headers, # type: ignore
|
||||
|
@ -787,7 +845,7 @@ class BedrockLLM(BaseLLM):
|
|||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
stream=stream,
|
||||
stream=stream if isinstance(stream, bool) else False,
|
||||
logging_obj=logging_obj,
|
||||
api_key="",
|
||||
data=data,
|
||||
|
|
|
@ -328,7 +328,7 @@ async def acompletion(
|
|||
or custom_llm_provider == "predibase"
|
||||
or (
|
||||
custom_llm_provider == "bedrock"
|
||||
and ("cohere" in model or "anthropic" in model)
|
||||
and ("cohere" in model or "anthropic" in model or "ai21" in model)
|
||||
)
|
||||
or custom_llm_provider in litellm.openai_compatible_providers
|
||||
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
|
||||
|
@ -1982,7 +1982,7 @@ def completion(
|
|||
# boto3 reads keys from .env
|
||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||
|
||||
if "cohere" in model or "anthropic" in model:
|
||||
if "cohere" in model or "anthropic" in model or "ai21" in model:
|
||||
response = bedrock_chat_completion.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
|
|
@ -2665,7 +2665,12 @@ def response_format_tests(response: litellm.ModelResponse):
|
|||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
["bedrock/cohere.command-r-plus-v1:0", "anthropic.claude-3-sonnet-20240229-v1:0"],
|
||||
[
|
||||
"bedrock/cohere.command-r-plus-v1:0",
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"anthropic.claude-instant-v1",
|
||||
"bedrock/ai21.j2-mid",
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_bedrock_httpx_models(sync_mode, model):
|
||||
|
@ -2675,6 +2680,8 @@ async def test_completion_bedrock_httpx_models(sync_mode, model):
|
|||
response = completion(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": "Hey! how's it going?"}],
|
||||
temperature=0.2,
|
||||
max_tokens=200,
|
||||
)
|
||||
|
||||
assert isinstance(response, litellm.ModelResponse)
|
||||
|
@ -2684,6 +2691,8 @@ async def test_completion_bedrock_httpx_models(sync_mode, model):
|
|||
response = await litellm.acompletion(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": "Hey! how's it going?"}],
|
||||
temperature=0.2,
|
||||
max_tokens=200,
|
||||
)
|
||||
|
||||
assert isinstance(response, litellm.ModelResponse)
|
||||
|
@ -2740,48 +2749,9 @@ def test_completion_bedrock_titan():
|
|||
# test_completion_bedrock_titan()
|
||||
|
||||
|
||||
def test_completion_bedrock_claude():
|
||||
print("calling claude")
|
||||
try:
|
||||
response = completion(
|
||||
model="anthropic.claude-instant-v1",
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
temperature=0.1,
|
||||
logger_fn=logger_fn,
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
except RateLimitError:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# test_completion_bedrock_claude()
|
||||
|
||||
|
||||
def test_completion_bedrock_cohere():
|
||||
print("calling bedrock cohere")
|
||||
litellm.set_verbose = True
|
||||
try:
|
||||
response = completion(
|
||||
model="bedrock/cohere.command-text-v14",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
temperature=0.1,
|
||||
max_tokens=10,
|
||||
stream=True,
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
for chunk in response:
|
||||
print(chunk)
|
||||
except RateLimitError:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# test_completion_bedrock_cohere()
|
||||
|
||||
|
||||
|
@ -2804,23 +2774,6 @@ def test_completion_bedrock_cohere():
|
|||
# pytest.fail(f"Error occurred: {e}")
|
||||
# test_completion_bedrock_claude_stream()
|
||||
|
||||
# def test_completion_bedrock_ai21():
|
||||
# try:
|
||||
# litellm.set_verbose = False
|
||||
# response = completion(
|
||||
# model="bedrock/ai21.j2-mid",
|
||||
# messages=messages,
|
||||
# temperature=0.2,
|
||||
# top_p=0.2,
|
||||
# max_tokens=20
|
||||
# )
|
||||
# # Add any assertions here to check the response
|
||||
# print(response)
|
||||
# except RateLimitError:
|
||||
# pass
|
||||
# except Exception as e:
|
||||
# pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
######## Test VLLM ########
|
||||
# def test_completion_vllm():
|
||||
|
|
|
@ -1044,8 +1044,10 @@ async def test_completion_replicate_llama3_streaming(sync_mode):
|
|||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
# "bedrock/cohere.command-r-plus-v1:0",
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0"
|
||||
"bedrock/cohere.command-r-plus-v1:0",
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"anthropic.claude-instant-v1",
|
||||
"bedrock/ai21.j2-mid",
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
@ -11510,7 +11510,11 @@ class CustomStreamWrapper:
|
|||
or self.custom_llm_provider == "predibase"
|
||||
or (
|
||||
self.custom_llm_provider == "bedrock"
|
||||
and ("cohere" in self.model or "anthropic" in self.model)
|
||||
and (
|
||||
"cohere" in self.model
|
||||
or "anthropic" in self.model
|
||||
or "ai21" in self.model
|
||||
)
|
||||
)
|
||||
or self.custom_llm_provider in litellm.openai_compatible_endpoints
|
||||
):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue