(v0) fixes for Azure GPT Vision enhancements

This commit is contained in:
ishaan-jaff 2024-01-17 09:57:16 -08:00
parent 6cb6bf0727
commit b95d6ec207
2 changed files with 37 additions and 5 deletions

View file

@ -95,6 +95,25 @@ class AzureOpenAIConfig(OpenAIConfig):
) )
def select_azure_base_url_or_endpoint(azure_client_params: dict):
# azure_client_params = {
# "api_version": api_version,
# "azure_endpoint": api_base,
# "azure_deployment": model,
# "http_client": litellm.client_session,
# "max_retries": max_retries,
# "timeout": timeout,
# }
azure_endpoint = azure_client_params.get("azure_endpoint", None)
if azure_endpoint is not None:
if "/openai" in azure_endpoint:
# this is base_url, not an azure_endpoint
azure_client_params["base_url"] = azure_endpoint
azure_client_params.pop("azure_endpoint")
return azure_client_params
class AzureChatCompletion(BaseLLM): class AzureChatCompletion(BaseLLM):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
@ -239,6 +258,9 @@ class AzureChatCompletion(BaseLLM):
"max_retries": max_retries, "max_retries": max_retries,
"timeout": timeout, "timeout": timeout,
} }
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)
if api_key is not None: if api_key is not None:
azure_client_params["api_key"] = api_key azure_client_params["api_key"] = api_key
elif azure_ad_token is not None: elif azure_ad_token is not None:

View file

@ -229,7 +229,7 @@ def test_completion_azure_gpt4_vision():
litellm.set_verbose = True litellm.set_verbose = True
response = completion( response = completion(
model="azure/gpt-4-vision", model="azure/gpt-4-vision",
timeout=1, timeout=5,
messages=[ messages=[
{ {
"role": "user", "role": "user",
@ -244,21 +244,31 @@ def test_completion_azure_gpt4_vision():
], ],
} }
], ],
base_url="https://gpt-4-vision-resource.openai.azure.com/", base_url="https://gpt-4-vision-resource.openai.azure.com/openai/deployments/gpt-4-vision/extensions",
api_key=os.getenv("AZURE_VISION_API_KEY"), api_key=os.getenv("AZURE_VISION_API_KEY"),
enhancements={"ocr": {"enabled": True}, "grounding": {"enabled": True}},
dataSources=[
{
"type": "AzureComputerVision",
"parameters": {
"endpoint": "https://gpt-4-vision-enhancement.cognitiveservices.azure.com/",
"key": "efcd55c055ca47e08f61a8c54ba1707b",
},
}
],
) )
print(response) print(response)
except openai.APITimeoutError: except openai.APITimeoutError:
print("got a timeout error") print("got a timeout error")
pass pass
except openai.RateLimitError: except openai.RateLimitError as e:
print("got a rate liimt error") print("got a rate liimt error", e)
pass pass
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_completion_azure_gpt4_vision() test_completion_azure_gpt4_vision()
@pytest.mark.skip(reason="this test is flaky") @pytest.mark.skip(reason="this test is flaky")