From 2b132c6befdd2a863c81ab8fbf00df45d2f52bcc Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 6 Aug 2024 18:16:07 -0700 Subject: [PATCH] feat(utils.py): support passing response_format as pydantic model Related issue - https://github.com/BerriAI/litellm/issues/5074 --- litellm/main.py | 2 +- litellm/tests/test_completion.py | 37 ++++++++++++++++++++++++++++++++ litellm/utils.py | 37 ++++++++++++++++++++++++++++++++ 3 files changed, 75 insertions(+), 1 deletion(-) diff --git a/litellm/main.py b/litellm/main.py index 1209306c8b..01e3d2f953 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -608,7 +608,7 @@ def completion( logit_bias: Optional[dict] = None, user: Optional[str] = None, # openai v1.0+ new params - response_format: Optional[dict] = None, + response_format: Optional[Union[dict, type[BaseModel]]] = None, seed: Optional[int] = None, tools: Optional[List] = None, tool_choice: Optional[Union[str, dict]] = None, diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index eec163f26a..04b260c2e8 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -2123,6 +2123,43 @@ def test_completion_openai(): pytest.fail(f"Error occurred: {e}") +def test_completion_openai_pydantic(): + try: + litellm.set_verbose = True + from pydantic import BaseModel + + class CalendarEvent(BaseModel): + name: str + date: str + participants: list[str] + + print(f"api key: {os.environ['OPENAI_API_KEY']}") + litellm.api_key = os.environ["OPENAI_API_KEY"] + response = completion( + model="gpt-4o-2024-08-06", + messages=[{"role": "user", "content": "Hey"}], + max_tokens=10, + metadata={"hi": "bye"}, + response_format=CalendarEvent, + ) + print("This is the response object\n", response) + + response_str = response["choices"][0]["message"]["content"] + response_str_2 = response.choices[0].message.content + + cost = completion_cost(completion_response=response) + print("Cost for completion call with gpt-3.5-turbo: ", f"${float(cost):.10f}") + assert response_str == response_str_2 + assert type(response_str) == str + assert len(response_str) > 1 + + litellm.api_key = None + except Timeout as e: + pass + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + def test_completion_openai_organization(): try: litellm.set_verbose = True diff --git a/litellm/utils.py b/litellm/utils.py index 20beb47dc2..ed155ab143 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -45,6 +45,8 @@ import requests import tiktoken from httpx import Proxy from httpx._utils import get_environment_proxies +from openai.lib import _parsing, _pydantic +from openai.types.chat.completion_create_params import ResponseFormat from pydantic import BaseModel from tokenizers import Tokenizer @@ -2806,6 +2808,11 @@ def get_optional_params( message=f"Function calling is not supported by {custom_llm_provider}.", ) + if "response_format" in non_default_params: + non_default_params["response_format"] = type_to_response_format_param( + response_format=non_default_params["response_format"] + ) + if "tools" in non_default_params and isinstance( non_default_params, list ): # fixes https://github.com/BerriAI/litellm/issues/4933 @@ -6112,6 +6119,36 @@ def _should_retry(status_code: int): return False +def type_to_response_format_param( + response_format: Optional[Union[type[BaseModel], dict]], +) -> Optional[dict]: + """ + Re-implementation of openai's 'type_to_response_format_param' function + + Used for converting pydantic object to api schema. + """ + if response_format is None: + return None + + if isinstance(response_format, dict): + return response_format + + # type checkers don't narrow the negation of a `TypeGuard` as it isn't + # a safe default behaviour but we know that at this point the `response_format` + # can only be a `type` + if not _parsing._completions.is_basemodel_type(response_format): + raise TypeError(f"Unsupported response_format type - {response_format}") + + return { + "type": "json_schema", + "json_schema": { + "schema": _pydantic.to_strict_json_schema(response_format), + "name": response_format.__name__, + "strict": True, + }, + } + + def _get_retry_after_from_exception_header( response_headers: Optional[httpx.Headers] = None, ):