fix(bedrock_httpx.py): logging fixes

This commit is contained in:
Krrish Dholakia 2024-05-16 23:20:51 -07:00
parent 92c2e2af6a
commit 21f2ba6f1f
2 changed files with 31 additions and 2 deletions

View file

@ -735,7 +735,19 @@ class BedrockLLM(BaseLLM):
inference_params[k] = v
data = json.dumps({"prompt": prompt, **inference_params})
else:
raise Exception("UNSUPPORTED PROVIDER")
## LOGGING
logging_obj.pre_call(
input=messages,
api_key="",
additional_args={
"complete_input_dict": inference_params,
},
)
raise Exception(
"Bedrock HTTPX: Unsupported provider={}, model={}".format(
provider, model
)
)
## COMPLETION CALL
@ -822,6 +834,14 @@ class BedrockLLM(BaseLLM):
status_code=response.status_code, message=response.text
)
## LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response=response.text,
additional_args={"complete_input_dict": data},
)
decoder = AWSEventStreamDecoder(model=model)
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
@ -940,6 +960,15 @@ class BedrockLLM(BaseLLM):
custom_llm_provider="bedrock",
logging_obj=logging_obj,
)
## LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response=streaming_response,
additional_args={"complete_input_dict": data},
)
return streaming_response
def embedding(self, *args, **kwargs):

View file

@ -558,7 +558,7 @@ async def test_async_chat_bedrock_stream():
continue
except:
pass
time.sleep(1)
await asyncio.sleep(1)
print(f"customHandler.errors: {customHandler.errors}")
assert len(customHandler.errors) == 0
litellm.callbacks = []