refactor(openai.py): moving embedding calls to http

This commit is contained in:
Krrish Dholakia 2023-11-08 19:01:17 -08:00
parent c2cbdb23fd
commit 70311502c8
2 changed files with 75 additions and 21 deletions

View file

@ -269,6 +269,71 @@ class OpenAIChatCompletion(BaseLLM):
else:
import traceback
raise OpenAIError(status_code=500, message=traceback.format_exc())
def embedding(self,
model: str,
input: list,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
logging_obj=None,
model_response=None,
optional_params=None,):
super().embedding()
exception_mapping_worked = False
try:
headers = self.validate_environment(api_key)
api_base = f"{api_base}/embeddings"
model = model
data = {
"model": model,
"input": input,
**optional_params
}
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
)
## COMPLETION CALL
response = self._client_session.post(
api_base, headers=headers, json=data
)
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
)
if response.status_code!=200:
raise OpenAIError(message=response.text, status_code=response.status_code)
embedding_response = response.json()
output_data = []
for idx, embedding in enumerate(embedding_response["data"]):
output_data.append(
{
"object": embedding["object"],
"index": embedding["index"],
"embedding": embedding["embedding"]
}
)
model_response["object"] = "list"
model_response["data"] = output_data
model_response["model"] = model
model_response["usage"] = embedding_response["usage"]
return model_response
except OpenAIError as e:
exception_mapping_worked = True
raise e
except Exception as e:
if exception_mapping_worked:
raise e
else:
import traceback
raise OpenAIError(status_code=500, message=traceback.format_exc())
class OpenAITextCompletion(BaseLLM):