fixes to test

This commit is contained in:
Ishaan Jaff 2025-04-17 22:08:31 -07:00
parent 4e2b04a1e0
commit e8b9b4f68b
4 changed files with 78 additions and 24 deletions

View file

@ -9,10 +9,12 @@ from litellm.types.llms.openai import (
ChatCompletionSystemMessage, ChatCompletionSystemMessage,
ChatCompletionUserMessage, ChatCompletionUserMessage,
GenericChatCompletionMessage, GenericChatCompletionMessage,
Reasoning,
ResponseAPIUsage, ResponseAPIUsage,
ResponseInputParam, ResponseInputParam,
ResponsesAPIOptionalRequestParams, ResponsesAPIOptionalRequestParams,
ResponsesAPIResponse, ResponsesAPIResponse,
ResponseTextConfig,
) )
from litellm.types.responses.main import GenericResponseOutputItem, OutputText from litellm.types.responses.main import GenericResponseOutputItem, OutputText
from litellm.types.utils import Choices, Message, ModelResponse, Usage from litellm.types.utils import Choices, Message, ModelResponse, Usage
@ -31,7 +33,7 @@ class LiteLLMCompletionResponsesConfig:
""" """
Transform a Responses API request into a Chat Completion request Transform a Responses API request into a Chat Completion request
""" """
return { litellm_completion_request: dict = {
"messages": LiteLLMCompletionResponsesConfig.transform_responses_api_input_to_messages( "messages": LiteLLMCompletionResponsesConfig.transform_responses_api_input_to_messages(
input=input, input=input,
responses_api_request=responses_api_request, responses_api_request=responses_api_request,
@ -45,10 +47,17 @@ class LiteLLMCompletionResponsesConfig:
"parallel_tool_calls": responses_api_request.get("parallel_tool_calls"), "parallel_tool_calls": responses_api_request.get("parallel_tool_calls"),
"max_tokens": responses_api_request.get("max_output_tokens"), "max_tokens": responses_api_request.get("max_output_tokens"),
"stream": kwargs.get("stream", False), "stream": kwargs.get("stream", False),
"metadata": kwargs.get("metadata", {}), "metadata": kwargs.get("metadata"),
"service_tier": kwargs.get("service_tier", ""), "service_tier": kwargs.get("service_tier"),
} }
# only pass non-None values
litellm_completion_request = {
k: v for k, v in litellm_completion_request.items() if v is not None
}
return litellm_completion_request
@staticmethod @staticmethod
def transform_responses_api_input_to_messages( def transform_responses_api_input_to_messages(
input: Union[str, ResponseInputParam], input: Union[str, ResponseInputParam],
@ -148,7 +157,7 @@ class LiteLLMCompletionResponsesConfig:
chat_completion_response, "incomplete_details", None chat_completion_response, "incomplete_details", None
), ),
instructions=getattr(chat_completion_response, "instructions", None), instructions=getattr(chat_completion_response, "instructions", None),
metadata=getattr(chat_completion_response, "metadata", None), metadata=getattr(chat_completion_response, "metadata", {}),
output=LiteLLMCompletionResponsesConfig._transform_chat_completion_choices_to_responses_output( output=LiteLLMCompletionResponsesConfig._transform_chat_completion_choices_to_responses_output(
chat_completion_response=chat_completion_response, chat_completion_response=chat_completion_response,
choices=getattr(chat_completion_response, "choices", []), choices=getattr(chat_completion_response, "choices", []),
@ -156,7 +165,7 @@ class LiteLLMCompletionResponsesConfig:
parallel_tool_calls=getattr( parallel_tool_calls=getattr(
chat_completion_response, "parallel_tool_calls", False chat_completion_response, "parallel_tool_calls", False
), ),
temperature=getattr(chat_completion_response, "temperature", None), temperature=getattr(chat_completion_response, "temperature", 0),
tool_choice=getattr(chat_completion_response, "tool_choice", "auto"), tool_choice=getattr(chat_completion_response, "tool_choice", "auto"),
tools=getattr(chat_completion_response, "tools", []), tools=getattr(chat_completion_response, "tools", []),
top_p=getattr(chat_completion_response, "top_p", None), top_p=getattr(chat_completion_response, "top_p", None),
@ -166,11 +175,13 @@ class LiteLLMCompletionResponsesConfig:
previous_response_id=getattr( previous_response_id=getattr(
chat_completion_response, "previous_response_id", None chat_completion_response, "previous_response_id", None
), ),
reasoning=getattr(chat_completion_response, "reasoning", None), reasoning=Reasoning(),
status=getattr(chat_completion_response, "status", None), status=getattr(chat_completion_response, "status", "completed"),
text=getattr(chat_completion_response, "text", None), text=ResponseTextConfig(),
truncation=getattr(chat_completion_response, "truncation", None), truncation=getattr(chat_completion_response, "truncation", None),
usage=getattr(chat_completion_response, "usage", None), usage=LiteLLMCompletionResponsesConfig._transform_chat_completion_usage_to_responses_usage(
chat_completion_response=chat_completion_response
),
user=getattr(chat_completion_response, "user", None), user=getattr(chat_completion_response, "user", None),
) )
@ -206,8 +217,15 @@ class LiteLLMCompletionResponsesConfig:
@staticmethod @staticmethod
def _transform_chat_completion_usage_to_responses_usage( def _transform_chat_completion_usage_to_responses_usage(
usage: Usage, chat_completion_response: ModelResponse,
) -> ResponseAPIUsage: ) -> ResponseAPIUsage:
usage: Optional[Usage] = getattr(chat_completion_response, "usage", None)
if usage is None:
return ResponseAPIUsage(
input_tokens=0,
output_tokens=0,
total_tokens=0,
)
return ResponseAPIUsage( return ResponseAPIUsage(
input_tokens=usage.prompt_tokens, input_tokens=usage.prompt_tokens,
output_tokens=usage.completion_tokens, output_tokens=usage.completion_tokens,

View file

@ -10,6 +10,9 @@ from litellm.constants import request_timeout
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
from litellm.responses.litellm_completion_transformation.handler import (
LiteLLMCompletionTransformationHandler,
)
from litellm.responses.utils import ResponsesAPIRequestUtils from litellm.responses.utils import ResponsesAPIRequestUtils
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
Reasoning, Reasoning,
@ -29,6 +32,7 @@ from .streaming_iterator import BaseResponsesAPIStreamingIterator
####### ENVIRONMENT VARIABLES ################### ####### ENVIRONMENT VARIABLES ###################
# Initialize any necessary instances or variables here # Initialize any necessary instances or variables here
base_llm_http_handler = BaseLLMHTTPHandler() base_llm_http_handler = BaseLLMHTTPHandler()
litellm_completion_transformation_handler = LiteLLMCompletionTransformationHandler()
################################################# #################################################
@ -178,19 +182,12 @@ def responses(
) )
# get provider config # get provider config
responses_api_provider_config: Optional[ responses_api_provider_config: Optional[BaseResponsesAPIConfig] = (
BaseResponsesAPIConfig ProviderConfigManager.get_provider_responses_api_config(
] = ProviderConfigManager.get_provider_responses_api_config(
model=model,
provider=litellm.LlmProviders(custom_llm_provider),
)
if responses_api_provider_config is None:
raise litellm.BadRequestError(
model=model, model=model,
llm_provider=custom_llm_provider, provider=litellm.LlmProviders(custom_llm_provider),
message=f"Responses API not available for custom_llm_provider={custom_llm_provider}, model: {model}",
) )
)
local_vars.update(kwargs) local_vars.update(kwargs)
# Get ResponsesAPIOptionalRequestParams with only valid parameters # Get ResponsesAPIOptionalRequestParams with only valid parameters
@ -200,6 +197,16 @@ def responses(
) )
) )
if responses_api_provider_config is None:
return litellm_completion_transformation_handler.response_api_handler(
model=model,
input=input,
responses_api_request=ResponsesAPIOptionalRequestParams(),
custom_llm_provider=custom_llm_provider,
_is_async=_is_async,
**kwargs,
)
# Get optional parameters for the responses API # Get optional parameters for the responses API
responses_api_request_params: Dict = ( responses_api_request_params: Dict = (
ResponsesAPIRequestUtils.get_optional_params_responses_api( ResponsesAPIRequestUtils.get_optional_params_responses_api(

View file

@ -68,16 +68,16 @@ def validate_responses_api_response(response, final_chunk: bool = False):
"metadata": dict, "metadata": dict,
"model": str, "model": str,
"object": str, "object": str,
"temperature": (int, float), "temperature": (int, float, type(None)),
"tool_choice": (dict, str), "tool_choice": (dict, str),
"tools": list, "tools": list,
"top_p": (int, float), "top_p": (int, float, type(None)),
"max_output_tokens": (int, type(None)), "max_output_tokens": (int, type(None)),
"previous_response_id": (str, type(None)), "previous_response_id": (str, type(None)),
"reasoning": dict, "reasoning": dict,
"status": str, "status": str,
"text": ResponseTextConfig, "text": ResponseTextConfig,
"truncation": str, "truncation": (str, type(None)),
"usage": ResponseAPIUsage, "usage": ResponseAPIUsage,
"user": (str, type(None)), "user": (str, type(None)),
} }

View file

@ -0,0 +1,29 @@
import os
import sys
import pytest
import asyncio
from typing import Optional
from unittest.mock import patch, AsyncMock
sys.path.insert(0, os.path.abspath("../.."))
import litellm
from litellm.integrations.custom_logger import CustomLogger
import json
from litellm.types.utils import StandardLoggingPayload
from litellm.types.llms.openai import (
ResponseCompletedEvent,
ResponsesAPIResponse,
ResponseTextConfig,
ResponseAPIUsage,
IncompleteDetails,
)
import litellm
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from base_responses_api import BaseResponsesAPITest
class TestAnthropicResponsesAPITest(BaseResponsesAPITest):
def get_base_completion_call_args(self):
#litellm._turn_on_debug()
return {
"model": "anthropic/claude-3-5-sonnet-latest",
}