diff --git a/litellm/litellm_core_utils/streaming_handler.py b/litellm/litellm_core_utils/streaming_handler.py index ba8cb167c8..597b03ac29 100644 --- a/litellm/litellm_core_utils/streaming_handler.py +++ b/litellm/litellm_core_utils/streaming_handler.py @@ -471,6 +471,7 @@ class CustomStreamWrapper: finish_reason = None logprobs = None usage = None + if str_line and str_line.choices and len(str_line.choices) > 0: if ( str_line.choices[0].delta is not None @@ -750,6 +751,7 @@ class CustomStreamWrapper: "function_call" in completion_obj and completion_obj["function_call"] is not None ) + or (model_response.choices[0].delta.provider_specific_fields is not None) or ( "provider_specific_fields" in response_obj and response_obj["provider_specific_fields"] is not None diff --git a/litellm/main.py b/litellm/main.py index 82fa65eefa..0056f4751d 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -383,6 +383,10 @@ async def acompletion( - If `stream` is True, the function returns an async generator that yields completion lines. """ fallbacks = kwargs.get("fallbacks", None) + mock_timeout = kwargs.get("mock_timeout", None) + + if mock_timeout is True: + await _handle_mock_timeout_async(mock_timeout, timeout, model) loop = asyncio.get_event_loop() custom_llm_provider = kwargs.get("custom_llm_provider", None) @@ -565,12 +569,7 @@ def _handle_mock_timeout( model: str, ): if mock_timeout is True and timeout is not None: - if isinstance(timeout, float): - time.sleep(timeout) - elif isinstance(timeout, str): - time.sleep(float(timeout)) - elif isinstance(timeout, httpx.Timeout) and timeout.connect is not None: - time.sleep(timeout.connect) + _sleep_for_timeout(timeout) raise litellm.Timeout( message="This is a mock timeout error", llm_provider="openai", @@ -578,6 +577,38 @@ def _handle_mock_timeout( ) +async def _handle_mock_timeout_async( + mock_timeout: Optional[bool], + timeout: Optional[Union[float, str, httpx.Timeout]], + model: str, +): + if mock_timeout is True and timeout is not None: + await _sleep_for_timeout_async(timeout) + raise litellm.Timeout( + message="This is a mock timeout error", + llm_provider="openai", + model=model, + ) + + +def _sleep_for_timeout(timeout: Union[float, str, httpx.Timeout]): + if isinstance(timeout, float): + time.sleep(timeout) + elif isinstance(timeout, str): + time.sleep(float(timeout)) + elif isinstance(timeout, httpx.Timeout) and timeout.connect is not None: + time.sleep(timeout.connect) + + +async def _sleep_for_timeout_async(timeout: Union[float, str, httpx.Timeout]): + if isinstance(timeout, float): + await asyncio.sleep(timeout) + elif isinstance(timeout, str): + await asyncio.sleep(float(timeout)) + elif isinstance(timeout, httpx.Timeout) and timeout.connect is not None: + await asyncio.sleep(timeout.connect) + + def mock_completion( model: str, messages: List, diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 527b395168..983525f495 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -15,4 +15,4 @@ model_list: model: anthropic.claude-3-sonnet-20240229-v1:0 litellm_settings: - callbacks: ["langsmith"] + callbacks: ["langsmith"] \ No newline at end of file diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 09f0f864d2..b2b198a4ff 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -487,6 +487,8 @@ class Message(OpenAIObject): if provider_specific_fields: # set if provider_specific_fields is not empty self.provider_specific_fields = provider_specific_fields + for k, v in provider_specific_fields.items(): + setattr(self, k, v) def get(self, key, default=None): # Custom .get() method to access attributes with a default value if the attribute doesn't exist @@ -522,18 +524,18 @@ class Delta(OpenAIObject): audio: Optional[ChatCompletionAudioResponse] = None, **params, ): + super(Delta, self).__init__(**params) provider_specific_fields: Dict[str, Any] = {} if "reasoning_content" in params: provider_specific_fields["reasoning_content"] = params["reasoning_content"] - del params["reasoning_content"] - super(Delta, self).__init__(**params) + setattr(self, "reasoning_content", params["reasoning_content"]) self.content = content self.role = role - self.provider_specific_fields = provider_specific_fields # Set default values and correct types self.function_call: Optional[Union[FunctionCall, Any]] = None self.tool_calls: Optional[List[Union[ChatCompletionDeltaToolCall, Any]]] = None self.audio: Optional[ChatCompletionAudioResponse] = None + if provider_specific_fields: # set if provider_specific_fields is not empty self.provider_specific_fields = provider_specific_fields @@ -801,6 +803,7 @@ class StreamingChatCompletionChunk(OpenAIChatCompletionChunk): new_choice = StreamingChoices(**choice).model_dump() new_choices.append(new_choice) kwargs["choices"] = new_choices + super().__init__(**kwargs) diff --git a/proxy_server_config.yaml b/proxy_server_config.yaml index bb178ed97c..e9188482f7 100644 --- a/proxy_server_config.yaml +++ b/proxy_server_config.yaml @@ -97,6 +97,14 @@ model_list: rpm: 1000 model_info: health_check_timeout: 1 + - model_name: good-model + litellm_params: + model: openai/bad-model + api_key: os.environ/OPENAI_API_KEY + api_base: https://exampleopenaiendpoint-production.up.railway.app/ + rpm: 1000 + model_info: + health_check_timeout: 1 - model_name: "*" litellm_params: model: openai/* diff --git a/tests/local_testing/test_completion.py b/tests/local_testing/test_completion.py index ef90d56f70..1a3f90a0bc 100644 --- a/tests/local_testing/test_completion.py +++ b/tests/local_testing/test_completion.py @@ -4546,10 +4546,7 @@ def test_deepseek_reasoning_content_completion(): messages=[{"role": "user", "content": "Tell me a joke."}], ) - assert ( - resp.choices[0].message.provider_specific_fields["reasoning_content"] - is not None - ) + assert resp.choices[0].message.reasoning_content is not None @pytest.mark.parametrize( diff --git a/tests/local_testing/test_streaming.py b/tests/local_testing/test_streaming.py index 06e2b9156d..527742b325 100644 --- a/tests/local_testing/test_streaming.py +++ b/tests/local_testing/test_streaming.py @@ -4066,7 +4066,7 @@ def test_mock_response_iterator_tool_use(): def test_deepseek_reasoning_content_completion(): - litellm.set_verbose = True + # litellm.set_verbose = True resp = litellm.completion( model="deepseek/deepseek-reasoner", messages=[{"role": "user", "content": "Tell me a joke."}], @@ -4076,8 +4076,7 @@ def test_deepseek_reasoning_content_completion(): reasoning_content_exists = False for chunk in resp: print(f"chunk: {chunk}") - if chunk.choices[0].delta.content is not None: - if "reasoning_content" in chunk.choices[0].delta.provider_specific_fields: - reasoning_content_exists = True - break + if chunk.choices[0].delta.reasoning_content is not None: + reasoning_content_exists = True + break assert reasoning_content_exists diff --git a/tests/test_fallbacks.py b/tests/test_fallbacks.py index 2f39d5e985..91c90448b3 100644 --- a/tests/test_fallbacks.py +++ b/tests/test_fallbacks.py @@ -228,3 +228,52 @@ async def test_chat_completion_client_fallbacks_with_custom_message(has_access): except Exception as e: if has_access: pytest.fail("Expected this to work: {}".format(str(e))) + + +import asyncio +from openai import AsyncOpenAI +from typing import List +import time + + +async def make_request(client: AsyncOpenAI, model: str) -> bool: + try: + await client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": "Who was Alexander?"}], + ) + return True + except Exception as e: + print(f"Error with {model}: {str(e)}") + return False + + +async def run_good_model_test(client: AsyncOpenAI, num_requests: int) -> bool: + tasks = [make_request(client, "good-model") for _ in range(num_requests)] + good_results = await asyncio.gather(*tasks) + return all(good_results) + + +@pytest.mark.asyncio +async def test_chat_completion_bad_and_good_model(): + """ + Prod test - ensure even if bad model is down, good model is still working. + """ + client = AsyncOpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000") + num_requests = 100 + num_iterations = 3 + + for iteration in range(num_iterations): + print(f"\nIteration {iteration + 1}/{num_iterations}") + start_time = time.time() + + # Fire and forget bad model requests + for _ in range(num_requests): + asyncio.create_task(make_request(client, "bad-model")) + + # Wait only for good model requests + success = await run_good_model_test(client, num_requests) + print( + f"Iteration {iteration + 1}: {'✓' if success else '✗'} ({time.time() - start_time:.2f}s)" + ) + assert success, "Not all good model requests succeeded"