diff --git a/litellm/__init__.py b/litellm/__init__.py index 9fa801318..1995a1177 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -5,6 +5,7 @@ warnings.filterwarnings("ignore", message=".*conflict with protected namespace.* ### INIT VARIABLES ### import threading, requests, os from typing import Callable, List, Optional, Dict, Union, Any, Literal +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.caching import Cache from litellm._logging import ( set_verbose, @@ -212,6 +213,7 @@ add_function_to_prompt: bool = ( ) client_session: Optional[httpx.Client] = None aclient_session: Optional[httpx.AsyncClient] = None +module_level_aclient = AsyncHTTPHandler() model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks' model_cost_map_url: str = ( "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" diff --git a/litellm/llms/bedrock_httpx.py b/litellm/llms/bedrock_httpx.py index 337055dc2..4eef9a258 100644 --- a/litellm/llms/bedrock_httpx.py +++ b/litellm/llms/bedrock_httpx.py @@ -1,7 +1,7 @@ # What is this? ## Initial implementation of calling bedrock via httpx client (allows for async calls). ## V1 - covers cohere + anthropic claude-3 support - +from functools import partial import os, types import json from enum import Enum @@ -145,6 +145,37 @@ class AmazonCohereChatConfig: return optional_params +async def make_call( + client: Optional[AsyncHTTPHandler], + api_base: str, + headers: dict, + data: str, + model: str, + messages: list, + logging_obj, +): + if client is None: + client = AsyncHTTPHandler() # Create a new client if none provided + + response = await client.post(api_base, headers=headers, data=data, stream=True) + + if response.status_code != 200: + raise BedrockError(status_code=response.status_code, message=response.text) + + decoder = AWSEventStreamDecoder(model=model) + completion_stream = decoder.aiter_bytes(response.aiter_bytes(chunk_size=1024)) + + # LOGGING + logging_obj.post_call( + input=messages, + api_key="", + original_response=completion_stream, # Pass the completion stream for logging + additional_args={"complete_input_dict": data}, + ) + + return completion_stream + + class BedrockLLM(BaseLLM): """ Example call @@ -968,39 +999,24 @@ class BedrockLLM(BaseLLM): headers={}, client: Optional[AsyncHTTPHandler] = None, ) -> CustomStreamWrapper: - if client is None: - _params = {} - if timeout is not None: - if isinstance(timeout, float) or isinstance(timeout, int): - timeout = httpx.Timeout(timeout) - _params["timeout"] = timeout - self.client = AsyncHTTPHandler(**_params) # type: ignore - else: - self.client = client # type: ignore + # The call is not made here; instead, we prepare the necessary objects for the stream. - response = await self.client.post(api_base, headers=headers, data=data, stream=True) # type: ignore - - if response.status_code != 200: - raise BedrockError(status_code=response.status_code, message=response.text) - - decoder = AWSEventStreamDecoder(model=model) - - completion_stream = decoder.aiter_bytes(response.aiter_bytes(chunk_size=1024)) streaming_response = CustomStreamWrapper( - completion_stream=completion_stream, + completion_stream=None, + make_call=partial( + make_call, + client=client, + api_base=api_base, + headers=headers, + data=data, + model=model, + messages=messages, + logging_obj=logging_obj, + ), model=model, custom_llm_provider="bedrock", logging_obj=logging_obj, ) - - ## LOGGING - logging_obj.post_call( - input=messages, - api_key="", - original_response=streaming_response, - additional_args={"complete_input_dict": data}, - ) - return streaming_response def embedding(self, *args, **kwargs): diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py index 8b5f11398..1efbb4501 100644 --- a/litellm/llms/custom_httpx/http_handler.py +++ b/litellm/llms/custom_httpx/http_handler.py @@ -1,4 +1,4 @@ -import httpx, asyncio +import httpx, asyncio, traceback from typing import Optional, Union, Mapping, Any # https://www.python-httpx.org/advanced/timeouts @@ -48,11 +48,17 @@ class AsyncHTTPHandler: headers: Optional[dict] = None, stream: bool = False, ): - req = self.client.build_request( - "POST", url, data=data, json=json, params=params, headers=headers # type: ignore - ) - response = await self.client.send(req, stream=stream) - return response + try: + req = self.client.build_request( + "POST", url, data=data, json=json, params=params, headers=headers # type: ignore + ) + response = await self.client.send(req, stream=stream) + response.raise_for_status() + return response + except httpx.HTTPStatusError as e: + raise + except Exception as e: + raise def __del__(self) -> None: try: diff --git a/litellm/llms/databricks.py b/litellm/llms/databricks.py index 7b2013710..4fe475259 100644 --- a/litellm/llms/databricks.py +++ b/litellm/llms/databricks.py @@ -1,5 +1,6 @@ # What is this? ## Handler file for databricks API https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request +from functools import partial import os, types import json from enum import Enum @@ -123,7 +124,7 @@ class DatabricksConfig: original_chunk = None # this is used for function/tool calling chunk_data = chunk_data.replace("data:", "") chunk_data = chunk_data.strip() - if len(chunk_data) == 0: + if len(chunk_data) == 0 or chunk_data == "[DONE]": return { "text": "", "is_finished": is_finished, @@ -221,6 +222,32 @@ class DatabricksEmbeddingConfig: return optional_params +async def make_call( + client: AsyncHTTPHandler, + api_base: str, + headers: dict, + data: str, + model: str, + messages: list, + logging_obj, +): + response = await client.post(api_base, headers=headers, data=data, stream=True) + + if response.status_code != 200: + raise DatabricksError(status_code=response.status_code, message=response.text) + + completion_stream = response.aiter_lines() + # LOGGING + logging_obj.post_call( + input=messages, + api_key="", + original_response=completion_stream, # Pass the completion stream for logging + additional_args={"complete_input_dict": data}, + ) + + return completion_stream + + class DatabricksChatCompletion(BaseLLM): def __init__(self) -> None: super().__init__() @@ -354,29 +381,21 @@ class DatabricksChatCompletion(BaseLLM): litellm_params=None, logger_fn=None, headers={}, - ): - self.async_handler = AsyncHTTPHandler( - timeout=httpx.Timeout(timeout=600.0, connect=5.0) - ) + client: Optional[AsyncHTTPHandler] = None, + ) -> CustomStreamWrapper: + data["stream"] = True - try: - response = await self.async_handler.post( - api_base, headers=headers, data=json.dumps(data), stream=True - ) - response.raise_for_status() - - completion_stream = response.aiter_lines() - except httpx.HTTPStatusError as e: - raise DatabricksError( - status_code=e.response.status_code, message=response.text - ) - except httpx.TimeoutException as e: - raise DatabricksError(status_code=408, message="Timeout error occurred.") - except Exception as e: - raise DatabricksError(status_code=500, message=str(e)) - streamwrapper = CustomStreamWrapper( - completion_stream=completion_stream, + completion_stream=None, + make_call=partial( + make_call, + api_base=api_base, + headers=headers, + data=json.dumps(data), + model=model, + messages=messages, + logging_obj=logging_obj, + ), model=model, custom_llm_provider="databricks", logging_obj=logging_obj, @@ -475,6 +494,8 @@ class DatabricksChatCompletion(BaseLLM): }, ) if acompletion == True: + if client is not None and isinstance(client, HTTPHandler): + client = None if ( stream is not None and stream == True ): # if function call - fake the streaming (need complete blocks for output parsing in openai format) @@ -496,6 +517,7 @@ class DatabricksChatCompletion(BaseLLM): litellm_params=litellm_params, logger_fn=logger_fn, headers=headers, + client=client, ) else: return self.acompletion_function( diff --git a/litellm/llms/predibase.py b/litellm/llms/predibase.py index 1e7e1d334..bd69a4250 100644 --- a/litellm/llms/predibase.py +++ b/litellm/llms/predibase.py @@ -1,7 +1,7 @@ # What is this? ## Controller file for Predibase Integration - https://predibase.com/ - +from functools import partial import os, types import json from enum import Enum @@ -51,6 +51,32 @@ class PredibaseError(Exception): ) # Call the base class constructor with the parameters it needs +async def make_call( + client: AsyncHTTPHandler, + api_base: str, + headers: dict, + data: str, + model: str, + messages: list, + logging_obj, +): + response = await client.post(api_base, headers=headers, data=data, stream=True) + + if response.status_code != 200: + raise PredibaseError(status_code=response.status_code, message=response.text) + + completion_stream = response.aiter_lines() + # LOGGING + logging_obj.post_call( + input=messages, + api_key="", + original_response=completion_stream, # Pass the completion stream for logging + additional_args={"complete_input_dict": data}, + ) + + return completion_stream + + class PredibaseConfig: """ Reference: https://docs.predibase.com/user-guide/inference/rest_api @@ -126,11 +152,17 @@ class PredibaseChatCompletion(BaseLLM): def __init__(self) -> None: super().__init__() - def _validate_environment(self, api_key: Optional[str], user_headers: dict) -> dict: + def _validate_environment( + self, api_key: Optional[str], user_headers: dict, tenant_id: Optional[str] + ) -> dict: if api_key is None: raise ValueError( "Missing Predibase API Key - A call is being made to predibase but no key is set either in the environment variables or via params" ) + if tenant_id is None: + raise ValueError( + "Missing Predibase Tenant ID - Required for making the request. Set dynamically (e.g. `completion(..tenant_id=)`) or in env - `PREDIBASE_TENANT_ID`." + ) headers = { "content-type": "application/json", "Authorization": "Bearer {}".format(api_key), @@ -304,7 +336,7 @@ class PredibaseChatCompletion(BaseLLM): logger_fn=None, headers: dict = {}, ) -> Union[ModelResponse, CustomStreamWrapper]: - headers = self._validate_environment(api_key, headers) + headers = self._validate_environment(api_key, headers, tenant_id=tenant_id) completion_url = "" input_text = "" base_url = "https://serving.app.predibase.com" @@ -488,26 +520,19 @@ class PredibaseChatCompletion(BaseLLM): logger_fn=None, headers={}, ) -> CustomStreamWrapper: - self.async_handler = AsyncHTTPHandler( - timeout=httpx.Timeout(timeout=600.0, connect=5.0) - ) data["stream"] = True - response = await self.async_handler.post( - url=api_base, - headers=headers, - data=json.dumps(data), - stream=True, - ) - - if response.status_code != 200: - raise PredibaseError( - status_code=response.status_code, message=response.text - ) - - completion_stream = response.aiter_lines() streamwrapper = CustomStreamWrapper( - completion_stream=completion_stream, + completion_stream=None, + make_call=partial( + make_call, + api_base=api_base, + headers=headers, + data=json.dumps(data), + model=model, + messages=messages, + logging_obj=logging_obj, + ), model=model, custom_llm_provider="predibase", logging_obj=logging_obj, diff --git a/litellm/llms/replicate.py b/litellm/llms/replicate.py index 386d24f59..ce62e51e9 100644 --- a/litellm/llms/replicate.py +++ b/litellm/llms/replicate.py @@ -251,7 +251,7 @@ async def async_handle_prediction_response( logs = "" while True and (status not in ["succeeded", "failed", "canceled"]): print_verbose(f"replicate: polling endpoint: {prediction_url}") - await asyncio.sleep(0.5) + await asyncio.sleep(0.5) # prevent replicate rate limit errors response = await http_handler.get(prediction_url, headers=headers) if response.status_code == 200: response_data = response.json() diff --git a/litellm/main.py b/litellm/main.py index 525a39d68..37565ff50 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -361,6 +361,7 @@ async def acompletion( ) # sets the logging event loop if the user does sync streaming (e.g. on proxy for sagemaker calls) return response except Exception as e: + traceback.print_exc() custom_llm_provider = custom_llm_provider or "openai" raise exception_type( model=model, diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index f32bba50b..f281f8ec2 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -1385,8 +1385,21 @@ def test_bedrock_claude_3_streaming(): @pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.parametrize( + "model", + [ + "claude-3-opus-20240229", + "cohere.command-r-plus-v1:0", # bedrock + "databricks/databricks-dbrx-instruct", # databricks + "predibase/llama-3-8b-instruct", # predibase + "replicate/meta/meta-llama-3-8b-instruct", # replicate + ], +) @pytest.mark.asyncio -async def test_claude_3_streaming_finish_reason(sync_mode): +async def test_parallel_streaming_requests(sync_mode, model): + """ + Important prod test. + """ try: import threading @@ -1398,7 +1411,7 @@ async def test_claude_3_streaming_finish_reason(sync_mode): def sync_test_streaming(): response: litellm.CustomStreamWrapper = litellm.acompletion( # type: ignore - model="claude-3-opus-20240229", + model=model, messages=messages, stream=True, max_tokens=10, @@ -1415,7 +1428,7 @@ async def test_claude_3_streaming_finish_reason(sync_mode): async def test_streaming(): response: litellm.CustomStreamWrapper = await litellm.acompletion( # type: ignore - model="claude-3-opus-20240229", + model=model, messages=messages, stream=True, max_tokens=10, @@ -1424,8 +1437,9 @@ async def test_claude_3_streaming_finish_reason(sync_mode): # Add any assertions here to-check the response num_finish_reason = 0 async for chunk in response: - print(f"chunk: {chunk}") + print(f"type of chunk: {type(chunk)}") if isinstance(chunk, ModelResponse): + print(f"OUTSIDE CHUNK: {chunk.choices[0]}") if chunk.choices[0].finish_reason is not None: num_finish_reason += 1 assert num_finish_reason == 1 diff --git a/litellm/utils.py b/litellm/utils.py index 9d2fcaec2..b02b420a7 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -32,7 +32,7 @@ from dataclasses import ( ) import litellm._service_logger # for storing API inputs, outputs, and metadata -from litellm.llms.custom_httpx.http_handler import HTTPHandler +from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler from litellm.caching import DualCache from litellm.types.utils import CostPerToken, ProviderField, ModelInfo @@ -10214,8 +10214,10 @@ class CustomStreamWrapper: custom_llm_provider=None, logging_obj=None, stream_options=None, + make_call: Optional[Callable] = None, ): self.model = model + self.make_call = make_call self.custom_llm_provider = custom_llm_provider self.logging_obj = logging_obj self.completion_stream = completion_stream @@ -11766,8 +11768,20 @@ class CustomStreamWrapper: custom_llm_provider=self.custom_llm_provider, ) + async def fetch_stream(self): + if self.completion_stream is None and self.make_call is not None: + # Call make_call to get the completion stream + self.completion_stream = await self.make_call( + client=litellm.module_level_aclient + ) + self._stream_iter = self.completion_stream.__aiter__() + + return self.completion_stream + async def __anext__(self): try: + if self.completion_stream is None: + await self.fetch_stream() if ( self.custom_llm_provider == "openai" or self.custom_llm_provider == "azure"