refactor(openai.py): moving embedding calls to http

This commit is contained in:
Krrish Dholakia 2023-11-08 19:01:17 -08:00
parent c2cbdb23fd
commit 70311502c8
2 changed files with 75 additions and 21 deletions

View file

@ -269,6 +269,71 @@ class OpenAIChatCompletion(BaseLLM):
else: else:
import traceback import traceback
raise OpenAIError(status_code=500, message=traceback.format_exc()) raise OpenAIError(status_code=500, message=traceback.format_exc())
def embedding(self,
model: str,
input: list,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
logging_obj=None,
model_response=None,
optional_params=None,):
super().embedding()
exception_mapping_worked = False
try:
headers = self.validate_environment(api_key)
api_base = f"{api_base}/embeddings"
model = model
data = {
"model": model,
"input": input,
**optional_params
}
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
)
## COMPLETION CALL
response = self._client_session.post(
api_base, headers=headers, json=data
)
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
)
if response.status_code!=200:
raise OpenAIError(message=response.text, status_code=response.status_code)
embedding_response = response.json()
output_data = []
for idx, embedding in enumerate(embedding_response["data"]):
output_data.append(
{
"object": embedding["object"],
"index": embedding["index"],
"embedding": embedding["embedding"]
}
)
model_response["object"] = "list"
model_response["data"] = output_data
model_response["model"] = model
model_response["usage"] = embedding_response["usage"]
return model_response
except OpenAIError as e:
exception_mapping_worked = True
raise e
except Exception as e:
if exception_mapping_worked:
raise e
else:
import traceback
raise OpenAIError(status_code=500, message=traceback.format_exc())
class OpenAITextCompletion(BaseLLM): class OpenAITextCompletion(BaseLLM):

View file

@ -1725,28 +1725,17 @@ def embedding(
api_type = "openai" api_type = "openai"
api_version = None api_version = None
## LOGGING
logging.pre_call(
input=input,
api_key=api_key,
additional_args={
"api_type": api_type,
"api_base": api_base,
"api_version": api_version,
},
)
## EMBEDDING CALL
response = openai.Embedding.create(
input=input,
model=model,
api_key=api_key,
api_base=api_base,
api_version=api_version,
api_type=api_type,
)
## LOGGING ## EMBEDDING CALL
logging.post_call(input=input, api_key=api_key, original_response=response) response = openai_chat_completions.embedding(
model=model,
input=input,
api_base=api_base,
api_key=api_key,
logging_obj=logging,
model_response=EmbeddingResponse(),
optional_params=kwargs
)
elif model in litellm.cohere_embedding_models: elif model in litellm.cohere_embedding_models:
cohere_key = ( cohere_key = (
api_key api_key