forked from phoenix/litellm-mirror
fix(bedrock_httpx.py): move bedrock ai21 calls to being async
This commit is contained in:
parent
180bc46ca4
commit
0293f7766a
5 changed files with 88 additions and 71 deletions
|
@ -419,9 +419,14 @@ class BedrockLLM(BaseLLM):
|
||||||
+ completion_response["usage"]["output_tokens"],
|
+ completion_response["usage"]["output_tokens"],
|
||||||
)
|
)
|
||||||
setattr(model_response, "usage", _usage)
|
setattr(model_response, "usage", _usage)
|
||||||
else:
|
else:
|
||||||
outputText = completion_response["completion"]
|
outputText = completion_response["completion"]
|
||||||
model_response["finish_reason"] = completion_response["stop_reason"]
|
|
||||||
|
model_response["finish_reason"] = completion_response["stop_reason"]
|
||||||
|
elif provider == "ai21":
|
||||||
|
outputText = (
|
||||||
|
completion_response.get("completions")[0].get("data").get("text")
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise BedrockError(
|
raise BedrockError(
|
||||||
message="Error processing={}, Received error={}".format(
|
message="Error processing={}, Received error={}".format(
|
||||||
|
@ -430,6 +435,49 @@ class BedrockLLM(BaseLLM):
|
||||||
status_code=422,
|
status_code=422,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if (
|
||||||
|
len(outputText) > 0
|
||||||
|
and hasattr(model_response.choices[0], "message")
|
||||||
|
and getattr(model_response.choices[0].message, "tool_calls", None)
|
||||||
|
is None
|
||||||
|
):
|
||||||
|
model_response["choices"][0]["message"]["content"] = outputText
|
||||||
|
elif (
|
||||||
|
hasattr(model_response.choices[0], "message")
|
||||||
|
and getattr(model_response.choices[0].message, "tool_calls", None)
|
||||||
|
is not None
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise Exception()
|
||||||
|
except:
|
||||||
|
raise BedrockError(
|
||||||
|
message=json.dumps(outputText), status_code=response.status_code
|
||||||
|
)
|
||||||
|
|
||||||
|
if stream and provider == "ai21":
|
||||||
|
streaming_model_response = ModelResponse(stream=True)
|
||||||
|
streaming_model_response.choices[0].finish_reason = model_response.choices[ # type: ignore
|
||||||
|
0
|
||||||
|
].finish_reason
|
||||||
|
# streaming_model_response.choices = [litellm.utils.StreamingChoices()]
|
||||||
|
streaming_choice = litellm.utils.StreamingChoices()
|
||||||
|
streaming_choice.index = model_response.choices[0].index
|
||||||
|
delta_obj = litellm.utils.Delta(
|
||||||
|
content=getattr(model_response.choices[0].message, "content", None),
|
||||||
|
role=model_response.choices[0].message.role,
|
||||||
|
)
|
||||||
|
streaming_choice.delta = delta_obj
|
||||||
|
streaming_model_response.choices = [streaming_choice]
|
||||||
|
mri = ModelResponseIterator(model_response=streaming_model_response)
|
||||||
|
return CustomStreamWrapper(
|
||||||
|
completion_stream=mri,
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="cached_response",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
|
||||||
## CALCULATING USAGE - bedrock returns usage in the headers
|
## CALCULATING USAGE - bedrock returns usage in the headers
|
||||||
bedrock_input_tokens = response.headers.get(
|
bedrock_input_tokens = response.headers.get(
|
||||||
"x-amzn-bedrock-input-token-count", None
|
"x-amzn-bedrock-input-token-count", None
|
||||||
|
@ -489,6 +537,7 @@ class BedrockLLM(BaseLLM):
|
||||||
|
|
||||||
## SETUP ##
|
## SETUP ##
|
||||||
stream = optional_params.pop("stream", None)
|
stream = optional_params.pop("stream", None)
|
||||||
|
provider = model.split(".")[0]
|
||||||
|
|
||||||
## CREDENTIALS ##
|
## CREDENTIALS ##
|
||||||
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||||
|
@ -544,14 +593,13 @@ class BedrockLLM(BaseLLM):
|
||||||
else:
|
else:
|
||||||
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
|
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
|
||||||
|
|
||||||
if stream is not None and stream == True:
|
if (stream is not None and stream == True) and provider != "ai21":
|
||||||
endpoint_url = f"{endpoint_url}/model/{model}/invoke-with-response-stream"
|
endpoint_url = f"{endpoint_url}/model/{model}/invoke-with-response-stream"
|
||||||
else:
|
else:
|
||||||
endpoint_url = f"{endpoint_url}/model/{model}/invoke"
|
endpoint_url = f"{endpoint_url}/model/{model}/invoke"
|
||||||
|
|
||||||
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
||||||
|
|
||||||
provider = model.split(".")[0]
|
|
||||||
prompt, chat_history = self.convert_messages_to_prompt(
|
prompt, chat_history = self.convert_messages_to_prompt(
|
||||||
model, messages, provider, custom_prompt_dict
|
model, messages, provider, custom_prompt_dict
|
||||||
)
|
)
|
||||||
|
@ -633,6 +681,16 @@ class BedrockLLM(BaseLLM):
|
||||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||||
inference_params[k] = v
|
inference_params[k] = v
|
||||||
data = json.dumps({"prompt": prompt, **inference_params})
|
data = json.dumps({"prompt": prompt, **inference_params})
|
||||||
|
elif provider == "ai21":
|
||||||
|
## LOAD CONFIG
|
||||||
|
config = litellm.AmazonAI21Config.get_config()
|
||||||
|
for k, v in config.items():
|
||||||
|
if (
|
||||||
|
k not in inference_params
|
||||||
|
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||||
|
inference_params[k] = v
|
||||||
|
|
||||||
|
data = json.dumps({"prompt": prompt, **inference_params})
|
||||||
else:
|
else:
|
||||||
raise Exception("UNSUPPORTED PROVIDER")
|
raise Exception("UNSUPPORTED PROVIDER")
|
||||||
|
|
||||||
|
@ -662,7 +720,7 @@ class BedrockLLM(BaseLLM):
|
||||||
if acompletion:
|
if acompletion:
|
||||||
if isinstance(client, HTTPHandler):
|
if isinstance(client, HTTPHandler):
|
||||||
client = None
|
client = None
|
||||||
if stream:
|
if stream == True and provider != "ai21":
|
||||||
return self.async_streaming(
|
return self.async_streaming(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -691,7 +749,7 @@ class BedrockLLM(BaseLLM):
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
stream=False,
|
stream=stream, # type: ignore
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
headers=prepped.headers,
|
headers=prepped.headers,
|
||||||
|
@ -708,7 +766,7 @@ class BedrockLLM(BaseLLM):
|
||||||
self.client = HTTPHandler(**_params) # type: ignore
|
self.client = HTTPHandler(**_params) # type: ignore
|
||||||
else:
|
else:
|
||||||
self.client = client
|
self.client = client
|
||||||
if stream is not None and stream == True:
|
if (stream is not None and stream == True) and provider != "ai21":
|
||||||
response = self.client.post(
|
response = self.client.post(
|
||||||
url=prepped.url,
|
url=prepped.url,
|
||||||
headers=prepped.headers, # type: ignore
|
headers=prepped.headers, # type: ignore
|
||||||
|
@ -787,7 +845,7 @@ class BedrockLLM(BaseLLM):
|
||||||
model=model,
|
model=model,
|
||||||
response=response,
|
response=response,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
stream=stream,
|
stream=stream if isinstance(stream, bool) else False,
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
api_key="",
|
api_key="",
|
||||||
data=data,
|
data=data,
|
||||||
|
|
|
@ -328,7 +328,7 @@ async def acompletion(
|
||||||
or custom_llm_provider == "predibase"
|
or custom_llm_provider == "predibase"
|
||||||
or (
|
or (
|
||||||
custom_llm_provider == "bedrock"
|
custom_llm_provider == "bedrock"
|
||||||
and ("cohere" in model or "anthropic" in model)
|
and ("cohere" in model or "anthropic" in model or "ai21" in model)
|
||||||
)
|
)
|
||||||
or custom_llm_provider in litellm.openai_compatible_providers
|
or custom_llm_provider in litellm.openai_compatible_providers
|
||||||
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
|
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
|
||||||
|
@ -1982,7 +1982,7 @@ def completion(
|
||||||
# boto3 reads keys from .env
|
# boto3 reads keys from .env
|
||||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||||
|
|
||||||
if "cohere" in model or "anthropic" in model:
|
if "cohere" in model or "anthropic" in model or "ai21" in model:
|
||||||
response = bedrock_chat_completion.completion(
|
response = bedrock_chat_completion.completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
|
|
@ -2665,7 +2665,12 @@ def response_format_tests(response: litellm.ModelResponse):
|
||||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model",
|
"model",
|
||||||
["bedrock/cohere.command-r-plus-v1:0", "anthropic.claude-3-sonnet-20240229-v1:0"],
|
[
|
||||||
|
"bedrock/cohere.command-r-plus-v1:0",
|
||||||
|
"anthropic.claude-3-sonnet-20240229-v1:0",
|
||||||
|
"anthropic.claude-instant-v1",
|
||||||
|
"bedrock/ai21.j2-mid",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_completion_bedrock_httpx_models(sync_mode, model):
|
async def test_completion_bedrock_httpx_models(sync_mode, model):
|
||||||
|
@ -2675,6 +2680,8 @@ async def test_completion_bedrock_httpx_models(sync_mode, model):
|
||||||
response = completion(
|
response = completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=[{"role": "user", "content": "Hey! how's it going?"}],
|
messages=[{"role": "user", "content": "Hey! how's it going?"}],
|
||||||
|
temperature=0.2,
|
||||||
|
max_tokens=200,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(response, litellm.ModelResponse)
|
assert isinstance(response, litellm.ModelResponse)
|
||||||
|
@ -2684,6 +2691,8 @@ async def test_completion_bedrock_httpx_models(sync_mode, model):
|
||||||
response = await litellm.acompletion(
|
response = await litellm.acompletion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=[{"role": "user", "content": "Hey! how's it going?"}],
|
messages=[{"role": "user", "content": "Hey! how's it going?"}],
|
||||||
|
temperature=0.2,
|
||||||
|
max_tokens=200,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(response, litellm.ModelResponse)
|
assert isinstance(response, litellm.ModelResponse)
|
||||||
|
@ -2740,48 +2749,9 @@ def test_completion_bedrock_titan():
|
||||||
# test_completion_bedrock_titan()
|
# test_completion_bedrock_titan()
|
||||||
|
|
||||||
|
|
||||||
def test_completion_bedrock_claude():
|
|
||||||
print("calling claude")
|
|
||||||
try:
|
|
||||||
response = completion(
|
|
||||||
model="anthropic.claude-instant-v1",
|
|
||||||
messages=messages,
|
|
||||||
max_tokens=10,
|
|
||||||
temperature=0.1,
|
|
||||||
logger_fn=logger_fn,
|
|
||||||
)
|
|
||||||
# Add any assertions here to check the response
|
|
||||||
print(response)
|
|
||||||
except RateLimitError:
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
pytest.fail(f"Error occurred: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
# test_completion_bedrock_claude()
|
# test_completion_bedrock_claude()
|
||||||
|
|
||||||
|
|
||||||
def test_completion_bedrock_cohere():
|
|
||||||
print("calling bedrock cohere")
|
|
||||||
litellm.set_verbose = True
|
|
||||||
try:
|
|
||||||
response = completion(
|
|
||||||
model="bedrock/cohere.command-text-v14",
|
|
||||||
messages=[{"role": "user", "content": "hi"}],
|
|
||||||
temperature=0.1,
|
|
||||||
max_tokens=10,
|
|
||||||
stream=True,
|
|
||||||
)
|
|
||||||
# Add any assertions here to check the response
|
|
||||||
print(response)
|
|
||||||
for chunk in response:
|
|
||||||
print(chunk)
|
|
||||||
except RateLimitError:
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
pytest.fail(f"Error occurred: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
# test_completion_bedrock_cohere()
|
# test_completion_bedrock_cohere()
|
||||||
|
|
||||||
|
|
||||||
|
@ -2804,23 +2774,6 @@ def test_completion_bedrock_cohere():
|
||||||
# pytest.fail(f"Error occurred: {e}")
|
# pytest.fail(f"Error occurred: {e}")
|
||||||
# test_completion_bedrock_claude_stream()
|
# test_completion_bedrock_claude_stream()
|
||||||
|
|
||||||
# def test_completion_bedrock_ai21():
|
|
||||||
# try:
|
|
||||||
# litellm.set_verbose = False
|
|
||||||
# response = completion(
|
|
||||||
# model="bedrock/ai21.j2-mid",
|
|
||||||
# messages=messages,
|
|
||||||
# temperature=0.2,
|
|
||||||
# top_p=0.2,
|
|
||||||
# max_tokens=20
|
|
||||||
# )
|
|
||||||
# # Add any assertions here to check the response
|
|
||||||
# print(response)
|
|
||||||
# except RateLimitError:
|
|
||||||
# pass
|
|
||||||
# except Exception as e:
|
|
||||||
# pytest.fail(f"Error occurred: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
######## Test VLLM ########
|
######## Test VLLM ########
|
||||||
# def test_completion_vllm():
|
# def test_completion_vllm():
|
||||||
|
|
|
@ -1044,8 +1044,10 @@ async def test_completion_replicate_llama3_streaming(sync_mode):
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model",
|
"model",
|
||||||
[
|
[
|
||||||
# "bedrock/cohere.command-r-plus-v1:0",
|
"bedrock/cohere.command-r-plus-v1:0",
|
||||||
"anthropic.claude-3-sonnet-20240229-v1:0"
|
"anthropic.claude-3-sonnet-20240229-v1:0",
|
||||||
|
"anthropic.claude-instant-v1",
|
||||||
|
"bedrock/ai21.j2-mid",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
|
@ -11510,7 +11510,11 @@ class CustomStreamWrapper:
|
||||||
or self.custom_llm_provider == "predibase"
|
or self.custom_llm_provider == "predibase"
|
||||||
or (
|
or (
|
||||||
self.custom_llm_provider == "bedrock"
|
self.custom_llm_provider == "bedrock"
|
||||||
and ("cohere" in self.model or "anthropic" in self.model)
|
and (
|
||||||
|
"cohere" in self.model
|
||||||
|
or "anthropic" in self.model
|
||||||
|
or "ai21" in self.model
|
||||||
|
)
|
||||||
)
|
)
|
||||||
or self.custom_llm_provider in litellm.openai_compatible_endpoints
|
or self.custom_llm_provider in litellm.openai_compatible_endpoints
|
||||||
):
|
):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue