forked from phoenix/litellm-mirror
(feat) Embedding: Async Azure
This commit is contained in:
parent
53554bae85
commit
c05da0797b
3 changed files with 45 additions and 3 deletions
|
@ -279,6 +279,24 @@ class AzureChatCompletion(BaseLLM):
|
|||
async for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
|
||||
async def aembedding(
|
||||
self,
|
||||
data: dict,
|
||||
model_response: ModelResponse,
|
||||
azure_client_params: dict,
|
||||
client=None,
|
||||
):
|
||||
response = None
|
||||
try:
|
||||
if client is None:
|
||||
openai_aclient = AsyncAzureOpenAI(**azure_client_params)
|
||||
else:
|
||||
openai_aclient = client
|
||||
response = await openai_aclient.embeddings.create(**data)
|
||||
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="embedding")
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def embedding(self,
|
||||
model: str,
|
||||
input: list,
|
||||
|
@ -290,7 +308,8 @@ class AzureChatCompletion(BaseLLM):
|
|||
model_response=None,
|
||||
optional_params=None,
|
||||
azure_ad_token: Optional[str]=None,
|
||||
client = None
|
||||
client = None,
|
||||
aembedding=None,
|
||||
):
|
||||
super().embedding()
|
||||
exception_mapping_worked = False
|
||||
|
@ -319,6 +338,9 @@ class AzureChatCompletion(BaseLLM):
|
|||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
if aembedding == True:
|
||||
response = self.aembedding(data=data, model_response=model_response, azure_client_params=azure_client_params)
|
||||
return response
|
||||
if client is None:
|
||||
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
|
||||
else:
|
||||
|
|
|
@ -1801,7 +1801,8 @@ def embedding(
|
|||
timeout=timeout,
|
||||
model_response=EmbeddingResponse(),
|
||||
optional_params=optional_params,
|
||||
client=client
|
||||
client=client,
|
||||
aembedding=aembedding
|
||||
)
|
||||
elif model in litellm.open_ai_embedding_models or custom_llm_provider == "openai":
|
||||
api_base = (
|
||||
|
|
|
@ -192,7 +192,26 @@ def test_aembedding():
|
|||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
test_aembedding()
|
||||
# test_aembedding()
|
||||
|
||||
|
||||
def test_aembedding_azure():
|
||||
try:
|
||||
import asyncio
|
||||
async def embedding_call():
|
||||
try:
|
||||
response = await litellm.aembedding(
|
||||
model="azure/azure-embedding-model",
|
||||
input=["good morning from litellm", "this is another item"]
|
||||
)
|
||||
print(response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
asyncio.run(embedding_call())
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# test_aembedding_azure()
|
||||
|
||||
# def test_custom_openai_embedding():
|
||||
# litellm.set_verbose=True
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue