refactor(azure.py): moving embeddings to http call

This commit is contained in:
Krrish Dholakia 2023-11-08 19:07:21 -08:00
parent 880768f83d
commit 678249ee09
2 changed files with 77 additions and 18 deletions

View file

@ -171,6 +171,76 @@ class AzureChatCompletion(BaseLLM):
except AzureOpenAIError as e:
exception_mapping_worked = True
raise e
except Exception as e:
if exception_mapping_worked:
raise e
else:
import traceback
raise AzureOpenAIError(status_code=500, message=traceback.format_exc())
def embedding(self,
model: str,
input: list,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
logging_obj=None,
model_response=None,
optional_params=None,):
super().embedding()
exception_mapping_worked = False
try:
headers = self.validate_environment(api_key)
# Ensure api_base ends with a trailing slash
if not api_base.endswith('/'):
api_base += '/'
api_base = api_base + f"openai/deployments/{model}/embeddings?api-version={api_version}"
model = model
data = {
"model": model,
"input": input,
**optional_params
}
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
)
## COMPLETION CALL
response = self._client_session.post(
api_base, headers=headers, json=data
)
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
)
if response.status_code!=200:
raise AzureOpenAIError(message=response.text, status_code=response.status_code)
embedding_response = response.json()
output_data = []
for idx, embedding in enumerate(embedding_response["data"]):
output_data.append(
{
"object": embedding["object"],
"index": embedding["index"],
"embedding": embedding["embedding"]
}
)
model_response["object"] = "list"
model_response["data"] = output_data
model_response["model"] = model
model_response["usage"] = embedding_response["usage"]
return model_response
except AzureOpenAIError as e:
exception_mapping_worked = True
raise e
except Exception as e:
if exception_mapping_worked:
raise e