(fix) improve cohere error handling

This commit is contained in:
ishaan-jaff 2023-11-02 10:07:11 -07:00
parent 744e69f01f
commit 724e169f32

View file

@ -137,6 +137,10 @@ def completion(
response = requests.post( response = requests.post(
completion_url, headers=headers, data=json.dumps(data), stream=optional_params["stream"] if "stream" in optional_params else False completion_url, headers=headers, data=json.dumps(data), stream=optional_params["stream"] if "stream" in optional_params else False
) )
## error handling for cohere calls
if response.status_code!=200:
raise CohereError(message=response.text, status_code=response.status_code)
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
return response.iter_lines() return response.iter_lines()
else: else:
@ -210,7 +214,6 @@ def embedding(
response = requests.post( response = requests.post(
embed_url, headers=headers, data=json.dumps(data) embed_url, headers=headers, data=json.dumps(data)
) )
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=input, input=input,
@ -230,6 +233,8 @@ def embedding(
'usage' 'usage'
} }
""" """
if response.status_code!=200:
raise CohereError(message=response.text, status_code=response.status_code)
embeddings = response.json()['embeddings'] embeddings = response.json()['embeddings']
output_data = [] output_data = []
for idx, embedding in enumerate(embeddings): for idx, embedding in enumerate(embeddings):