forked from phoenix/litellm-mirror
feat(bedrock_httpx.py): working bedrock converse api streaming
This commit is contained in:
parent
a995a0b172
commit
51ba5652a0
6 changed files with 165 additions and 25 deletions
|
@ -5,7 +5,7 @@ warnings.filterwarnings("ignore", message=".*conflict with protected namespace.*
|
||||||
### INIT VARIABLES ###
|
### INIT VARIABLES ###
|
||||||
import threading, requests, os
|
import threading, requests, os
|
||||||
from typing import Callable, List, Optional, Dict, Union, Any, Literal
|
from typing import Callable, List, Optional, Dict, Union, Any, Literal
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from litellm.caching import Cache
|
from litellm.caching import Cache
|
||||||
from litellm._logging import (
|
from litellm._logging import (
|
||||||
set_verbose,
|
set_verbose,
|
||||||
|
@ -232,6 +232,7 @@ max_end_user_budget: Optional[float] = None
|
||||||
#### RELIABILITY ####
|
#### RELIABILITY ####
|
||||||
request_timeout: float = 6000
|
request_timeout: float = 6000
|
||||||
module_level_aclient = AsyncHTTPHandler(timeout=request_timeout)
|
module_level_aclient = AsyncHTTPHandler(timeout=request_timeout)
|
||||||
|
module_level_client = HTTPHandler(timeout=request_timeout)
|
||||||
num_retries: Optional[int] = None # per model endpoint
|
num_retries: Optional[int] = None # per model endpoint
|
||||||
default_fallbacks: Optional[List] = None
|
default_fallbacks: Optional[List] = None
|
||||||
fallbacks: Optional[List] = None
|
fallbacks: Optional[List] = None
|
||||||
|
|
|
@ -185,6 +185,37 @@ async def make_call(
|
||||||
return completion_stream
|
return completion_stream
|
||||||
|
|
||||||
|
|
||||||
|
def make_sync_call(
|
||||||
|
client: Optional[HTTPHandler],
|
||||||
|
api_base: str,
|
||||||
|
headers: dict,
|
||||||
|
data: str,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
logging_obj,
|
||||||
|
):
|
||||||
|
if client is None:
|
||||||
|
client = HTTPHandler() # Create a new client if none provided
|
||||||
|
|
||||||
|
response = client.post(api_base, headers=headers, data=data, stream=True)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise BedrockError(status_code=response.status_code, message=response.read())
|
||||||
|
|
||||||
|
decoder = AWSEventStreamDecoder(model=model)
|
||||||
|
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
|
||||||
|
|
||||||
|
# LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key="",
|
||||||
|
original_response=completion_stream, # Pass the completion stream for logging
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
|
||||||
|
return completion_stream
|
||||||
|
|
||||||
|
|
||||||
class BedrockLLM(BaseLLM):
|
class BedrockLLM(BaseLLM):
|
||||||
"""
|
"""
|
||||||
Example call
|
Example call
|
||||||
|
@ -1081,6 +1112,7 @@ class BedrockLLM(BaseLLM):
|
||||||
class AmazonConverseConfig:
|
class AmazonConverseConfig:
|
||||||
"""
|
"""
|
||||||
Reference - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
|
Reference - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
|
||||||
|
#2 - https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#conversation-inference-supported-models-features
|
||||||
"""
|
"""
|
||||||
|
|
||||||
maxTokens: Optional[int]
|
maxTokens: Optional[int]
|
||||||
|
@ -1118,30 +1150,32 @@ class AmazonConverseConfig:
|
||||||
and v is not None
|
and v is not None
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_supported_openai_params(self) -> List[str]:
|
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||||
return [
|
supported_params = [
|
||||||
"max_tokens",
|
"max_tokens",
|
||||||
"stream",
|
"stream",
|
||||||
"stream_options",
|
"stream_options",
|
||||||
"stop",
|
"stop",
|
||||||
"temperature",
|
"temperature",
|
||||||
"top_p",
|
"top_p",
|
||||||
"tools",
|
|
||||||
"tool_choice",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if (
|
||||||
|
model.startswith("anthropic")
|
||||||
|
or model.startswith("mistral")
|
||||||
|
or model.startswith("cohere")
|
||||||
|
):
|
||||||
|
supported_params.append("tools")
|
||||||
|
|
||||||
|
if model.startswith("anthropic") or model.startswith("mistral"):
|
||||||
|
# only anthropic and mistral support tool choice config. otherwise (E.g. cohere) will fail the call - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
|
||||||
|
supported_params.append("tool_choice")
|
||||||
|
|
||||||
|
return supported_params
|
||||||
|
|
||||||
def map_tool_choice_values(
|
def map_tool_choice_values(
|
||||||
self, model: str, tool_choice: Union[str, dict], drop_params: bool
|
self, model: str, tool_choice: Union[str, dict], drop_params: bool
|
||||||
) -> Optional[ToolChoiceValuesBlock]:
|
) -> Optional[ToolChoiceValuesBlock]:
|
||||||
if not model.startswith("anthropic") and not model.startswith("mistral"):
|
|
||||||
# only anthropic and mistral support tool choice config. otherwise (E.g. cohere) will fail the call - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
|
|
||||||
if drop_params == True or litellm.drop_params == True:
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
raise litellm.utils.UnsupportedParamsError(
|
|
||||||
message="Only Anthropic and Mistral on Bedrock support 'tool_choice'. To drop it from the call, set `litellm.drop_params = True.`",
|
|
||||||
status_code=400,
|
|
||||||
)
|
|
||||||
if tool_choice == "none":
|
if tool_choice == "none":
|
||||||
if litellm.drop_params is True or drop_params is True:
|
if litellm.drop_params is True or drop_params is True:
|
||||||
return None
|
return None
|
||||||
|
@ -1197,7 +1231,7 @@ class AmazonConverseConfig:
|
||||||
optional_params["tools"] = value
|
optional_params["tools"] = value
|
||||||
if param == "tool_choice":
|
if param == "tool_choice":
|
||||||
_tool_choice_value = self.map_tool_choice_values(
|
_tool_choice_value = self.map_tool_choice_values(
|
||||||
model=model, tool_choice=value, drop_params=drop_params
|
model=model, tool_choice=value, drop_params=drop_params # type: ignore
|
||||||
)
|
)
|
||||||
if _tool_choice_value is not None:
|
if _tool_choice_value is not None:
|
||||||
optional_params["tool_choice"] = _tool_choice_value
|
optional_params["tool_choice"] = _tool_choice_value
|
||||||
|
@ -1539,7 +1573,7 @@ class BedrockConverseLLM(BaseLLM):
|
||||||
else:
|
else:
|
||||||
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
|
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
|
||||||
|
|
||||||
if (stream is not None and stream == True) and provider != "ai21":
|
if (stream is not None and stream is True) and provider != "ai21":
|
||||||
endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream"
|
endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream"
|
||||||
else:
|
else:
|
||||||
endpoint_url = f"{endpoint_url}/model/{modelId}/converse"
|
endpoint_url = f"{endpoint_url}/model/{modelId}/converse"
|
||||||
|
@ -1561,7 +1595,7 @@ class BedrockConverseLLM(BaseLLM):
|
||||||
inference_params = copy.deepcopy(optional_params)
|
inference_params = copy.deepcopy(optional_params)
|
||||||
additional_request_keys = []
|
additional_request_keys = []
|
||||||
additional_request_params = {}
|
additional_request_params = {}
|
||||||
supported_converse_params = AmazonConverseConfig().get_config().keys()
|
supported_converse_params = AmazonConverseConfig.__annotations__.keys()
|
||||||
supported_tool_call_params = ["tools", "tool_choice"]
|
supported_tool_call_params = ["tools", "tool_choice"]
|
||||||
## TRANSFORMATION ##
|
## TRANSFORMATION ##
|
||||||
# send all model-specific params in 'additional_request_params'
|
# send all model-specific params in 'additional_request_params'
|
||||||
|
@ -1596,6 +1630,7 @@ class BedrockConverseLLM(BaseLLM):
|
||||||
"messages": bedrock_messages,
|
"messages": bedrock_messages,
|
||||||
"additionalModelRequestFields": additional_request_params,
|
"additionalModelRequestFields": additional_request_params,
|
||||||
"system": system_content_blocks,
|
"system": system_content_blocks,
|
||||||
|
"inferenceConfig": InferenceConfig(**inference_params),
|
||||||
}
|
}
|
||||||
if bedrock_tool_config is not None:
|
if bedrock_tool_config is not None:
|
||||||
_data["toolConfig"] = bedrock_tool_config
|
_data["toolConfig"] = bedrock_tool_config
|
||||||
|
@ -1623,7 +1658,35 @@ class BedrockConverseLLM(BaseLLM):
|
||||||
)
|
)
|
||||||
|
|
||||||
### ROUTING (ASYNC, STREAMING, SYNC)
|
### ROUTING (ASYNC, STREAMING, SYNC)
|
||||||
|
if (stream is not None and stream is True) and provider != "ai21":
|
||||||
|
|
||||||
|
streaming_response = CustomStreamWrapper(
|
||||||
|
completion_stream=None,
|
||||||
|
make_call=partial(
|
||||||
|
make_sync_call,
|
||||||
|
client=None,
|
||||||
|
api_base=prepped.url,
|
||||||
|
headers=prepped.headers,
|
||||||
|
data=data,
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
),
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="bedrock",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key="",
|
||||||
|
original_response=streaming_response,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
return streaming_response
|
||||||
### COMPLETION
|
### COMPLETION
|
||||||
|
|
||||||
if client is None or isinstance(client, AsyncHTTPHandler):
|
if client is None or isinstance(client, AsyncHTTPHandler):
|
||||||
_params = {}
|
_params = {}
|
||||||
if timeout is not None:
|
if timeout is not None:
|
||||||
|
@ -1675,6 +1738,31 @@ class AWSEventStreamDecoder:
|
||||||
self.parser = EventStreamJSONParser()
|
self.parser = EventStreamJSONParser()
|
||||||
|
|
||||||
def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
|
def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
|
||||||
|
text = ""
|
||||||
|
tool_str = ""
|
||||||
|
is_finished = False
|
||||||
|
finish_reason = ""
|
||||||
|
usage: Optional[ConverseTokenUsageBlock] = None
|
||||||
|
if "delta" in chunk_data:
|
||||||
|
delta_obj = ContentBlockDeltaEvent(**chunk_data["delta"])
|
||||||
|
if "text" in delta_obj:
|
||||||
|
text = delta_obj["text"]
|
||||||
|
elif "toolUse" in delta_obj:
|
||||||
|
tool_str = delta_obj["toolUse"]["input"]
|
||||||
|
elif "stopReason" in chunk_data:
|
||||||
|
finish_reason = map_finish_reason(chunk_data.get("stopReason", "stop"))
|
||||||
|
elif "usage" in chunk_data:
|
||||||
|
usage = ConverseTokenUsageBlock(**chunk_data["usage"])
|
||||||
|
response = GenericStreamingChunk(
|
||||||
|
text=text,
|
||||||
|
tool_str=tool_str,
|
||||||
|
is_finished=is_finished,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
usage=usage,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _old_chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
|
||||||
text = ""
|
text = ""
|
||||||
is_finished = False
|
is_finished = False
|
||||||
finish_reason = ""
|
finish_reason = ""
|
||||||
|
@ -1763,12 +1851,11 @@ class AWSEventStreamDecoder:
|
||||||
|
|
||||||
def _parse_message_from_event(self, event) -> Optional[str]:
|
def _parse_message_from_event(self, event) -> Optional[str]:
|
||||||
response_dict = event.to_response_dict()
|
response_dict = event.to_response_dict()
|
||||||
parsed_response = self.parser.parse(response_dict, get_response_stream_shape())
|
|
||||||
if response_dict["status_code"] != 200:
|
if response_dict["status_code"] != 200:
|
||||||
raise ValueError(f"Bad response code, expected 200: {response_dict}")
|
raise ValueError(f"Bad response code, expected 200: {response_dict}")
|
||||||
|
|
||||||
chunk = parsed_response.get("chunk")
|
chunk = response_dict.get("body")
|
||||||
if not chunk:
|
if not chunk:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return chunk.get("bytes").decode() # type: ignore[no-any-return]
|
return chunk.decode() # type: ignore[no-any-return]
|
||||||
|
|
|
@ -168,6 +168,7 @@ class HTTPHandler:
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def __del__(self) -> None:
|
def __del__(self) -> None:
|
||||||
|
traceback.print_stack()
|
||||||
try:
|
try:
|
||||||
self.close()
|
self.close()
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
|
@ -1284,7 +1284,7 @@ async def test_completion_replicate_llama3_streaming(sync_mode):
|
||||||
# pytest.fail(f"Error occurred: {e}")
|
# pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
@pytest.mark.parametrize("sync_mode", [True]) # False
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model",
|
"model",
|
||||||
[
|
[
|
||||||
|
@ -1324,6 +1324,8 @@ async def test_bedrock_httpx_streaming(sync_mode, model):
|
||||||
raise Exception("finish reason not set")
|
raise Exception("finish reason not set")
|
||||||
if complete_response.strip() == "":
|
if complete_response.strip() == "":
|
||||||
raise Exception("Empty response received")
|
raise Exception("Empty response received")
|
||||||
|
|
||||||
|
assert False
|
||||||
else:
|
else:
|
||||||
response: litellm.CustomStreamWrapper = await litellm.acompletion( # type: ignore
|
response: litellm.CustomStreamWrapper = await litellm.acompletion( # type: ignore
|
||||||
model=model,
|
model=model,
|
||||||
|
|
|
@ -107,10 +107,30 @@ class ToolConfigBlock(TypedDict, total=False):
|
||||||
toolChoice: Union[str, ToolChoiceValuesBlock]
|
toolChoice: Union[str, ToolChoiceValuesBlock]
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceConfig(TypedDict, total=False):
|
||||||
|
maxTokens: int
|
||||||
|
stopSequences: List[str]
|
||||||
|
temperature: float
|
||||||
|
topP: float
|
||||||
|
|
||||||
|
|
||||||
|
class ToolBlockDeltaEvent(TypedDict):
|
||||||
|
input: str
|
||||||
|
|
||||||
|
|
||||||
|
class ContentBlockDeltaEvent(TypedDict, total=False):
|
||||||
|
"""
|
||||||
|
Either 'text' or 'toolUse' will be specified for Converse API streaming response.
|
||||||
|
"""
|
||||||
|
|
||||||
|
text: str
|
||||||
|
toolUse: ToolBlockDeltaEvent
|
||||||
|
|
||||||
|
|
||||||
class RequestObject(TypedDict, total=False):
|
class RequestObject(TypedDict, total=False):
|
||||||
additionalModelRequestFields: dict
|
additionalModelRequestFields: dict
|
||||||
additionalModelResponseFieldPaths: List[str]
|
additionalModelResponseFieldPaths: List[str]
|
||||||
inferenceConfig: dict
|
inferenceConfig: InferenceConfig
|
||||||
messages: Required[List[MessageBlock]]
|
messages: Required[List[MessageBlock]]
|
||||||
system: List[SystemContentBlock]
|
system: List[SystemContentBlock]
|
||||||
toolConfig: ToolConfigBlock
|
toolConfig: ToolConfigBlock
|
||||||
|
@ -118,8 +138,10 @@ class RequestObject(TypedDict, total=False):
|
||||||
|
|
||||||
class GenericStreamingChunk(TypedDict):
|
class GenericStreamingChunk(TypedDict):
|
||||||
text: Required[str]
|
text: Required[str]
|
||||||
|
tool_str: Required[str]
|
||||||
is_finished: Required[bool]
|
is_finished: Required[bool]
|
||||||
finish_reason: Required[str]
|
finish_reason: Required[str]
|
||||||
|
usage: Optional[ConverseTokenUsageBlock]
|
||||||
|
|
||||||
|
|
||||||
class Document(TypedDict):
|
class Document(TypedDict):
|
||||||
|
|
|
@ -239,6 +239,8 @@ def map_finish_reason(
|
||||||
return "length"
|
return "length"
|
||||||
elif finish_reason == "tool_use": # anthropic
|
elif finish_reason == "tool_use": # anthropic
|
||||||
return "tool_calls"
|
return "tool_calls"
|
||||||
|
elif finish_reason == "content_filtered":
|
||||||
|
return "content_filter"
|
||||||
return finish_reason
|
return finish_reason
|
||||||
|
|
||||||
|
|
||||||
|
@ -6330,7 +6332,7 @@ def get_supported_openai_params(
|
||||||
- None if unmapped
|
- None if unmapped
|
||||||
"""
|
"""
|
||||||
if custom_llm_provider == "bedrock":
|
if custom_llm_provider == "bedrock":
|
||||||
return litellm.AmazonConverseConfig().get_supported_openai_params()
|
return litellm.AmazonConverseConfig().get_supported_openai_params(model=model)
|
||||||
elif custom_llm_provider == "ollama":
|
elif custom_llm_provider == "ollama":
|
||||||
return litellm.OllamaConfig().get_supported_openai_params()
|
return litellm.OllamaConfig().get_supported_openai_params()
|
||||||
elif custom_llm_provider == "ollama_chat":
|
elif custom_llm_provider == "ollama_chat":
|
||||||
|
@ -11242,12 +11244,27 @@ class CustomStreamWrapper:
|
||||||
if response_obj["is_finished"]:
|
if response_obj["is_finished"]:
|
||||||
self.received_finish_reason = response_obj["finish_reason"]
|
self.received_finish_reason = response_obj["finish_reason"]
|
||||||
elif self.custom_llm_provider == "bedrock":
|
elif self.custom_llm_provider == "bedrock":
|
||||||
|
from litellm.types.llms.bedrock import GenericStreamingChunk
|
||||||
|
|
||||||
if self.received_finish_reason is not None:
|
if self.received_finish_reason is not None:
|
||||||
raise StopIteration
|
raise StopIteration
|
||||||
response_obj = self.handle_bedrock_stream(chunk)
|
response_obj: GenericStreamingChunk = chunk
|
||||||
completion_obj["content"] = response_obj["text"]
|
completion_obj["content"] = response_obj["text"]
|
||||||
|
|
||||||
if response_obj["is_finished"]:
|
if response_obj["is_finished"]:
|
||||||
self.received_finish_reason = response_obj["finish_reason"]
|
self.received_finish_reason = response_obj["finish_reason"]
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.stream_options
|
||||||
|
and self.stream_options.get("include_usage", False) is True
|
||||||
|
and response_obj["usage"] is not None
|
||||||
|
):
|
||||||
|
self.sent_stream_usage = True
|
||||||
|
model_response.usage = litellm.Usage(
|
||||||
|
prompt_tokens=response_obj["usage"]["inputTokens"],
|
||||||
|
completion_tokens=response_obj["usage"]["outputTokens"],
|
||||||
|
total_tokens=response_obj["usage"]["totalTokens"],
|
||||||
|
)
|
||||||
elif self.custom_llm_provider == "sagemaker":
|
elif self.custom_llm_provider == "sagemaker":
|
||||||
print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}")
|
print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}")
|
||||||
response_obj = self.handle_sagemaker_stream(chunk)
|
response_obj = self.handle_sagemaker_stream(chunk)
|
||||||
|
@ -11509,7 +11526,7 @@ class CustomStreamWrapper:
|
||||||
and hasattr(model_response, "usage")
|
and hasattr(model_response, "usage")
|
||||||
and hasattr(model_response.usage, "prompt_tokens")
|
and hasattr(model_response.usage, "prompt_tokens")
|
||||||
):
|
):
|
||||||
if self.sent_first_chunk == False:
|
if self.sent_first_chunk is False:
|
||||||
completion_obj["role"] = "assistant"
|
completion_obj["role"] = "assistant"
|
||||||
self.sent_first_chunk = True
|
self.sent_first_chunk = True
|
||||||
model_response.choices[0].delta = Delta(**completion_obj)
|
model_response.choices[0].delta = Delta(**completion_obj)
|
||||||
|
@ -11677,6 +11694,8 @@ class CustomStreamWrapper:
|
||||||
|
|
||||||
def __next__(self):
|
def __next__(self):
|
||||||
try:
|
try:
|
||||||
|
if self.completion_stream is None:
|
||||||
|
self.fetch_sync_stream()
|
||||||
while True:
|
while True:
|
||||||
if (
|
if (
|
||||||
isinstance(self.completion_stream, str)
|
isinstance(self.completion_stream, str)
|
||||||
|
@ -11751,6 +11770,14 @@ class CustomStreamWrapper:
|
||||||
custom_llm_provider=self.custom_llm_provider,
|
custom_llm_provider=self.custom_llm_provider,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def fetch_sync_stream(self):
|
||||||
|
if self.completion_stream is None and self.make_call is not None:
|
||||||
|
# Call make_call to get the completion stream
|
||||||
|
self.completion_stream = self.make_call(client=litellm.module_level_client)
|
||||||
|
self._stream_iter = self.completion_stream.__iter__()
|
||||||
|
|
||||||
|
return self.completion_stream
|
||||||
|
|
||||||
async def fetch_stream(self):
|
async def fetch_stream(self):
|
||||||
if self.completion_stream is None and self.make_call is not None:
|
if self.completion_stream is None and self.make_call is not None:
|
||||||
# Call make_call to get the completion stream
|
# Call make_call to get the completion stream
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue