fix(bedrock_httpx.py): working async bedrock command r calls

This commit is contained in:
Krrish Dholakia 2024-05-11 16:45:20 -07:00
parent 59c8c0adff
commit 49ab1a1d3f
6 changed files with 374 additions and 78 deletions

View file

@ -151,19 +151,120 @@ class AnthropicChatCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
def process_streaming_response(
self,
model: str,
response: requests.Response | httpx.Response,
model_response: ModelResponse,
stream: bool,
logging_obj: litellm.utils.Logging,
optional_params: dict,
api_key: str,
data: dict | str,
messages: List,
print_verbose,
encoding,
) -> CustomStreamWrapper:
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
try:
completion_response = response.json()
except:
raise AnthropicError(
message=response.text, status_code=response.status_code
)
text_content = ""
tool_calls = []
for content in completion_response["content"]:
if content["type"] == "text":
text_content += content["text"]
## TOOL CALLING
elif content["type"] == "tool_use":
tool_calls.append(
{
"id": content["id"],
"type": "function",
"function": {
"name": content["name"],
"arguments": json.dumps(content["input"]),
},
}
)
if "error" in completion_response:
raise AnthropicError(
message=str(completion_response["error"]),
status_code=response.status_code,
)
print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
# return an iterator
streaming_model_response = ModelResponse(stream=True)
streaming_model_response.choices[0].finish_reason = model_response.choices[ # type: ignore
0
].finish_reason
# streaming_model_response.choices = [litellm.utils.StreamingChoices()]
streaming_choice = litellm.utils.StreamingChoices()
streaming_choice.index = model_response.choices[0].index
_tool_calls = []
print_verbose(
f"type of model_response.choices[0]: {type(model_response.choices[0])}"
)
print_verbose(f"type of streaming_choice: {type(streaming_choice)}")
if isinstance(model_response.choices[0], litellm.Choices):
if getattr(
model_response.choices[0].message, "tool_calls", None
) is not None and isinstance(
model_response.choices[0].message.tool_calls, list
):
for tool_call in model_response.choices[0].message.tool_calls:
_tool_call = {**tool_call.dict(), "index": 0}
_tool_calls.append(_tool_call)
delta_obj = litellm.utils.Delta(
content=getattr(model_response.choices[0].message, "content", None),
role=model_response.choices[0].message.role,
tool_calls=_tool_calls,
)
streaming_choice.delta = delta_obj
streaming_model_response.choices = [streaming_choice]
completion_stream = ModelResponseIterator(
model_response=streaming_model_response
)
print_verbose(
"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
)
return CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="cached_response",
logging_obj=logging_obj,
)
else:
raise AnthropicError(
status_code=422,
message="Unprocessable response object - {}".format(response.text),
)
def process_response(
self,
model,
response,
model_response,
_is_function_call,
stream,
logging_obj,
api_key,
data,
messages,
model: str,
response: requests.Response | httpx.Response,
model_response: ModelResponse,
stream: bool,
logging_obj: litellm.utils.Logging,
optional_params: dict,
api_key: str,
data: dict | str,
messages: List,
print_verbose,
):
encoding,
) -> ModelResponse:
## LOGGING
logging_obj.post_call(
input=messages,
@ -216,51 +317,6 @@ class AnthropicChatCompletion(BaseLLM):
completion_response["stop_reason"]
)
print_verbose(f"_is_function_call: {_is_function_call}; stream: {stream}")
if _is_function_call and stream:
print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
# return an iterator
streaming_model_response = ModelResponse(stream=True)
streaming_model_response.choices[0].finish_reason = model_response.choices[
0
].finish_reason
# streaming_model_response.choices = [litellm.utils.StreamingChoices()]
streaming_choice = litellm.utils.StreamingChoices()
streaming_choice.index = model_response.choices[0].index
_tool_calls = []
print_verbose(
f"type of model_response.choices[0]: {type(model_response.choices[0])}"
)
print_verbose(f"type of streaming_choice: {type(streaming_choice)}")
if isinstance(model_response.choices[0], litellm.Choices):
if getattr(
model_response.choices[0].message, "tool_calls", None
) is not None and isinstance(
model_response.choices[0].message.tool_calls, list
):
for tool_call in model_response.choices[0].message.tool_calls:
_tool_call = {**tool_call.dict(), "index": 0}
_tool_calls.append(_tool_call)
delta_obj = litellm.utils.Delta(
content=getattr(model_response.choices[0].message, "content", None),
role=model_response.choices[0].message.role,
tool_calls=_tool_calls,
)
streaming_choice.delta = delta_obj
streaming_model_response.choices = [streaming_choice]
completion_stream = ModelResponseIterator(
model_response=streaming_model_response
)
print_verbose(
"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
)
return CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="cached_response",
logging_obj=logging_obj,
)
## CALCULATING USAGE
prompt_tokens = completion_response["usage"]["input_tokens"]
completion_tokens = completion_response["usage"]["output_tokens"]
@ -273,7 +329,7 @@ class AnthropicChatCompletion(BaseLLM):
completion_tokens=completion_tokens,
total_tokens=total_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage) # type: ignore
return model_response
async def acompletion_stream_function(
@ -289,7 +345,7 @@ class AnthropicChatCompletion(BaseLLM):
logging_obj,
stream,
_is_function_call,
data=None,
data: dict,
optional_params=None,
litellm_params=None,
logger_fn=None,
@ -331,12 +387,12 @@ class AnthropicChatCompletion(BaseLLM):
logging_obj,
stream,
_is_function_call,
data=None,
optional_params=None,
data: dict,
optional_params: dict,
litellm_params=None,
logger_fn=None,
headers={},
):
) -> ModelResponse:
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
@ -347,13 +403,14 @@ class AnthropicChatCompletion(BaseLLM):
model=model,
response=response,
model_response=model_response,
_is_function_call=_is_function_call,
stream=stream,
logging_obj=logging_obj,
api_key=api_key,
data=data,
messages=messages,
print_verbose=print_verbose,
optional_params=optional_params,
encoding=encoding,
)
def completion(
@ -367,7 +424,7 @@ class AnthropicChatCompletion(BaseLLM):
encoding,
api_key,
logging_obj,
optional_params=None,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
@ -526,17 +583,33 @@ class AnthropicChatCompletion(BaseLLM):
raise AnthropicError(
status_code=response.status_code, message=response.text
)
if stream and _is_function_call:
return self.process_streaming_response(
model=model,
response=response,
model_response=model_response,
stream=stream,
logging_obj=logging_obj,
api_key=api_key,
data=data,
messages=messages,
print_verbose=print_verbose,
optional_params=optional_params,
encoding=encoding,
)
return self.process_response(
model=model,
response=response,
model_response=model_response,
_is_function_call=_is_function_call,
stream=stream,
logging_obj=logging_obj,
api_key=api_key,
data=data,
messages=messages,
print_verbose=print_verbose,
optional_params=optional_params,
encoding=encoding,
)
def embedding(self):