mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
fix(utils.py): support json schema validation
This commit is contained in:
parent
05dfc63b88
commit
b699d9a8b9
5 changed files with 216 additions and 12 deletions
|
@ -849,6 +849,7 @@ from .exceptions import (
|
|||
APIResponseValidationError,
|
||||
UnprocessableEntityError,
|
||||
InternalServerError,
|
||||
JSONSchemaValidationError,
|
||||
LITELLM_EXCEPTION_TYPES,
|
||||
)
|
||||
from .budget_manager import BudgetManager
|
||||
|
|
|
@ -551,7 +551,7 @@ class APIError(openai.APIError): # type: ignore
|
|||
message,
|
||||
llm_provider,
|
||||
model,
|
||||
request: httpx.Request,
|
||||
request: Optional[httpx.Request] = None,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
num_retries: Optional[int] = None,
|
||||
|
@ -563,6 +563,8 @@ class APIError(openai.APIError): # type: ignore
|
|||
self.litellm_debug_info = litellm_debug_info
|
||||
self.max_retries = max_retries
|
||||
self.num_retries = num_retries
|
||||
if request is None:
|
||||
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
||||
super().__init__(self.message, request=request, body=None) # type: ignore
|
||||
|
||||
def __str__(self):
|
||||
|
@ -664,6 +666,22 @@ class OpenAIError(openai.OpenAIError): # type: ignore
|
|||
self.llm_provider = "openai"
|
||||
|
||||
|
||||
class JSONSchemaValidationError(APIError):
|
||||
def __init__(
|
||||
self, model: str, llm_provider: str, raw_response: str, schema: str
|
||||
) -> None:
|
||||
self.raw_response = raw_response
|
||||
self.schema = schema
|
||||
self.model = model
|
||||
message = "litellm.JSONSchemaValidationError: model={}, returned an invalid response={}, for schema={}.\nAccess raw response with `e.raw_response`".format(
|
||||
model, raw_response, schema
|
||||
)
|
||||
self.message = message
|
||||
super().__init__(
|
||||
model=model, message=message, llm_provider=llm_provider, status_code=500
|
||||
)
|
||||
|
||||
|
||||
LITELLM_EXCEPTION_TYPES = [
|
||||
AuthenticationError,
|
||||
NotFoundError,
|
||||
|
@ -682,6 +700,7 @@ LITELLM_EXCEPTION_TYPES = [
|
|||
APIResponseValidationError,
|
||||
OpenAIError,
|
||||
InternalServerError,
|
||||
JSONSchemaValidationError,
|
||||
]
|
||||
|
||||
|
||||
|
|
23
litellm/litellm_core_utils/json_validation_rule.py
Normal file
23
litellm/litellm_core_utils/json_validation_rule.py
Normal file
|
@ -0,0 +1,23 @@
|
|||
import json
|
||||
|
||||
|
||||
def validate_schema(schema: dict, response: str):
|
||||
"""
|
||||
Validate if the returned json response follows the schema.
|
||||
|
||||
Params:
|
||||
- schema - dict: JSON schema
|
||||
- response - str: Received json response as string.
|
||||
"""
|
||||
from jsonschema import ValidationError, validate
|
||||
|
||||
from litellm import JSONSchemaValidationError
|
||||
|
||||
response_dict = json.loads(response)
|
||||
|
||||
try:
|
||||
validate(response_dict, schema=schema)
|
||||
except ValidationError:
|
||||
raise JSONSchemaValidationError(
|
||||
model="", llm_provider="", raw_response=response, schema=json.dumps(schema)
|
||||
)
|
|
@ -880,15 +880,133 @@ Using this JSON schema:
|
|||
mock_call.assert_called_once()
|
||||
|
||||
|
||||
def vertex_httpx_mock_post_valid_response(*args, **kwargs):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"Content-Type": "application/json"}
|
||||
mock_response.json.return_value = {
|
||||
"candidates": [
|
||||
{
|
||||
"content": {
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{
|
||||
"text": '[{"recipe_name": "Chocolate Chip Cookies"}, {"recipe_name": "Oatmeal Raisin Cookies"}, {"recipe_name": "Peanut Butter Cookies"}, {"recipe_name": "Sugar Cookies"}, {"recipe_name": "Snickerdoodles"}]\n'
|
||||
}
|
||||
],
|
||||
},
|
||||
"finishReason": "STOP",
|
||||
"safetyRatings": [
|
||||
{
|
||||
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"probabilityScore": 0.09790669,
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||
"severityScore": 0.11736965,
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"probabilityScore": 0.1261379,
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||
"severityScore": 0.08601588,
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_HARASSMENT",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"probabilityScore": 0.083441176,
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||
"severityScore": 0.0355444,
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"probabilityScore": 0.071981624,
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||
"severityScore": 0.08108212,
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 60,
|
||||
"candidatesTokenCount": 55,
|
||||
"totalTokenCount": 115,
|
||||
},
|
||||
}
|
||||
return mock_response
|
||||
|
||||
|
||||
def vertex_httpx_mock_post_invalid_schema_response(*args, **kwargs):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"Content-Type": "application/json"}
|
||||
mock_response.json.return_value = {
|
||||
"candidates": [
|
||||
{
|
||||
"content": {
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{"text": '[{"recipe_world": "Chocolate Chip Cookies"}]\n'}
|
||||
],
|
||||
},
|
||||
"finishReason": "STOP",
|
||||
"safetyRatings": [
|
||||
{
|
||||
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"probabilityScore": 0.09790669,
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||
"severityScore": 0.11736965,
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"probabilityScore": 0.1261379,
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||
"severityScore": 0.08601588,
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_HARASSMENT",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"probabilityScore": 0.083441176,
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||
"severityScore": 0.0355444,
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"probabilityScore": 0.071981624,
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||
"severityScore": 0.08108212,
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 60,
|
||||
"candidatesTokenCount": 55,
|
||||
"totalTokenCount": 115,
|
||||
},
|
||||
}
|
||||
return mock_response
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, supports_response_schema",
|
||||
[
|
||||
("vertex_ai_beta/gemini-1.5-pro-001", True),
|
||||
("vertex_ai_beta/gemini-1.5-flash", False),
|
||||
],
|
||||
) # "vertex_ai",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_response",
|
||||
[True, False],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_pro_json_schema_httpx(model, supports_response_schema):
|
||||
async def test_gemini_pro_json_schema_args_sent_httpx(
|
||||
model, supports_response_schema, invalid_response
|
||||
):
|
||||
load_vertex_ai_credentials()
|
||||
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||
litellm.model_cost = litellm.get_model_cost_map(url="")
|
||||
|
@ -912,7 +1030,12 @@ async def test_gemini_pro_json_schema_httpx(model, supports_response_schema):
|
|||
|
||||
client = HTTPHandler()
|
||||
|
||||
with patch.object(client, "post", new=MagicMock()) as mock_call:
|
||||
httpx_response = MagicMock()
|
||||
if invalid_response is True:
|
||||
httpx_response.side_effect = vertex_httpx_mock_post_invalid_schema_response
|
||||
else:
|
||||
httpx_response.side_effect = vertex_httpx_mock_post_valid_response
|
||||
with patch.object(client, "post", new=httpx_response) as mock_call:
|
||||
try:
|
||||
_ = completion(
|
||||
model=model,
|
||||
|
@ -923,8 +1046,11 @@ async def test_gemini_pro_json_schema_httpx(model, supports_response_schema):
|
|||
},
|
||||
client=client,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
if invalid_response is True:
|
||||
pytest.fail("Expected this to fail")
|
||||
except litellm.JSONSchemaValidationError as e:
|
||||
if invalid_response is False:
|
||||
pytest.fail("Expected this to pass. Got={}".format(e))
|
||||
|
||||
mock_call.assert_called_once()
|
||||
print(mock_call.call_args.kwargs)
|
||||
|
|
|
@ -48,6 +48,7 @@ from tokenizers import Tokenizer
|
|||
import litellm
|
||||
import litellm._service_logger # for storing API inputs, outputs, and metadata
|
||||
import litellm.litellm_core_utils
|
||||
import litellm.litellm_core_utils.json_validation_rule
|
||||
from litellm.caching import DualCache
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.litellm_core_utils.llm_request_utils import _ensure_extra_body_is_safe
|
||||
|
@ -579,7 +580,7 @@ def client(original_function):
|
|||
else:
|
||||
return False
|
||||
|
||||
def post_call_processing(original_response, model):
|
||||
def post_call_processing(original_response, model, optional_params: Optional[dict]):
|
||||
try:
|
||||
if original_response is None:
|
||||
pass
|
||||
|
@ -594,11 +595,41 @@ def client(original_function):
|
|||
pass
|
||||
else:
|
||||
if isinstance(original_response, ModelResponse):
|
||||
model_response = original_response.choices[
|
||||
model_response: Optional[str] = original_response.choices[
|
||||
0
|
||||
].message.content
|
||||
].message.content # type: ignore
|
||||
if model_response is not None:
|
||||
### POST-CALL RULES ###
|
||||
rules_obj.post_call_rules(input=model_response, model=model)
|
||||
rules_obj.post_call_rules(
|
||||
input=model_response, model=model
|
||||
)
|
||||
### JSON SCHEMA VALIDATION ###
|
||||
if (
|
||||
optional_params is not None
|
||||
and "response_format" in optional_params
|
||||
and isinstance(
|
||||
optional_params["response_format"], dict
|
||||
)
|
||||
and "type" in optional_params["response_format"]
|
||||
and optional_params["response_format"]["type"]
|
||||
== "json_object"
|
||||
and "response_schema"
|
||||
in optional_params["response_format"]
|
||||
and isinstance(
|
||||
optional_params["response_format"][
|
||||
"response_schema"
|
||||
],
|
||||
dict,
|
||||
)
|
||||
):
|
||||
# schema given, json response expected
|
||||
litellm.litellm_core_utils.json_validation_rule.validate_schema(
|
||||
schema=optional_params["response_format"][
|
||||
"response_schema"
|
||||
],
|
||||
response=model_response,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
@ -867,7 +898,11 @@ def client(original_function):
|
|||
return result
|
||||
|
||||
### POST-CALL RULES ###
|
||||
post_call_processing(original_response=result, model=model or None)
|
||||
post_call_processing(
|
||||
original_response=result,
|
||||
model=model or None,
|
||||
optional_params=kwargs,
|
||||
)
|
||||
|
||||
# [OPTIONAL] ADD TO CACHE
|
||||
if (
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue