mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
fix(utils.py): support json schema validation
This commit is contained in:
parent
05dfc63b88
commit
b699d9a8b9
5 changed files with 216 additions and 12 deletions
|
@ -849,6 +849,7 @@ from .exceptions import (
|
||||||
APIResponseValidationError,
|
APIResponseValidationError,
|
||||||
UnprocessableEntityError,
|
UnprocessableEntityError,
|
||||||
InternalServerError,
|
InternalServerError,
|
||||||
|
JSONSchemaValidationError,
|
||||||
LITELLM_EXCEPTION_TYPES,
|
LITELLM_EXCEPTION_TYPES,
|
||||||
)
|
)
|
||||||
from .budget_manager import BudgetManager
|
from .budget_manager import BudgetManager
|
||||||
|
|
|
@ -551,7 +551,7 @@ class APIError(openai.APIError): # type: ignore
|
||||||
message,
|
message,
|
||||||
llm_provider,
|
llm_provider,
|
||||||
model,
|
model,
|
||||||
request: httpx.Request,
|
request: Optional[httpx.Request] = None,
|
||||||
litellm_debug_info: Optional[str] = None,
|
litellm_debug_info: Optional[str] = None,
|
||||||
max_retries: Optional[int] = None,
|
max_retries: Optional[int] = None,
|
||||||
num_retries: Optional[int] = None,
|
num_retries: Optional[int] = None,
|
||||||
|
@ -563,6 +563,8 @@ class APIError(openai.APIError): # type: ignore
|
||||||
self.litellm_debug_info = litellm_debug_info
|
self.litellm_debug_info = litellm_debug_info
|
||||||
self.max_retries = max_retries
|
self.max_retries = max_retries
|
||||||
self.num_retries = num_retries
|
self.num_retries = num_retries
|
||||||
|
if request is None:
|
||||||
|
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
||||||
super().__init__(self.message, request=request, body=None) # type: ignore
|
super().__init__(self.message, request=request, body=None) # type: ignore
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
@ -664,6 +666,22 @@ class OpenAIError(openai.OpenAIError): # type: ignore
|
||||||
self.llm_provider = "openai"
|
self.llm_provider = "openai"
|
||||||
|
|
||||||
|
|
||||||
|
class JSONSchemaValidationError(APIError):
|
||||||
|
def __init__(
|
||||||
|
self, model: str, llm_provider: str, raw_response: str, schema: str
|
||||||
|
) -> None:
|
||||||
|
self.raw_response = raw_response
|
||||||
|
self.schema = schema
|
||||||
|
self.model = model
|
||||||
|
message = "litellm.JSONSchemaValidationError: model={}, returned an invalid response={}, for schema={}.\nAccess raw response with `e.raw_response`".format(
|
||||||
|
model, raw_response, schema
|
||||||
|
)
|
||||||
|
self.message = message
|
||||||
|
super().__init__(
|
||||||
|
model=model, message=message, llm_provider=llm_provider, status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
LITELLM_EXCEPTION_TYPES = [
|
LITELLM_EXCEPTION_TYPES = [
|
||||||
AuthenticationError,
|
AuthenticationError,
|
||||||
NotFoundError,
|
NotFoundError,
|
||||||
|
@ -682,6 +700,7 @@ LITELLM_EXCEPTION_TYPES = [
|
||||||
APIResponseValidationError,
|
APIResponseValidationError,
|
||||||
OpenAIError,
|
OpenAIError,
|
||||||
InternalServerError,
|
InternalServerError,
|
||||||
|
JSONSchemaValidationError,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
23
litellm/litellm_core_utils/json_validation_rule.py
Normal file
23
litellm/litellm_core_utils/json_validation_rule.py
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
def validate_schema(schema: dict, response: str):
|
||||||
|
"""
|
||||||
|
Validate if the returned json response follows the schema.
|
||||||
|
|
||||||
|
Params:
|
||||||
|
- schema - dict: JSON schema
|
||||||
|
- response - str: Received json response as string.
|
||||||
|
"""
|
||||||
|
from jsonschema import ValidationError, validate
|
||||||
|
|
||||||
|
from litellm import JSONSchemaValidationError
|
||||||
|
|
||||||
|
response_dict = json.loads(response)
|
||||||
|
|
||||||
|
try:
|
||||||
|
validate(response_dict, schema=schema)
|
||||||
|
except ValidationError:
|
||||||
|
raise JSONSchemaValidationError(
|
||||||
|
model="", llm_provider="", raw_response=response, schema=json.dumps(schema)
|
||||||
|
)
|
|
@ -880,15 +880,133 @@ Using this JSON schema:
|
||||||
mock_call.assert_called_once()
|
mock_call.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def vertex_httpx_mock_post_valid_response(*args, **kwargs):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.headers = {"Content-Type": "application/json"}
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"candidates": [
|
||||||
|
{
|
||||||
|
"content": {
|
||||||
|
"role": "model",
|
||||||
|
"parts": [
|
||||||
|
{
|
||||||
|
"text": '[{"recipe_name": "Chocolate Chip Cookies"}, {"recipe_name": "Oatmeal Raisin Cookies"}, {"recipe_name": "Peanut Butter Cookies"}, {"recipe_name": "Sugar Cookies"}, {"recipe_name": "Snickerdoodles"}]\n'
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"finishReason": "STOP",
|
||||||
|
"safetyRatings": [
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||||
|
"probability": "NEGLIGIBLE",
|
||||||
|
"probabilityScore": 0.09790669,
|
||||||
|
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||||
|
"severityScore": 0.11736965,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||||
|
"probability": "NEGLIGIBLE",
|
||||||
|
"probabilityScore": 0.1261379,
|
||||||
|
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||||
|
"severityScore": 0.08601588,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_HARASSMENT",
|
||||||
|
"probability": "NEGLIGIBLE",
|
||||||
|
"probabilityScore": 0.083441176,
|
||||||
|
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||||
|
"severityScore": 0.0355444,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||||
|
"probability": "NEGLIGIBLE",
|
||||||
|
"probabilityScore": 0.071981624,
|
||||||
|
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||||
|
"severityScore": 0.08108212,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usageMetadata": {
|
||||||
|
"promptTokenCount": 60,
|
||||||
|
"candidatesTokenCount": 55,
|
||||||
|
"totalTokenCount": 115,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return mock_response
|
||||||
|
|
||||||
|
|
||||||
|
def vertex_httpx_mock_post_invalid_schema_response(*args, **kwargs):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.headers = {"Content-Type": "application/json"}
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"candidates": [
|
||||||
|
{
|
||||||
|
"content": {
|
||||||
|
"role": "model",
|
||||||
|
"parts": [
|
||||||
|
{"text": '[{"recipe_world": "Chocolate Chip Cookies"}]\n'}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"finishReason": "STOP",
|
||||||
|
"safetyRatings": [
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||||
|
"probability": "NEGLIGIBLE",
|
||||||
|
"probabilityScore": 0.09790669,
|
||||||
|
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||||
|
"severityScore": 0.11736965,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||||
|
"probability": "NEGLIGIBLE",
|
||||||
|
"probabilityScore": 0.1261379,
|
||||||
|
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||||
|
"severityScore": 0.08601588,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_HARASSMENT",
|
||||||
|
"probability": "NEGLIGIBLE",
|
||||||
|
"probabilityScore": 0.083441176,
|
||||||
|
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||||
|
"severityScore": 0.0355444,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||||
|
"probability": "NEGLIGIBLE",
|
||||||
|
"probabilityScore": 0.071981624,
|
||||||
|
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||||
|
"severityScore": 0.08108212,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usageMetadata": {
|
||||||
|
"promptTokenCount": 60,
|
||||||
|
"candidatesTokenCount": 55,
|
||||||
|
"totalTokenCount": 115,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return mock_response
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model, supports_response_schema",
|
"model, supports_response_schema",
|
||||||
[
|
[
|
||||||
("vertex_ai_beta/gemini-1.5-pro-001", True),
|
("vertex_ai_beta/gemini-1.5-pro-001", True),
|
||||||
("vertex_ai_beta/gemini-1.5-flash", False),
|
("vertex_ai_beta/gemini-1.5-flash", False),
|
||||||
],
|
],
|
||||||
) # "vertex_ai",
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"invalid_response",
|
||||||
|
[True, False],
|
||||||
|
)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_gemini_pro_json_schema_httpx(model, supports_response_schema):
|
async def test_gemini_pro_json_schema_args_sent_httpx(
|
||||||
|
model, supports_response_schema, invalid_response
|
||||||
|
):
|
||||||
load_vertex_ai_credentials()
|
load_vertex_ai_credentials()
|
||||||
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||||
litellm.model_cost = litellm.get_model_cost_map(url="")
|
litellm.model_cost = litellm.get_model_cost_map(url="")
|
||||||
|
@ -912,7 +1030,12 @@ async def test_gemini_pro_json_schema_httpx(model, supports_response_schema):
|
||||||
|
|
||||||
client = HTTPHandler()
|
client = HTTPHandler()
|
||||||
|
|
||||||
with patch.object(client, "post", new=MagicMock()) as mock_call:
|
httpx_response = MagicMock()
|
||||||
|
if invalid_response is True:
|
||||||
|
httpx_response.side_effect = vertex_httpx_mock_post_invalid_schema_response
|
||||||
|
else:
|
||||||
|
httpx_response.side_effect = vertex_httpx_mock_post_valid_response
|
||||||
|
with patch.object(client, "post", new=httpx_response) as mock_call:
|
||||||
try:
|
try:
|
||||||
_ = completion(
|
_ = completion(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -923,8 +1046,11 @@ async def test_gemini_pro_json_schema_httpx(model, supports_response_schema):
|
||||||
},
|
},
|
||||||
client=client,
|
client=client,
|
||||||
)
|
)
|
||||||
except Exception:
|
if invalid_response is True:
|
||||||
pass
|
pytest.fail("Expected this to fail")
|
||||||
|
except litellm.JSONSchemaValidationError as e:
|
||||||
|
if invalid_response is False:
|
||||||
|
pytest.fail("Expected this to pass. Got={}".format(e))
|
||||||
|
|
||||||
mock_call.assert_called_once()
|
mock_call.assert_called_once()
|
||||||
print(mock_call.call_args.kwargs)
|
print(mock_call.call_args.kwargs)
|
||||||
|
|
|
@ -48,6 +48,7 @@ from tokenizers import Tokenizer
|
||||||
import litellm
|
import litellm
|
||||||
import litellm._service_logger # for storing API inputs, outputs, and metadata
|
import litellm._service_logger # for storing API inputs, outputs, and metadata
|
||||||
import litellm.litellm_core_utils
|
import litellm.litellm_core_utils
|
||||||
|
import litellm.litellm_core_utils.json_validation_rule
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
from litellm.litellm_core_utils.llm_request_utils import _ensure_extra_body_is_safe
|
from litellm.litellm_core_utils.llm_request_utils import _ensure_extra_body_is_safe
|
||||||
|
@ -579,7 +580,7 @@ def client(original_function):
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def post_call_processing(original_response, model):
|
def post_call_processing(original_response, model, optional_params: Optional[dict]):
|
||||||
try:
|
try:
|
||||||
if original_response is None:
|
if original_response is None:
|
||||||
pass
|
pass
|
||||||
|
@ -594,11 +595,41 @@ def client(original_function):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
if isinstance(original_response, ModelResponse):
|
if isinstance(original_response, ModelResponse):
|
||||||
model_response = original_response.choices[
|
model_response: Optional[str] = original_response.choices[
|
||||||
0
|
0
|
||||||
].message.content
|
].message.content # type: ignore
|
||||||
### POST-CALL RULES ###
|
if model_response is not None:
|
||||||
rules_obj.post_call_rules(input=model_response, model=model)
|
### POST-CALL RULES ###
|
||||||
|
rules_obj.post_call_rules(
|
||||||
|
input=model_response, model=model
|
||||||
|
)
|
||||||
|
### JSON SCHEMA VALIDATION ###
|
||||||
|
if (
|
||||||
|
optional_params is not None
|
||||||
|
and "response_format" in optional_params
|
||||||
|
and isinstance(
|
||||||
|
optional_params["response_format"], dict
|
||||||
|
)
|
||||||
|
and "type" in optional_params["response_format"]
|
||||||
|
and optional_params["response_format"]["type"]
|
||||||
|
== "json_object"
|
||||||
|
and "response_schema"
|
||||||
|
in optional_params["response_format"]
|
||||||
|
and isinstance(
|
||||||
|
optional_params["response_format"][
|
||||||
|
"response_schema"
|
||||||
|
],
|
||||||
|
dict,
|
||||||
|
)
|
||||||
|
):
|
||||||
|
# schema given, json response expected
|
||||||
|
litellm.litellm_core_utils.json_validation_rule.validate_schema(
|
||||||
|
schema=optional_params["response_format"][
|
||||||
|
"response_schema"
|
||||||
|
],
|
||||||
|
response=model_response,
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
@ -867,7 +898,11 @@ def client(original_function):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
### POST-CALL RULES ###
|
### POST-CALL RULES ###
|
||||||
post_call_processing(original_response=result, model=model or None)
|
post_call_processing(
|
||||||
|
original_response=result,
|
||||||
|
model=model or None,
|
||||||
|
optional_params=kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
# [OPTIONAL] ADD TO CACHE
|
# [OPTIONAL] ADD TO CACHE
|
||||||
if (
|
if (
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue