mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
feat(utils.py): support passing response_format as pydantic model
Related issue - https://github.com/BerriAI/litellm/issues/5074
This commit is contained in:
parent
f3a0eb8eb9
commit
9cf3d5f568
3 changed files with 75 additions and 1 deletions
|
@ -608,7 +608,7 @@ def completion(
|
||||||
logit_bias: Optional[dict] = None,
|
logit_bias: Optional[dict] = None,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
# openai v1.0+ new params
|
# openai v1.0+ new params
|
||||||
response_format: Optional[dict] = None,
|
response_format: Optional[Union[dict, type[BaseModel]]] = None,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
tools: Optional[List] = None,
|
tools: Optional[List] = None,
|
||||||
tool_choice: Optional[Union[str, dict]] = None,
|
tool_choice: Optional[Union[str, dict]] = None,
|
||||||
|
|
|
@ -2123,6 +2123,43 @@ def test_completion_openai():
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_completion_openai_pydantic():
|
||||||
|
try:
|
||||||
|
litellm.set_verbose = True
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
class CalendarEvent(BaseModel):
|
||||||
|
name: str
|
||||||
|
date: str
|
||||||
|
participants: list[str]
|
||||||
|
|
||||||
|
print(f"api key: {os.environ['OPENAI_API_KEY']}")
|
||||||
|
litellm.api_key = os.environ["OPENAI_API_KEY"]
|
||||||
|
response = completion(
|
||||||
|
model="gpt-4o-2024-08-06",
|
||||||
|
messages=[{"role": "user", "content": "Hey"}],
|
||||||
|
max_tokens=10,
|
||||||
|
metadata={"hi": "bye"},
|
||||||
|
response_format=CalendarEvent,
|
||||||
|
)
|
||||||
|
print("This is the response object\n", response)
|
||||||
|
|
||||||
|
response_str = response["choices"][0]["message"]["content"]
|
||||||
|
response_str_2 = response.choices[0].message.content
|
||||||
|
|
||||||
|
cost = completion_cost(completion_response=response)
|
||||||
|
print("Cost for completion call with gpt-3.5-turbo: ", f"${float(cost):.10f}")
|
||||||
|
assert response_str == response_str_2
|
||||||
|
assert type(response_str) == str
|
||||||
|
assert len(response_str) > 1
|
||||||
|
|
||||||
|
litellm.api_key = None
|
||||||
|
except Timeout as e:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
def test_completion_openai_organization():
|
def test_completion_openai_organization():
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
|
|
@ -45,6 +45,8 @@ import requests
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from httpx import Proxy
|
from httpx import Proxy
|
||||||
from httpx._utils import get_environment_proxies
|
from httpx._utils import get_environment_proxies
|
||||||
|
from openai.lib import _parsing, _pydantic
|
||||||
|
from openai.types.chat.completion_create_params import ResponseFormat
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from tokenizers import Tokenizer
|
from tokenizers import Tokenizer
|
||||||
|
|
||||||
|
@ -2806,6 +2808,11 @@ def get_optional_params(
|
||||||
message=f"Function calling is not supported by {custom_llm_provider}.",
|
message=f"Function calling is not supported by {custom_llm_provider}.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if "response_format" in non_default_params:
|
||||||
|
non_default_params["response_format"] = type_to_response_format_param(
|
||||||
|
response_format=non_default_params["response_format"]
|
||||||
|
)
|
||||||
|
|
||||||
if "tools" in non_default_params and isinstance(
|
if "tools" in non_default_params and isinstance(
|
||||||
non_default_params, list
|
non_default_params, list
|
||||||
): # fixes https://github.com/BerriAI/litellm/issues/4933
|
): # fixes https://github.com/BerriAI/litellm/issues/4933
|
||||||
|
@ -6112,6 +6119,36 @@ def _should_retry(status_code: int):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def type_to_response_format_param(
|
||||||
|
response_format: Optional[Union[type[BaseModel], dict]],
|
||||||
|
) -> Optional[dict]:
|
||||||
|
"""
|
||||||
|
Re-implementation of openai's 'type_to_response_format_param' function
|
||||||
|
|
||||||
|
Used for converting pydantic object to api schema.
|
||||||
|
"""
|
||||||
|
if response_format is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if isinstance(response_format, dict):
|
||||||
|
return response_format
|
||||||
|
|
||||||
|
# type checkers don't narrow the negation of a `TypeGuard` as it isn't
|
||||||
|
# a safe default behaviour but we know that at this point the `response_format`
|
||||||
|
# can only be a `type`
|
||||||
|
if not _parsing._completions.is_basemodel_type(response_format):
|
||||||
|
raise TypeError(f"Unsupported response_format type - {response_format}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": "json_schema",
|
||||||
|
"json_schema": {
|
||||||
|
"schema": _pydantic.to_strict_json_schema(response_format),
|
||||||
|
"name": response_format.__name__,
|
||||||
|
"strict": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def _get_retry_after_from_exception_header(
|
def _get_retry_after_from_exception_header(
|
||||||
response_headers: Optional[httpx.Headers] = None,
|
response_headers: Optional[httpx.Headers] = None,
|
||||||
):
|
):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue