clean up azure implementation

This commit is contained in:
ishaan-jaff 2023-09-05 15:31:57 -07:00
parent 463e57d0a8
commit 2250d1375e
5 changed files with 61 additions and 90 deletions

File diff suppressed because one or more lines are too long

View file

@ -10,11 +10,12 @@
| gpt-4 | `completion('gpt-4', messages)` | `os.environ['OPENAI_API_KEY']` |
## Azure OpenAI Chat Completion Models
For Azure calls add the `azure/` prefix to `model`. If your azure deployment name is `gpt-v-2` set `model` = `azure/gpt-v-2`
| Model Name | Function Call | Required OS Variables |
|------------------|-----------------------------------------|-------------------------------------------|
| gpt-3.5-turbo | `completion('gpt-3.5-turbo', messages, azure=True)` | `os.environ['AZURE_API_KEY']`,`os.environ['AZURE_API_BASE']`,`os.environ['AZURE_API_VERSION']` |
| gpt-4 | `completion('gpt-4', messages, azure=True)` | `os.environ['AZURE_API_KEY']`,`os.environ['AZURE_API_BASE']`,`os.environ['AZURE_API_VERSION']` |
| gpt-3.5-turbo | `completion('azure/gpt-3.5-turbo-deployment', messages)` | `os.environ['AZURE_API_KEY']`,`os.environ['AZURE_API_BASE']`,`os.environ['AZURE_API_VERSION']` |
| gpt-4 | `completion('azure/gpt-4-deployment', messages)` | `os.environ['AZURE_API_KEY']`,`os.environ['AZURE_API_BASE']`,`os.environ['AZURE_API_VERSION']` |
### OpenAI Text Completion Models

View file

@ -115,10 +115,9 @@ def completion(
model
] # update the model to the actual value if an alias has been passed in
model_response = ModelResponse()
if azure: # this flag is deprecated, remove once notebooks are also updated.
custom_llm_provider = "azure"
if deployment_id:
if deployment_id != None:
model=deployment_id
custom_llm_provider="azure"
elif (
model.split("/", 1)[0] in litellm.provider_list
): # allow custom provider to be passed in via the model name "azure/chatgpt-test"

View file

@ -83,7 +83,7 @@ def test_bad_azure_embedding():
# def test_good_azure_embedding():
# try:
# response = embedding(model='azure-embedding-model', input=[user_message], azure=True, logger_fn=logger_fn)
# response = embedding(model='azure/azure-embedding-model', input=[user_message], logger_fn=logger_fn)
# # Add any assertions here to check the response
# print(f"response: {str(response)[:50]}")
# except Exception as e:

View file

@ -326,10 +326,10 @@ def test_completion_openai_with_functions():
def test_completion_azure():
try:
print("azure gpt-3.5 test\n\n")
response = completion(
model="chatgpt-test",
model="azure/chatgpt-v-2",
messages=messages,
custom_llm_provider="azure",
)
# Add any assertions here to check the response
print(response)
@ -340,10 +340,9 @@ def test_completion_azure():
def test_completion_azure_deployment_id():
try:
response = completion(
model="chatgpt-3.5-turbo",
deployment_id="chatgpt-v-2",
model="gpt-3.5-turbo",
messages=messages,
azure=True,
)
# Add any assertions here to check the response
print(response)