mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
refactor(azure.py): working azure client init on audio speech endpoint
This commit is contained in:
parent
f7d9cce536
commit
152bc67d22
5 changed files with 63 additions and 46 deletions
|
@ -60,6 +60,7 @@ def get_litellm_params(
|
||||||
merge_reasoning_content_in_choices: Optional[bool] = None,
|
merge_reasoning_content_in_choices: Optional[bool] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
||||||
litellm_params = {
|
litellm_params = {
|
||||||
"acompletion": acompletion,
|
"acompletion": acompletion,
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
|
@ -99,5 +100,11 @@ def get_litellm_params(
|
||||||
"async_call": async_call,
|
"async_call": async_call,
|
||||||
"ssl_verify": ssl_verify,
|
"ssl_verify": ssl_verify,
|
||||||
"merge_reasoning_content_in_choices": merge_reasoning_content_in_choices,
|
"merge_reasoning_content_in_choices": merge_reasoning_content_in_choices,
|
||||||
|
"azure_ad_token": kwargs.get("azure_ad_token"),
|
||||||
|
"tenant_id": kwargs.get("tenant_id"),
|
||||||
|
"client_id": kwargs.get("client_id"),
|
||||||
|
"client_secret": kwargs.get("client_secret"),
|
||||||
|
"azure_username": kwargs.get("azure_username"),
|
||||||
|
"azure_password": kwargs.get("azure_password"),
|
||||||
}
|
}
|
||||||
return litellm_params
|
return litellm_params
|
||||||
|
|
|
@ -153,27 +153,16 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
client: Optional[Any],
|
client: Optional[Any],
|
||||||
client_type: Literal["sync", "async"],
|
client_type: Literal["sync", "async"],
|
||||||
|
litellm_params: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
# init AzureOpenAI Client
|
# init AzureOpenAI Client
|
||||||
azure_client_params: Dict[str, Any] = {
|
azure_client_params: Dict[str, Any] = self.initialize_azure_sdk_client(
|
||||||
"api_version": api_version,
|
litellm_params=litellm_params or {},
|
||||||
"azure_endpoint": api_base,
|
api_key=api_key,
|
||||||
"azure_deployment": model,
|
model_name=model,
|
||||||
"http_client": litellm.client_session,
|
api_version=api_version,
|
||||||
"max_retries": max_retries,
|
api_base=api_base,
|
||||||
"timeout": timeout,
|
|
||||||
}
|
|
||||||
azure_client_params = select_azure_base_url_or_endpoint(
|
|
||||||
azure_client_params=azure_client_params
|
|
||||||
)
|
)
|
||||||
if api_key is not None:
|
|
||||||
azure_client_params["api_key"] = api_key
|
|
||||||
elif azure_ad_token is not None:
|
|
||||||
if azure_ad_token.startswith("oidc/"):
|
|
||||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
|
||||||
elif azure_ad_token_provider is not None:
|
|
||||||
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
|
|
||||||
if client is None:
|
if client is None:
|
||||||
if client_type == "sync":
|
if client_type == "sync":
|
||||||
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
|
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
|
||||||
|
@ -780,6 +769,7 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
|
||||||
client=None,
|
client=None,
|
||||||
aembedding=None,
|
aembedding=None,
|
||||||
headers: Optional[dict] = None,
|
headers: Optional[dict] = None,
|
||||||
|
litellm_params: Optional[dict] = None,
|
||||||
) -> EmbeddingResponse:
|
) -> EmbeddingResponse:
|
||||||
if headers:
|
if headers:
|
||||||
optional_params["extra_headers"] = headers
|
optional_params["extra_headers"] = headers
|
||||||
|
@ -795,29 +785,14 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
|
||||||
)
|
)
|
||||||
|
|
||||||
# init AzureOpenAI Client
|
# init AzureOpenAI Client
|
||||||
azure_client_params = {
|
|
||||||
"api_version": api_version,
|
|
||||||
"azure_endpoint": api_base,
|
|
||||||
"azure_deployment": model,
|
|
||||||
"max_retries": max_retries,
|
|
||||||
"timeout": timeout,
|
|
||||||
}
|
|
||||||
azure_client_params = select_azure_base_url_or_endpoint(
|
|
||||||
azure_client_params=azure_client_params
|
|
||||||
)
|
|
||||||
if aembedding:
|
|
||||||
azure_client_params["http_client"] = litellm.aclient_session
|
|
||||||
else:
|
|
||||||
azure_client_params["http_client"] = litellm.client_session
|
|
||||||
if api_key is not None:
|
|
||||||
azure_client_params["api_key"] = api_key
|
|
||||||
elif azure_ad_token is not None:
|
|
||||||
if azure_ad_token.startswith("oidc/"):
|
|
||||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
|
||||||
elif azure_ad_token_provider is not None:
|
|
||||||
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
|
|
||||||
|
|
||||||
|
azure_client_params = self.initialize_azure_sdk_client(
|
||||||
|
litellm_params=litellm_params or {},
|
||||||
|
api_key=api_key,
|
||||||
|
model_name=model,
|
||||||
|
api_version=api_version,
|
||||||
|
api_base=api_base,
|
||||||
|
)
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=input,
|
input=input,
|
||||||
|
@ -1282,6 +1257,7 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
|
||||||
azure_ad_token_provider: Optional[Callable] = None,
|
azure_ad_token_provider: Optional[Callable] = None,
|
||||||
aspeech: Optional[bool] = None,
|
aspeech: Optional[bool] = None,
|
||||||
client=None,
|
client=None,
|
||||||
|
litellm_params: Optional[dict] = None,
|
||||||
) -> HttpxBinaryResponseContent:
|
) -> HttpxBinaryResponseContent:
|
||||||
|
|
||||||
max_retries = optional_params.pop("max_retries", 2)
|
max_retries = optional_params.pop("max_retries", 2)
|
||||||
|
@ -1300,6 +1276,7 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
client=client,
|
client=client,
|
||||||
|
litellm_params=litellm_params,
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
azure_client: AzureOpenAI = self._get_sync_azure_client(
|
azure_client: AzureOpenAI = self._get_sync_azure_client(
|
||||||
|
@ -1313,6 +1290,7 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
client=client,
|
client=client,
|
||||||
client_type="sync",
|
client_type="sync",
|
||||||
|
litellm_params=litellm_params,
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
response = azure_client.audio.speech.create(
|
response = azure_client.audio.speech.create(
|
||||||
|
@ -1337,6 +1315,7 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
|
||||||
max_retries: int,
|
max_retries: int,
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
client=None,
|
client=None,
|
||||||
|
litellm_params: Optional[dict] = None,
|
||||||
) -> HttpxBinaryResponseContent:
|
) -> HttpxBinaryResponseContent:
|
||||||
|
|
||||||
azure_client: AsyncAzureOpenAI = self._get_sync_azure_client(
|
azure_client: AsyncAzureOpenAI = self._get_sync_azure_client(
|
||||||
|
@ -1350,6 +1329,7 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
client=client,
|
client=client,
|
||||||
client_type="async",
|
client_type="async",
|
||||||
|
litellm_params=litellm_params,
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
azure_response = await azure_client.audio.speech.create(
|
azure_response = await azure_client.audio.speech.create(
|
||||||
|
|
|
@ -58,6 +58,7 @@ def get_azure_openai_client(
|
||||||
data[k] = v
|
data[k] = v
|
||||||
if "api_version" not in data:
|
if "api_version" not in data:
|
||||||
data["api_version"] = litellm.AZURE_DEFAULT_API_VERSION
|
data["api_version"] = litellm.AZURE_DEFAULT_API_VERSION
|
||||||
|
|
||||||
if _is_async is True:
|
if _is_async is True:
|
||||||
openai_client = AsyncAzureOpenAI(**data)
|
openai_client = AsyncAzureOpenAI(**data)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -1162,6 +1162,12 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
merge_reasoning_content_in_choices=kwargs.get(
|
merge_reasoning_content_in_choices=kwargs.get(
|
||||||
"merge_reasoning_content_in_choices", None
|
"merge_reasoning_content_in_choices", None
|
||||||
),
|
),
|
||||||
|
azure_ad_token=kwargs.get("azure_ad_token"),
|
||||||
|
tenant_id=kwargs.get("tenant_id"),
|
||||||
|
client_id=kwargs.get("client_id"),
|
||||||
|
client_secret=kwargs.get("client_secret"),
|
||||||
|
azure_username=kwargs.get("azure_username"),
|
||||||
|
azure_password=kwargs.get("azure_password"),
|
||||||
)
|
)
|
||||||
logging.update_environment_variables(
|
logging.update_environment_variables(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -3411,6 +3417,7 @@ def embedding( # noqa: PLR0915
|
||||||
aembedding=aembedding,
|
aembedding=aembedding,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
headers=headers or extra_headers,
|
headers=headers or extra_headers,
|
||||||
|
litellm_params=litellm_params_dict,
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
model in litellm.open_ai_embedding_models
|
model in litellm.open_ai_embedding_models
|
||||||
|
@ -5002,6 +5009,7 @@ def transcription(
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
drop_params=drop_params,
|
drop_params=drop_params,
|
||||||
)
|
)
|
||||||
|
litellm_params_dict = get_litellm_params(**kwargs)
|
||||||
|
|
||||||
litellm_logging_obj.update_environment_variables(
|
litellm_logging_obj.update_environment_variables(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -5198,7 +5206,7 @@ def speech(
|
||||||
|
|
||||||
if max_retries is None:
|
if max_retries is None:
|
||||||
max_retries = litellm.num_retries or openai.DEFAULT_MAX_RETRIES
|
max_retries = litellm.num_retries or openai.DEFAULT_MAX_RETRIES
|
||||||
|
litellm_params_dict = get_litellm_params(**kwargs)
|
||||||
logging_obj = kwargs.get("litellm_logging_obj", None)
|
logging_obj = kwargs.get("litellm_logging_obj", None)
|
||||||
logging_obj.update_environment_variables(
|
logging_obj.update_environment_variables(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -5315,6 +5323,7 @@ def speech(
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
client=client, # pass AsyncOpenAI, OpenAI client
|
client=client, # pass AsyncOpenAI, OpenAI client
|
||||||
aspeech=aspeech,
|
aspeech=aspeech,
|
||||||
|
litellm_params=litellm_params_dict,
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "vertex_ai" or custom_llm_provider == "vertex_ai_beta":
|
elif custom_llm_provider == "vertex_ai" or custom_llm_provider == "vertex_ai_beta":
|
||||||
|
|
||||||
|
|
|
@ -219,8 +219,12 @@ def test_select_azure_base_url_called(setup_mocks):
|
||||||
CallTypes.acompletion,
|
CallTypes.acompletion,
|
||||||
CallTypes.atext_completion,
|
CallTypes.atext_completion,
|
||||||
CallTypes.aembedding,
|
CallTypes.aembedding,
|
||||||
CallTypes.arerank,
|
# CallTypes.arerank,
|
||||||
CallTypes.atranscription,
|
# CallTypes.atranscription,
|
||||||
|
CallTypes.aspeech,
|
||||||
|
CallTypes.aimage_generation,
|
||||||
|
# BATCHES ENDPOINTS
|
||||||
|
# ASSISTANT ENDPOINTS
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -255,15 +259,20 @@ async def test_ensure_initialize_azure_sdk_client_always_used(call_type):
|
||||||
"aembedding": {"input": "Hello, how are you?"},
|
"aembedding": {"input": "Hello, how are you?"},
|
||||||
"arerank": {"input": "Hello, how are you?"},
|
"arerank": {"input": "Hello, how are you?"},
|
||||||
"atranscription": {"file": "path/to/file"},
|
"atranscription": {"file": "path/to/file"},
|
||||||
|
"aspeech": {"input": "Hello, how are you?", "voice": "female"},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Get appropriate input for this call type
|
# Get appropriate input for this call type
|
||||||
input_kwarg = test_inputs.get(call_type.value, {})
|
input_kwarg = test_inputs.get(call_type.value, {})
|
||||||
|
|
||||||
|
patch_target = "litellm.main.azure_chat_completions.initialize_azure_sdk_client"
|
||||||
|
if call_type == CallTypes.atranscription:
|
||||||
|
patch_target = (
|
||||||
|
"litellm.main.azure_audio_transcriptions.initialize_azure_sdk_client"
|
||||||
|
)
|
||||||
|
|
||||||
# Mock the initialize_azure_sdk_client function
|
# Mock the initialize_azure_sdk_client function
|
||||||
with patch(
|
with patch(patch_target) as mock_init_azure:
|
||||||
"litellm.main.azure_chat_completions.initialize_azure_sdk_client"
|
|
||||||
) as mock_init_azure:
|
|
||||||
# Also mock async_function_with_fallbacks to prevent actual API calls
|
# Also mock async_function_with_fallbacks to prevent actual API calls
|
||||||
# Call the appropriate router method
|
# Call the appropriate router method
|
||||||
try:
|
try:
|
||||||
|
@ -271,6 +280,7 @@ async def test_ensure_initialize_azure_sdk_client_always_used(call_type):
|
||||||
model="gpt-3.5-turbo",
|
model="gpt-3.5-turbo",
|
||||||
**input_kwarg,
|
**input_kwarg,
|
||||||
num_retries=0,
|
num_retries=0,
|
||||||
|
azure_ad_token="oidc/test-token",
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
@ -282,6 +292,16 @@ async def test_ensure_initialize_azure_sdk_client_always_used(call_type):
|
||||||
calls = mock_init_azure.call_args_list
|
calls = mock_init_azure.call_args_list
|
||||||
azure_calls = [call for call in calls]
|
azure_calls = [call for call in calls]
|
||||||
|
|
||||||
|
litellm_params = azure_calls[0].kwargs["litellm_params"]
|
||||||
|
print("litellm_params", litellm_params)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
"azure_ad_token" in litellm_params
|
||||||
|
), "azure_ad_token not found in parameters"
|
||||||
|
assert (
|
||||||
|
litellm_params["azure_ad_token"] == "oidc/test-token"
|
||||||
|
), "azure_ad_token is not correct"
|
||||||
|
|
||||||
# More detailed verification (optional)
|
# More detailed verification (optional)
|
||||||
for call in azure_calls:
|
for call in azure_calls:
|
||||||
assert "api_key" in call.kwargs, "api_key not found in parameters"
|
assert "api_key" in call.kwargs, "api_key not found in parameters"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue