fix(bedrock_httpx.py): working async bedrock command r calls

This commit is contained in:
Krrish Dholakia 2024-05-11 16:45:20 -07:00
parent 59c8c0adff
commit 49ab1a1d3f
6 changed files with 374 additions and 78 deletions

View file

@ -16,6 +16,7 @@ from litellm.utils import (
Message,
Choices,
get_secret,
Logging,
)
import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt, cohere_message_pt
@ -255,6 +256,70 @@ class BedrockLLM(BaseLLM):
return session.get_credentials()
def process_response(
self,
model: str,
response: requests.Response | httpx.Response,
model_response: ModelResponse,
stream: bool,
logging_obj: Logging,
optional_params: dict,
api_key: str,
data: Union[dict, str],
messages: List,
print_verbose,
encoding,
) -> ModelResponse:
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
try:
completion_response = response.json()
except:
raise BedrockError(message=response.text, status_code=422)
try:
model_response.choices[0].message.content = completion_response["text"] # type: ignore
except Exception as e:
raise BedrockError(message=response.text, status_code=422)
## CALCULATING USAGE - bedrock returns usage in the headers
prompt_tokens = int(
response.headers.get(
"x-amzn-bedrock-input-token-count",
len(encoding.encode("".join(m.get("content", "") for m in messages))),
)
)
completion_tokens = int(
response.headers.get(
"x-amzn-bedrock-output-token-count",
len(
encoding.encode(
model_response.choices[0].message.content, # type: ignore
disallowed_special=(),
)
),
)
)
model_response["created"] = int(time.time())
model_response["model"] = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
def completion(
self,
model: str,
@ -268,8 +333,9 @@ class BedrockLLM(BaseLLM):
timeout: Optional[Union[float, httpx.Timeout]],
litellm_params=None,
logger_fn=None,
acompletion: bool = False,
extra_headers: Optional[dict] = None,
client: Optional[HTTPHandler] = None,
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
) -> Union[ModelResponse, CustomStreamWrapper]:
try:
import boto3
@ -381,13 +447,39 @@ class BedrockLLM(BaseLLM):
## COMPLETION CALL
headers = {"Content-Type": "application/json"}
if extra_headers is not None:
headers = {"Content-Type": "application/json", **extra_headers}
request = AWSRequest(
method="POST", url=endpoint_url, data=data, headers=headers
)
sigv4.add_auth(request)
prepped = request.prepare()
if client is None:
### ROUTING (ASYNC, STREAMING, SYNC)
if acompletion:
if isinstance(client, HTTPHandler):
client = None
### ASYNC COMPLETION
return self.async_completion(
model=model,
messages=messages,
data=data,
api_base=prepped.url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
logging_obj=logging_obj,
optional_params=optional_params,
stream=False,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=prepped.headers,
timeout=timeout,
client=client,
) # type: ignore
if client is None or isinstance(client, AsyncHTTPHandler):
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
@ -416,7 +508,62 @@ class BedrockLLM(BaseLLM):
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=response.text)
return response
return self.process_response(
model=model,
response=response,
model_response=model_response,
stream=stream,
logging_obj=logging_obj,
optional_params=optional_params,
api_key="",
data=data,
messages=messages,
print_verbose=print_verbose,
encoding=encoding,
)
async def async_completion(
self,
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
data: str,
timeout: Optional[Union[float, httpx.Timeout]],
encoding,
logging_obj,
stream,
optional_params: dict,
litellm_params=None,
logger_fn=None,
headers={},
client: Optional[AsyncHTTPHandler] = None,
) -> ModelResponse:
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
self.client = AsyncHTTPHandler(**_params) # type: ignore
else:
self.client = client # type: ignore
response = await self.client.post(api_base, headers=headers, data=data) # type: ignore
return self.process_response(
model=model,
response=response,
model_response=model_response,
stream=stream,
logging_obj=logging_obj,
api_key="",
data=data,
messages=messages,
print_verbose=print_verbose,
optional_params=optional_params,
encoding=encoding,
)
def embedding(self, *args, **kwargs):
return super().embedding(*args, **kwargs)