fix(vertex_ai.py): add async embedding support for vertex ai

This commit is contained in:
Krrish Dholakia 2024-02-03 10:35:17 -08:00
parent 5bf51a6058
commit 0ffdf57dec
3 changed files with 102 additions and 0 deletions

View file

@ -945,6 +945,7 @@ def embedding(
encoding=None,
vertex_project=None,
vertex_location=None,
aembedding=False,
):
# logic for parsing in - calling - parsing out model embedding calls
try:
@ -972,9 +973,95 @@ def embedding(
try:
llm_model = TextEmbeddingModel.from_pretrained(model)
except Exception as e:
raise VertexAIError(status_code=422, message=str(e))
if aembedding == True:
return async_embedding(
model=model,
client=llm_model,
input=input,
logging_obj=logging_obj,
model_response=model_response,
optional_params=optional_params,
encoding=encoding,
)
request_str = f"""embeddings = llm_model.get_embeddings({input})"""
## LOGGING PRE-CALL
logging_obj.pre_call(
input=input,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
try:
embeddings = llm_model.get_embeddings(input)
except Exception as e:
raise VertexAIError(status_code=500, message=str(e))
## LOGGING POST-CALL
logging_obj.post_call(input=input, api_key=None, original_response=embeddings)
## Populate OpenAI compliant dictionary
embedding_response = []
for idx, embedding in enumerate(embeddings):
embedding_response.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding.values,
}
)
model_response["object"] = "list"
model_response["data"] = embedding_response
model_response["model"] = model
input_tokens = 0
input_str = "".join(input)
input_tokens += len(encoding.encode(input_str))
usage = Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
)
model_response.usage = usage
return model_response
async def async_embedding(
model: str,
input: Union[list, str],
logging_obj=None,
model_response=None,
optional_params=None,
encoding=None,
client=None,
):
"""
Async embedding implementation
"""
request_str = f"""embeddings = llm_model.get_embeddings({input})"""
## LOGGING PRE-CALL
logging_obj.pre_call(
input=input,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
try:
embeddings = await client.get_embeddings_async(input)
except Exception as e:
raise VertexAIError(status_code=500, message=str(e))
## LOGGING POST-CALL
logging_obj.post_call(input=input, api_key=None, original_response=embeddings)
## Populate OpenAI compliant dictionary
embedding_response = []
for idx, embedding in enumerate(embeddings):