(feat) Embedding: Async Azure

This commit is contained in:
ishaan-jaff 2023-11-29 19:43:47 -08:00
parent 53554bae85
commit c05da0797b
3 changed files with 45 additions and 3 deletions

View file

@ -279,6 +279,24 @@ class AzureChatCompletion(BaseLLM):
async for transformed_chunk in streamwrapper:
yield transformed_chunk
async def aembedding(
self,
data: dict,
model_response: ModelResponse,
azure_client_params: dict,
client=None,
):
response = None
try:
if client is None:
openai_aclient = AsyncAzureOpenAI(**azure_client_params)
else:
openai_aclient = client
response = await openai_aclient.embeddings.create(**data)
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="embedding")
except Exception as e:
raise e
def embedding(self,
model: str,
input: list,
@ -290,7 +308,8 @@ class AzureChatCompletion(BaseLLM):
model_response=None,
optional_params=None,
azure_ad_token: Optional[str]=None,
client = None
client = None,
aembedding=None,
):
super().embedding()
exception_mapping_worked = False
@ -319,6 +338,9 @@ class AzureChatCompletion(BaseLLM):
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
azure_client_params["azure_ad_token"] = azure_ad_token
if aembedding == True:
response = self.aembedding(data=data, model_response=model_response, azure_client_params=azure_client_params)
return response
if client is None:
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
else:

View file

@ -1801,7 +1801,8 @@ def embedding(
timeout=timeout,
model_response=EmbeddingResponse(),
optional_params=optional_params,
client=client
client=client,
aembedding=aembedding
)
elif model in litellm.open_ai_embedding_models or custom_llm_provider == "openai":
api_base = (

View file

@ -192,7 +192,26 @@ def test_aembedding():
except Exception as e:
pytest.fail(f"Error occurred: {e}")
test_aembedding()
# test_aembedding()
def test_aembedding_azure():
try:
import asyncio
async def embedding_call():
try:
response = await litellm.aembedding(
model="azure/azure-embedding-model",
input=["good morning from litellm", "this is another item"]
)
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
asyncio.run(embedding_call())
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_aembedding_azure()
# def test_custom_openai_embedding():
# litellm.set_verbose=True