refactor(azure.py): moving embeddings to http call

This commit is contained in:
Krrish Dholakia 2023-11-08 19:07:21 -08:00
parent 880768f83d
commit 678249ee09
2 changed files with 77 additions and 18 deletions

View file

@ -171,6 +171,76 @@ class AzureChatCompletion(BaseLLM):
except AzureOpenAIError as e: except AzureOpenAIError as e:
exception_mapping_worked = True exception_mapping_worked = True
raise e raise e
except Exception as e:
if exception_mapping_worked:
raise e
else:
import traceback
raise AzureOpenAIError(status_code=500, message=traceback.format_exc())
def embedding(self,
model: str,
input: list,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
logging_obj=None,
model_response=None,
optional_params=None,):
super().embedding()
exception_mapping_worked = False
try:
headers = self.validate_environment(api_key)
# Ensure api_base ends with a trailing slash
if not api_base.endswith('/'):
api_base += '/'
api_base = api_base + f"openai/deployments/{model}/embeddings?api-version={api_version}"
model = model
data = {
"model": model,
"input": input,
**optional_params
}
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
)
## COMPLETION CALL
response = self._client_session.post(
api_base, headers=headers, json=data
)
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
)
if response.status_code!=200:
raise AzureOpenAIError(message=response.text, status_code=response.status_code)
embedding_response = response.json()
output_data = []
for idx, embedding in enumerate(embedding_response["data"]):
output_data.append(
{
"object": embedding["object"],
"index": embedding["index"],
"embedding": embedding["embedding"]
}
)
model_response["object"] = "list"
model_response["data"] = output_data
model_response["model"] = model
model_response["usage"] = embedding_response["usage"]
return model_response
except AzureOpenAIError as e:
exception_mapping_worked = True
raise e
except Exception as e: except Exception as e:
if exception_mapping_worked: if exception_mapping_worked:
raise e raise e

View file

@ -1681,28 +1681,17 @@ def embedding(
litellm.azure_key or litellm.azure_key or
get_secret("AZURE_API_KEY") get_secret("AZURE_API_KEY")
) )
## LOGGING
logging.pre_call(
input=input,
api_key=openai.api_key,
additional_args={
"api_type": openai.api_type,
"api_base": openai.api_base,
"api_version": openai.api_version,
},
)
## EMBEDDING CALL ## EMBEDDING CALL
response = openai.Embedding.create( response = azure_chat_completions.embedding(
input=input, model=model,
engine=model, input=input,
api_key=api_key,
api_base=api_base, api_base=api_base,
api_key=api_key,
api_version=api_version, api_version=api_version,
api_type=api_type, logging_obj=logging,
model_response=EmbeddingResponse(),
optional_params=kwargs
) )
## LOGGING
logging.post_call(input=input, api_key=openai.api_key, original_response=response)
elif model in litellm.open_ai_embedding_models: elif model in litellm.open_ai_embedding_models:
api_base = ( api_base = (
api_base api_base