mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
refactor(openai.py): moving embedding calls to http
This commit is contained in:
parent
c2cbdb23fd
commit
70311502c8
2 changed files with 75 additions and 21 deletions
|
@ -270,6 +270,71 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
import traceback
|
||||
raise OpenAIError(status_code=500, message=traceback.format_exc())
|
||||
|
||||
def embedding(self,
|
||||
model: str,
|
||||
input: list,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
logging_obj=None,
|
||||
model_response=None,
|
||||
optional_params=None,):
|
||||
super().embedding()
|
||||
exception_mapping_worked = False
|
||||
try:
|
||||
headers = self.validate_environment(api_key)
|
||||
api_base = f"{api_base}/embeddings"
|
||||
model = model
|
||||
data = {
|
||||
"model": model,
|
||||
"input": input,
|
||||
**optional_params
|
||||
}
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
## COMPLETION CALL
|
||||
response = self._client_session.post(
|
||||
api_base, headers=headers, json=data
|
||||
)
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=response,
|
||||
)
|
||||
|
||||
if response.status_code!=200:
|
||||
raise OpenAIError(message=response.text, status_code=response.status_code)
|
||||
embedding_response = response.json()
|
||||
output_data = []
|
||||
for idx, embedding in enumerate(embedding_response["data"]):
|
||||
output_data.append(
|
||||
{
|
||||
"object": embedding["object"],
|
||||
"index": embedding["index"],
|
||||
"embedding": embedding["embedding"]
|
||||
}
|
||||
)
|
||||
model_response["object"] = "list"
|
||||
model_response["data"] = output_data
|
||||
model_response["model"] = model
|
||||
model_response["usage"] = embedding_response["usage"]
|
||||
return model_response
|
||||
except OpenAIError as e:
|
||||
exception_mapping_worked = True
|
||||
raise e
|
||||
except Exception as e:
|
||||
if exception_mapping_worked:
|
||||
raise e
|
||||
else:
|
||||
import traceback
|
||||
raise OpenAIError(status_code=500, message=traceback.format_exc())
|
||||
|
||||
|
||||
class OpenAITextCompletion(BaseLLM):
|
||||
_client_session: requests.Session
|
||||
|
|
|
@ -1725,28 +1725,17 @@ def embedding(
|
|||
api_type = "openai"
|
||||
api_version = None
|
||||
|
||||
## LOGGING
|
||||
logging.pre_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"api_type": api_type,
|
||||
"api_base": api_base,
|
||||
"api_version": api_version,
|
||||
},
|
||||
)
|
||||
## EMBEDDING CALL
|
||||
response = openai.Embedding.create(
|
||||
input=input,
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
api_type=api_type,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging.post_call(input=input, api_key=api_key, original_response=response)
|
||||
## EMBEDDING CALL
|
||||
response = openai_chat_completions.embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
logging_obj=logging,
|
||||
model_response=EmbeddingResponse(),
|
||||
optional_params=kwargs
|
||||
)
|
||||
elif model in litellm.cohere_embedding_models:
|
||||
cohere_key = (
|
||||
api_key
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue