mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(vertex_ai.py): add async embedding support for vertex ai
This commit is contained in:
parent
5bf51a6058
commit
0ffdf57dec
3 changed files with 102 additions and 0 deletions
|
@ -945,6 +945,7 @@ def embedding(
|
||||||
encoding=None,
|
encoding=None,
|
||||||
vertex_project=None,
|
vertex_project=None,
|
||||||
vertex_location=None,
|
vertex_location=None,
|
||||||
|
aembedding=False,
|
||||||
):
|
):
|
||||||
# logic for parsing in - calling - parsing out model embedding calls
|
# logic for parsing in - calling - parsing out model embedding calls
|
||||||
try:
|
try:
|
||||||
|
@ -972,9 +973,95 @@ def embedding(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
llm_model = TextEmbeddingModel.from_pretrained(model)
|
llm_model = TextEmbeddingModel.from_pretrained(model)
|
||||||
|
except Exception as e:
|
||||||
|
raise VertexAIError(status_code=422, message=str(e))
|
||||||
|
|
||||||
|
if aembedding == True:
|
||||||
|
return async_embedding(
|
||||||
|
model=model,
|
||||||
|
client=llm_model,
|
||||||
|
input=input,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
model_response=model_response,
|
||||||
|
optional_params=optional_params,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
|
||||||
|
request_str = f"""embeddings = llm_model.get_embeddings({input})"""
|
||||||
|
## LOGGING PRE-CALL
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=input,
|
||||||
|
api_key=None,
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": optional_params,
|
||||||
|
"request_str": request_str,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
embeddings = llm_model.get_embeddings(input)
|
embeddings = llm_model.get_embeddings(input)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise VertexAIError(status_code=500, message=str(e))
|
raise VertexAIError(status_code=500, message=str(e))
|
||||||
|
|
||||||
|
## LOGGING POST-CALL
|
||||||
|
logging_obj.post_call(input=input, api_key=None, original_response=embeddings)
|
||||||
|
## Populate OpenAI compliant dictionary
|
||||||
|
embedding_response = []
|
||||||
|
for idx, embedding in enumerate(embeddings):
|
||||||
|
embedding_response.append(
|
||||||
|
{
|
||||||
|
"object": "embedding",
|
||||||
|
"index": idx,
|
||||||
|
"embedding": embedding.values,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
model_response["object"] = "list"
|
||||||
|
model_response["data"] = embedding_response
|
||||||
|
model_response["model"] = model
|
||||||
|
input_tokens = 0
|
||||||
|
|
||||||
|
input_str = "".join(input)
|
||||||
|
|
||||||
|
input_tokens += len(encoding.encode(input_str))
|
||||||
|
|
||||||
|
usage = Usage(
|
||||||
|
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
|
||||||
|
)
|
||||||
|
model_response.usage = usage
|
||||||
|
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
|
||||||
|
async def async_embedding(
|
||||||
|
model: str,
|
||||||
|
input: Union[list, str],
|
||||||
|
logging_obj=None,
|
||||||
|
model_response=None,
|
||||||
|
optional_params=None,
|
||||||
|
encoding=None,
|
||||||
|
client=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Async embedding implementation
|
||||||
|
"""
|
||||||
|
request_str = f"""embeddings = llm_model.get_embeddings({input})"""
|
||||||
|
## LOGGING PRE-CALL
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=input,
|
||||||
|
api_key=None,
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": optional_params,
|
||||||
|
"request_str": request_str,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
embeddings = await client.get_embeddings_async(input)
|
||||||
|
except Exception as e:
|
||||||
|
raise VertexAIError(status_code=500, message=str(e))
|
||||||
|
|
||||||
|
## LOGGING POST-CALL
|
||||||
|
logging_obj.post_call(input=input, api_key=None, original_response=embeddings)
|
||||||
## Populate OpenAI compliant dictionary
|
## Populate OpenAI compliant dictionary
|
||||||
embedding_response = []
|
embedding_response = []
|
||||||
for idx, embedding in enumerate(embeddings):
|
for idx, embedding in enumerate(embeddings):
|
||||||
|
|
|
@ -2211,6 +2211,7 @@ async def aembedding(*args, **kwargs):
|
||||||
or custom_llm_provider == "deepinfra"
|
or custom_llm_provider == "deepinfra"
|
||||||
or custom_llm_provider == "perplexity"
|
or custom_llm_provider == "perplexity"
|
||||||
or custom_llm_provider == "ollama"
|
or custom_llm_provider == "ollama"
|
||||||
|
or custom_llm_provider == "vertex_ai"
|
||||||
): # currently implemented aiohttp calls for just azure and openai, soon all.
|
): # currently implemented aiohttp calls for just azure and openai, soon all.
|
||||||
# Await normally
|
# Await normally
|
||||||
init_response = await loop.run_in_executor(None, func_with_context)
|
init_response = await loop.run_in_executor(None, func_with_context)
|
||||||
|
@ -2549,6 +2550,7 @@ def embedding(
|
||||||
model_response=EmbeddingResponse(),
|
model_response=EmbeddingResponse(),
|
||||||
vertex_project=vertex_ai_project,
|
vertex_project=vertex_ai_project,
|
||||||
vertex_location=vertex_ai_location,
|
vertex_location=vertex_ai_location,
|
||||||
|
aembedding=aembedding,
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "oobabooga":
|
elif custom_llm_provider == "oobabooga":
|
||||||
response = oobabooga.embedding(
|
response = oobabooga.embedding(
|
||||||
|
|
|
@ -243,6 +243,19 @@ def test_vertexai_embedding():
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_vertexai_aembedding():
|
||||||
|
try:
|
||||||
|
# litellm.set_verbose=True
|
||||||
|
response = await litellm.aembedding(
|
||||||
|
model="textembedding-gecko@001",
|
||||||
|
input=["good morning from litellm", "this is another item"],
|
||||||
|
)
|
||||||
|
print(f"response: {response}")
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
def test_bedrock_embedding_titan():
|
def test_bedrock_embedding_titan():
|
||||||
try:
|
try:
|
||||||
# this tests if we support str input for bedrock embedding
|
# this tests if we support str input for bedrock embedding
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue