diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index bf7289fa1d..5285aaebe0 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -171,6 +171,76 @@ class AzureChatCompletion(BaseLLM): except AzureOpenAIError as e: exception_mapping_worked = True raise e + except Exception as e: + if exception_mapping_worked: + raise e + else: + import traceback + raise AzureOpenAIError(status_code=500, message=traceback.format_exc()) + + def embedding(self, + model: str, + input: list, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + api_version: Optional[str] = None, + logging_obj=None, + model_response=None, + optional_params=None,): + super().embedding() + exception_mapping_worked = False + try: + headers = self.validate_environment(api_key) + # Ensure api_base ends with a trailing slash + if not api_base.endswith('/'): + api_base += '/' + + api_base = api_base + f"openai/deployments/{model}/embeddings?api-version={api_version}" + model = model + data = { + "model": model, + "input": input, + **optional_params + } + + ## LOGGING + logging_obj.pre_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + ) + ## COMPLETION CALL + response = self._client_session.post( + api_base, headers=headers, json=data + ) + ## LOGGING + logging_obj.post_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=response, + ) + + if response.status_code!=200: + raise AzureOpenAIError(message=response.text, status_code=response.status_code) + embedding_response = response.json() + output_data = [] + for idx, embedding in enumerate(embedding_response["data"]): + output_data.append( + { + "object": embedding["object"], + "index": embedding["index"], + "embedding": embedding["embedding"] + } + ) + model_response["object"] = "list" + model_response["data"] = output_data + model_response["model"] = model + model_response["usage"] = embedding_response["usage"] + return model_response + except AzureOpenAIError as e: + exception_mapping_worked = True + raise e except Exception as e: if exception_mapping_worked: raise e diff --git a/litellm/main.py b/litellm/main.py index 5f86a8f491..2d6d200b8f 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1681,28 +1681,17 @@ def embedding( litellm.azure_key or get_secret("AZURE_API_KEY") ) - ## LOGGING - logging.pre_call( - input=input, - api_key=openai.api_key, - additional_args={ - "api_type": openai.api_type, - "api_base": openai.api_base, - "api_version": openai.api_version, - }, - ) ## EMBEDDING CALL - response = openai.Embedding.create( - input=input, - engine=model, - api_key=api_key, + response = azure_chat_completions.embedding( + model=model, + input=input, api_base=api_base, + api_key=api_key, api_version=api_version, - api_type=api_type, + logging_obj=logging, + model_response=EmbeddingResponse(), + optional_params=kwargs ) - - ## LOGGING - logging.post_call(input=input, api_key=openai.api_key, original_response=response) elif model in litellm.open_ai_embedding_models: api_base = ( api_base