mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
Litellm dev 01 25 2025 p2 (#8003)
* fix(base_utils.py): supported nested json schema passed in for anthropic calls * refactor(base_utils.py): refactor ref parsing to prevent infinite loop * test(test_openai_endpoints.py): refactor anthropic test to use bedrock * fix(langfuse_prompt_management.py): add unit test for sync langfuse calls Resolves https://github.com/BerriAI/litellm/issues/7938#issuecomment-2613293757
This commit is contained in:
parent
a7b3c664d1
commit
08b124aeb6
12 changed files with 214 additions and 5 deletions
|
@ -11,6 +11,7 @@ from typing_extensions import TypeAlias
|
|||
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.integrations.prompt_management_base import PromptManagementClient
|
||||
from litellm.litellm_core_utils.asyncify import run_async_function
|
||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionSystemMessage
|
||||
from litellm.types.utils import StandardCallbackDynamicParams, StandardLoggingPayload
|
||||
|
||||
|
@ -231,6 +232,11 @@ class LangfusePromptManagement(LangFuseLogger, PromptManagementBase, CustomLogge
|
|||
completed_messages=None,
|
||||
)
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
return run_async_function(
|
||||
self.async_log_success_event, kwargs, response_obj, start_time, end_time
|
||||
)
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
standard_callback_dynamic_params = kwargs.get(
|
||||
"standard_callback_dynamic_params"
|
||||
|
|
|
@ -98,6 +98,7 @@ class AnthropicConfig(BaseConfig):
|
|||
def get_json_schema_from_pydantic_object(
|
||||
self, response_format: Union[Any, Dict, None]
|
||||
) -> Optional[dict]:
|
||||
|
||||
return type_to_response_format_param(
|
||||
response_format, ref_template="/$defs/{model}"
|
||||
) # Relevant issue: https://github.com/BerriAI/litellm/issues/7755
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import copy
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Type, Union
|
||||
|
||||
|
@ -31,6 +32,47 @@ class BaseLLMModelInfo(ABC):
|
|||
pass
|
||||
|
||||
|
||||
def _dict_to_response_format_helper(
|
||||
response_format: dict, ref_template: Optional[str] = None
|
||||
) -> dict:
|
||||
if ref_template is not None and response_format.get("type") == "json_schema":
|
||||
# Deep copy to avoid modifying original
|
||||
modified_format = copy.deepcopy(response_format)
|
||||
schema = modified_format["json_schema"]["schema"]
|
||||
|
||||
# Update all $ref values in the schema
|
||||
def update_refs(schema):
|
||||
stack = [(schema, [])]
|
||||
visited = set()
|
||||
|
||||
while stack:
|
||||
obj, path = stack.pop()
|
||||
obj_id = id(obj)
|
||||
|
||||
if obj_id in visited:
|
||||
continue
|
||||
visited.add(obj_id)
|
||||
|
||||
if isinstance(obj, dict):
|
||||
if "$ref" in obj:
|
||||
ref_path = obj["$ref"]
|
||||
model_name = ref_path.split("/")[-1]
|
||||
obj["$ref"] = ref_template.format(model=model_name)
|
||||
|
||||
for k, v in obj.items():
|
||||
if isinstance(v, (dict, list)):
|
||||
stack.append((v, path + [k]))
|
||||
|
||||
elif isinstance(obj, list):
|
||||
for i, item in enumerate(obj):
|
||||
if isinstance(item, (dict, list)):
|
||||
stack.append((item, path + [i]))
|
||||
|
||||
update_refs(schema)
|
||||
return modified_format
|
||||
return response_format
|
||||
|
||||
|
||||
def type_to_response_format_param(
|
||||
response_format: Optional[Union[Type[BaseModel], dict]],
|
||||
ref_template: Optional[str] = None,
|
||||
|
@ -44,7 +86,7 @@ def type_to_response_format_param(
|
|||
return None
|
||||
|
||||
if isinstance(response_format, dict):
|
||||
return response_format
|
||||
return _dict_to_response_format_helper(response_format, ref_template)
|
||||
|
||||
# type checkers don't narrow the negation of a `TypeGuard` as it isn't
|
||||
# a safe default behaviour but we know that at this point the `response_format`
|
||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -10,6 +10,9 @@ model_list:
|
|||
model: gpt-3.5-turbo
|
||||
timeout: 2
|
||||
num_retries: 0
|
||||
- model_name: anthropic-claude
|
||||
litellm_params:
|
||||
model: anthropic.claude-3-sonnet-20240229-v1:0
|
||||
|
||||
litellm_settings:
|
||||
callbacks: ["langsmith"]
|
||||
|
|
|
@ -108,6 +108,9 @@ model_list:
|
|||
litellm_params:
|
||||
model: "anthropic/*"
|
||||
api_key: os.environ/ANTHROPIC_API_KEY
|
||||
- model_name: "bedrock/*"
|
||||
litellm_params:
|
||||
model: "bedrock/*"
|
||||
- model_name: "groq/*"
|
||||
litellm_params:
|
||||
model: "groq/*"
|
||||
|
|
|
@ -312,6 +312,56 @@ class BaseLLMChatTest(ABC):
|
|||
except litellm.InternalServerError:
|
||||
pytest.skip("Model is overloaded")
|
||||
|
||||
@pytest.mark.flaky(retries=6, delay=1)
|
||||
def test_json_response_nested_json_schema(self):
|
||||
"""
|
||||
PROD Test: ensure nested json schema sent to proxy works as expected.
|
||||
"""
|
||||
from pydantic import BaseModel
|
||||
from litellm.utils import supports_response_schema
|
||||
from litellm.llms.base_llm.base_utils import type_to_response_format_param
|
||||
|
||||
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
|
||||
litellm.model_cost = litellm.get_model_cost_map(url="")
|
||||
|
||||
class CalendarEvent(BaseModel):
|
||||
name: str
|
||||
date: str
|
||||
participants: list[str]
|
||||
|
||||
class EventsList(BaseModel):
|
||||
events: list[CalendarEvent]
|
||||
|
||||
response_format = type_to_response_format_param(EventsList)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "List 5 important events in the XIX century"}
|
||||
]
|
||||
|
||||
base_completion_call_args = self.get_base_completion_call_args()
|
||||
if not supports_response_schema(base_completion_call_args["model"], None):
|
||||
pytest.skip(
|
||||
f"Model={base_completion_call_args['model']} does not support response schema"
|
||||
)
|
||||
|
||||
try:
|
||||
res = self.completion_function(
|
||||
**base_completion_call_args,
|
||||
messages=messages,
|
||||
response_format=response_format,
|
||||
timeout=60,
|
||||
)
|
||||
assert res is not None
|
||||
|
||||
print(res.choices[0].message)
|
||||
|
||||
assert res.choices[0].message.content is not None
|
||||
assert res.choices[0].message.tool_calls is None
|
||||
except litellm.Timeout:
|
||||
pytest.skip("Model took too long to respond")
|
||||
except litellm.InternalServerError:
|
||||
pytest.skip("Model is overloaded")
|
||||
|
||||
@pytest.mark.flaky(retries=6, delay=1)
|
||||
def test_json_response_format_stream(self):
|
||||
"""
|
||||
|
|
|
@ -1567,3 +1567,49 @@ async def test_wrapper_kwargs_passthrough():
|
|||
litellm_logging_obj.model_call_details["litellm_params"]["base_model"]
|
||||
== "gpt-4o-mini"
|
||||
)
|
||||
|
||||
|
||||
def test_dict_to_response_format_helper():
|
||||
from litellm.llms.base_llm.base_utils import _dict_to_response_format_helper
|
||||
|
||||
args = {
|
||||
"response_format": {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"schema": {
|
||||
"$defs": {
|
||||
"CalendarEvent": {
|
||||
"properties": {
|
||||
"name": {"title": "Name", "type": "string"},
|
||||
"date": {"title": "Date", "type": "string"},
|
||||
"participants": {
|
||||
"items": {"type": "string"},
|
||||
"title": "Participants",
|
||||
"type": "array",
|
||||
},
|
||||
},
|
||||
"required": ["name", "date", "participants"],
|
||||
"title": "CalendarEvent",
|
||||
"type": "object",
|
||||
"additionalProperties": False,
|
||||
}
|
||||
},
|
||||
"properties": {
|
||||
"events": {
|
||||
"items": {"$ref": "#/$defs/CalendarEvent"},
|
||||
"title": "Events",
|
||||
"type": "array",
|
||||
}
|
||||
},
|
||||
"required": ["events"],
|
||||
"title": "EventsList",
|
||||
"type": "object",
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"name": "EventsList",
|
||||
"strict": True,
|
||||
},
|
||||
},
|
||||
"ref_template": "/$defs/{model}",
|
||||
}
|
||||
_dict_to_response_format_helper(**args)
|
||||
|
|
|
@ -14,7 +14,7 @@ from litellm.integrations.langfuse.langfuse import (
|
|||
from litellm.integrations.langfuse.langfuse_handler import LangFuseHandler
|
||||
from litellm.litellm_core_utils.litellm_logging import DynamicLoggingCache
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from respx import MockRouter
|
||||
from litellm.types.utils import (
|
||||
StandardLoggingPayload,
|
||||
StandardLoggingModelInformation,
|
||||
|
@ -292,3 +292,31 @@ def test_get_langfuse_tags():
|
|||
mock_payload["request_tags"] = []
|
||||
result = global_langfuse_logger._get_langfuse_tags(mock_payload)
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_langfuse_e2e_sync(monkeypatch):
|
||||
from litellm import completion
|
||||
import litellm
|
||||
import respx
|
||||
import httpx
|
||||
import time
|
||||
|
||||
litellm._turn_on_debug()
|
||||
monkeypatch.setattr(litellm, "success_callback", ["langfuse"])
|
||||
|
||||
with respx.mock:
|
||||
# Mock Langfuse
|
||||
# Mock any Langfuse endpoint
|
||||
langfuse_mock = respx.post(
|
||||
"https://*.cloud.langfuse.com/api/public/ingestion"
|
||||
).mock(return_value=httpx.Response(200))
|
||||
completion(
|
||||
model="openai/my-fake-endpoint",
|
||||
messages=[{"role": "user", "content": "hello from litellm"}],
|
||||
stream=False,
|
||||
mock_response="Hello from litellm 2",
|
||||
)
|
||||
|
||||
time.sleep(3)
|
||||
|
||||
assert langfuse_mock.called
|
||||
|
|
|
@ -378,6 +378,39 @@ async def test_chat_completion_streaming():
|
|||
print(f"response_str: {response_str}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_anthropic_structured_output():
|
||||
"""
|
||||
Ensure nested pydantic output is returned correctly
|
||||
"""
|
||||
from pydantic import BaseModel
|
||||
|
||||
class CalendarEvent(BaseModel):
|
||||
name: str
|
||||
date: str
|
||||
participants: list[str]
|
||||
|
||||
class EventsList(BaseModel):
|
||||
events: list[CalendarEvent]
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "List 5 important events in the XIX century"}
|
||||
]
|
||||
|
||||
client = AsyncOpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000")
|
||||
|
||||
res = await client.beta.chat.completions.parse(
|
||||
model="bedrock/us.anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
messages=messages,
|
||||
response_format=EventsList,
|
||||
timeout=60,
|
||||
)
|
||||
message = res.choices[0].message
|
||||
|
||||
if message.parsed:
|
||||
print(message.parsed.events)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_old_key():
|
||||
"""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue