mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Merge pull request #5296 from BerriAI/litellm_azure_json_schema_support
feat(azure.py): support 'json_schema' for older models
This commit is contained in:
commit
02eb6455b2
3 changed files with 101 additions and 31 deletions
|
@ -2162,37 +2162,44 @@ def test_completion_openai():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_completion_openai_pydantic():
|
||||
@pytest.mark.parametrize("model", ["gpt-4o-2024-08-06", "azure/chatgpt-v-2"])
|
||||
def test_completion_openai_pydantic(model):
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
from pydantic import BaseModel
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "List 5 important events in the XIX century"}
|
||||
]
|
||||
|
||||
class CalendarEvent(BaseModel):
|
||||
name: str
|
||||
date: str
|
||||
participants: list[str]
|
||||
|
||||
print(f"api key: {os.environ['OPENAI_API_KEY']}")
|
||||
litellm.api_key = os.environ["OPENAI_API_KEY"]
|
||||
response = completion(
|
||||
model="gpt-4o-2024-08-06",
|
||||
messages=[{"role": "user", "content": "Hey"}],
|
||||
max_tokens=10,
|
||||
metadata={"hi": "bye"},
|
||||
response_format=CalendarEvent,
|
||||
)
|
||||
class EventsList(BaseModel):
|
||||
events: list[CalendarEvent]
|
||||
|
||||
litellm.enable_json_schema_validation = True
|
||||
for _ in range(3):
|
||||
try:
|
||||
response = completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
metadata={"hi": "bye"},
|
||||
response_format=EventsList,
|
||||
)
|
||||
break
|
||||
except litellm.JSONSchemaValidationError:
|
||||
print("ERROR OCCURRED! INVALID JSON")
|
||||
|
||||
print("This is the response object\n", response)
|
||||
|
||||
response_str = response["choices"][0]["message"]["content"]
|
||||
response_str_2 = response.choices[0].message.content
|
||||
|
||||
cost = completion_cost(completion_response=response)
|
||||
print("Cost for completion call with gpt-3.5-turbo: ", f"${float(cost):.10f}")
|
||||
assert response_str == response_str_2
|
||||
assert type(response_str) == str
|
||||
assert len(response_str) > 1
|
||||
print(f"response_str: {response_str}")
|
||||
json.loads(response_str) # check valid json is returned
|
||||
|
||||
litellm.api_key = None
|
||||
except Timeout as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue