Merge pull request #3586 from BerriAI/litellm_bedrock_command_r_support

feat(bedrock_httpx.py): Make Bedrock-Cohere calls Async
This commit is contained in:
Krish Dholakia 2024-05-11 21:24:51 -07:00 committed by GitHub
commit 94c9df969e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
39 changed files with 1222 additions and 195 deletions

View file

@ -10,7 +10,6 @@ from litellm.caching import DualCache
from typing import Literal, Union
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
@ -19,8 +18,6 @@ import traceback
import dotenv, os
import requests
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
import datetime, subprocess, sys
import litellm, uuid

View file

@ -743,6 +743,7 @@ from .llms.sagemaker import SagemakerConfig
from .llms.ollama import OllamaConfig
from .llms.ollama_chat import OllamaChatConfig
from .llms.maritalk import MaritTalkConfig
from .llms.bedrock_httpx import AmazonCohereChatConfig
from .llms.bedrock import (
AmazonTitanConfig,
AmazonAI21Config,

View file

@ -1,8 +1,6 @@
#### What this does ####
# On success + failure, log events to aispend.io
import dotenv, os
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
import datetime

View file

@ -3,7 +3,6 @@
import dotenv, os
import requests # type: ignore
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
import datetime

View file

@ -8,8 +8,6 @@ from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
from typing import Literal, Union
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
@ -18,8 +16,6 @@ import traceback
import dotenv, os
import requests
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
import datetime, subprocess, sys
import litellm, uuid

View file

@ -6,8 +6,6 @@ from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
from typing import Literal, Union, Optional
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback

View file

@ -3,8 +3,6 @@
import dotenv, os
import requests # type: ignore
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
import datetime, subprocess, sys
import litellm, uuid

View file

@ -3,8 +3,6 @@
import dotenv, os
import requests # type: ignore
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
import datetime, subprocess, sys
import litellm, uuid

View file

@ -3,8 +3,6 @@
import dotenv, os
import requests # type: ignore
import litellm
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback

View file

@ -1,8 +1,6 @@
#### What this does ####
# On success, logs events to Langfuse
import dotenv, os
dotenv.load_dotenv() # Loading env variables using dotenv
import os
import copy
import traceback
from packaging.version import Version

View file

@ -3,8 +3,6 @@
import dotenv, os # type: ignore
import requests # type: ignore
from datetime import datetime
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
import asyncio
import types

View file

@ -2,13 +2,10 @@
# On success + failure, log events to lunary.ai
from datetime import datetime, timezone
import traceback
import dotenv
import importlib
import packaging
dotenv.load_dotenv()
# convert to {completion: xx, tokens: xx}
def parse_usage(usage):
@ -79,14 +76,16 @@ class LunaryLogger:
version = importlib.metadata.version("lunary")
# if version < 0.1.43 then raise ImportError
if packaging.version.Version(version) < packaging.version.Version("0.1.43"):
print(
print( # noqa
"Lunary version outdated. Required: >= 0.1.43. Upgrade via 'pip install lunary --upgrade'"
)
raise ImportError
self.lunary_client = lunary
except ImportError:
print("Lunary not installed. Please install it using 'pip install lunary'")
print( # noqa
"Lunary not installed. Please install it using 'pip install lunary'"
) # noqa
raise ImportError
def log_event(

View file

@ -3,8 +3,6 @@
import dotenv, os, json
import litellm
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler

View file

@ -4,8 +4,6 @@
import dotenv, os
import requests # type: ignore
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
import datetime, subprocess, sys
import litellm, uuid

View file

@ -5,8 +5,6 @@
import dotenv, os
import requests # type: ignore
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
import datetime, subprocess, sys
import litellm, uuid

View file

@ -3,8 +3,6 @@
import dotenv, os
import requests # type: ignore
from pydantic import BaseModel
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback

View file

@ -1,9 +1,7 @@
#### What this does ####
# On success + failure, log events to Supabase
import dotenv, os
dotenv.load_dotenv() # Loading env variables using dotenv
import os
import traceback
import datetime, subprocess, sys
import litellm, uuid

View file

@ -2,8 +2,6 @@
# Class for sending Slack Alerts #
import dotenv, os
from litellm.proxy._types import UserAPIKeyAuth
dotenv.load_dotenv() # Loading env variables using dotenv
from litellm._logging import verbose_logger, verbose_proxy_logger
import litellm, threading
from typing import List, Literal, Any, Union, Optional, Dict

View file

@ -3,8 +3,6 @@
import dotenv, os
import requests # type: ignore
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
import datetime, subprocess, sys
import litellm

View file

@ -21,11 +21,11 @@ try:
# contains a (known) object attribute
object: Literal["chat.completion", "edit", "text_completion"]
def __getitem__(self, key: K) -> V:
... # pragma: no cover
def __getitem__(self, key: K) -> V: ... # noqa
def get(self, key: K, default: Optional[V] = None) -> Optional[V]:
... # pragma: no cover
def get( # noqa
self, key: K, default: Optional[V] = None
) -> Optional[V]: ... # pragma: no cover
class OpenAIRequestResponseResolver:
def __call__(
@ -173,12 +173,11 @@ except:
#### What this does ####
# On success, logs events to Langfuse
import dotenv, os
import os
import requests
import requests
from datetime import datetime
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback

View file

@ -3,7 +3,7 @@ import json
from enum import Enum
import requests, copy # type: ignore
import time
from typing import Callable, Optional, List
from typing import Callable, Optional, List, Union
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt
@ -151,19 +151,135 @@ class AnthropicChatCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
def process_streaming_response(
self,
model: str,
response: Union[requests.Response, httpx.Response],
model_response: ModelResponse,
stream: bool,
logging_obj: litellm.utils.Logging,
optional_params: dict,
api_key: str,
data: Union[dict, str],
messages: List,
print_verbose,
encoding,
) -> CustomStreamWrapper:
"""
Return stream object for tool-calling + streaming
"""
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
try:
completion_response = response.json()
except:
raise AnthropicError(
message=response.text, status_code=response.status_code
)
text_content = ""
tool_calls = []
for content in completion_response["content"]:
if content["type"] == "text":
text_content += content["text"]
## TOOL CALLING
elif content["type"] == "tool_use":
tool_calls.append(
{
"id": content["id"],
"type": "function",
"function": {
"name": content["name"],
"arguments": json.dumps(content["input"]),
},
}
)
if "error" in completion_response:
raise AnthropicError(
message=str(completion_response["error"]),
status_code=response.status_code,
)
_message = litellm.Message(
tool_calls=tool_calls,
content=text_content or None,
)
model_response.choices[0].message = _message # type: ignore
model_response._hidden_params["original_response"] = completion_response[
"content"
] # allow user to access raw anthropic tool calling response
model_response.choices[0].finish_reason = map_finish_reason(
completion_response["stop_reason"]
)
print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
# return an iterator
streaming_model_response = ModelResponse(stream=True)
streaming_model_response.choices[0].finish_reason = model_response.choices[ # type: ignore
0
].finish_reason
# streaming_model_response.choices = [litellm.utils.StreamingChoices()]
streaming_choice = litellm.utils.StreamingChoices()
streaming_choice.index = model_response.choices[0].index
_tool_calls = []
print_verbose(
f"type of model_response.choices[0]: {type(model_response.choices[0])}"
)
print_verbose(f"type of streaming_choice: {type(streaming_choice)}")
if isinstance(model_response.choices[0], litellm.Choices):
if getattr(
model_response.choices[0].message, "tool_calls", None
) is not None and isinstance(
model_response.choices[0].message.tool_calls, list
):
for tool_call in model_response.choices[0].message.tool_calls:
_tool_call = {**tool_call.dict(), "index": 0}
_tool_calls.append(_tool_call)
delta_obj = litellm.utils.Delta(
content=getattr(model_response.choices[0].message, "content", None),
role=model_response.choices[0].message.role,
tool_calls=_tool_calls,
)
streaming_choice.delta = delta_obj
streaming_model_response.choices = [streaming_choice]
completion_stream = ModelResponseIterator(
model_response=streaming_model_response
)
print_verbose(
"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
)
return CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="cached_response",
logging_obj=logging_obj,
)
else:
raise AnthropicError(
status_code=422,
message="Unprocessable response object - {}".format(response.text),
)
def process_response(
self,
model,
response,
model_response,
_is_function_call,
stream,
logging_obj,
api_key,
data,
messages,
model: str,
response: Union[requests.Response, httpx.Response],
model_response: ModelResponse,
stream: bool,
logging_obj: litellm.utils.Logging,
optional_params: dict,
api_key: str,
data: Union[dict, str],
messages: List,
print_verbose,
):
encoding,
) -> ModelResponse:
## LOGGING
logging_obj.post_call(
input=messages,
@ -216,51 +332,6 @@ class AnthropicChatCompletion(BaseLLM):
completion_response["stop_reason"]
)
print_verbose(f"_is_function_call: {_is_function_call}; stream: {stream}")
if _is_function_call and stream:
print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
# return an iterator
streaming_model_response = ModelResponse(stream=True)
streaming_model_response.choices[0].finish_reason = model_response.choices[
0
].finish_reason
# streaming_model_response.choices = [litellm.utils.StreamingChoices()]
streaming_choice = litellm.utils.StreamingChoices()
streaming_choice.index = model_response.choices[0].index
_tool_calls = []
print_verbose(
f"type of model_response.choices[0]: {type(model_response.choices[0])}"
)
print_verbose(f"type of streaming_choice: {type(streaming_choice)}")
if isinstance(model_response.choices[0], litellm.Choices):
if getattr(
model_response.choices[0].message, "tool_calls", None
) is not None and isinstance(
model_response.choices[0].message.tool_calls, list
):
for tool_call in model_response.choices[0].message.tool_calls:
_tool_call = {**tool_call.dict(), "index": 0}
_tool_calls.append(_tool_call)
delta_obj = litellm.utils.Delta(
content=getattr(model_response.choices[0].message, "content", None),
role=model_response.choices[0].message.role,
tool_calls=_tool_calls,
)
streaming_choice.delta = delta_obj
streaming_model_response.choices = [streaming_choice]
completion_stream = ModelResponseIterator(
model_response=streaming_model_response
)
print_verbose(
"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
)
return CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="cached_response",
logging_obj=logging_obj,
)
## CALCULATING USAGE
prompt_tokens = completion_response["usage"]["input_tokens"]
completion_tokens = completion_response["usage"]["output_tokens"]
@ -273,7 +344,7 @@ class AnthropicChatCompletion(BaseLLM):
completion_tokens=completion_tokens,
total_tokens=total_tokens,
)
model_response.usage = usage
setattr(model_response, "usage", usage) # type: ignore
return model_response
async def acompletion_stream_function(
@ -289,7 +360,7 @@ class AnthropicChatCompletion(BaseLLM):
logging_obj,
stream,
_is_function_call,
data=None,
data: dict,
optional_params=None,
litellm_params=None,
logger_fn=None,
@ -331,29 +402,44 @@ class AnthropicChatCompletion(BaseLLM):
logging_obj,
stream,
_is_function_call,
data=None,
optional_params=None,
data: dict,
optional_params: dict,
litellm_params=None,
logger_fn=None,
headers={},
):
) -> Union[ModelResponse, CustomStreamWrapper]:
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
response = await self.async_handler.post(
api_base, headers=headers, data=json.dumps(data)
)
if stream and _is_function_call:
return self.process_streaming_response(
model=model,
response=response,
model_response=model_response,
stream=stream,
logging_obj=logging_obj,
api_key=api_key,
data=data,
messages=messages,
print_verbose=print_verbose,
optional_params=optional_params,
encoding=encoding,
)
return self.process_response(
model=model,
response=response,
model_response=model_response,
_is_function_call=_is_function_call,
stream=stream,
logging_obj=logging_obj,
api_key=api_key,
data=data,
messages=messages,
print_verbose=print_verbose,
optional_params=optional_params,
encoding=encoding,
)
def completion(
@ -367,7 +453,7 @@ class AnthropicChatCompletion(BaseLLM):
encoding,
api_key,
logging_obj,
optional_params=None,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
@ -526,17 +612,33 @@ class AnthropicChatCompletion(BaseLLM):
raise AnthropicError(
status_code=response.status_code, message=response.text
)
if stream and _is_function_call:
return self.process_streaming_response(
model=model,
response=response,
model_response=model_response,
stream=stream,
logging_obj=logging_obj,
api_key=api_key,
data=data,
messages=messages,
print_verbose=print_verbose,
optional_params=optional_params,
encoding=encoding,
)
return self.process_response(
model=model,
response=response,
model_response=model_response,
_is_function_call=_is_function_call,
stream=stream,
logging_obj=logging_obj,
api_key=api_key,
data=data,
messages=messages,
print_verbose=print_verbose,
optional_params=optional_params,
encoding=encoding,
)
def embedding(self):

View file

@ -100,7 +100,7 @@ class AnthropicTextCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
def process_response(
def _process_response(
self, model_response: ModelResponse, response, encoding, prompt: str, model: str
):
## RESPONSE OBJECT
@ -171,7 +171,7 @@ class AnthropicTextCompletion(BaseLLM):
additional_args={"complete_input_dict": data},
)
response = self.process_response(
response = self._process_response(
model_response=model_response,
response=response,
encoding=encoding,
@ -330,7 +330,7 @@ class AnthropicTextCompletion(BaseLLM):
)
print_verbose(f"raw model_response: {response.text}")
response = self.process_response(
response = self._process_response(
model_response=model_response,
response=response,
encoding=encoding,

View file

@ -1,12 +1,32 @@
## This is a template base class to be used for adding new LLM providers via API calls
import litellm
import httpx
from typing import Optional
import httpx, requests
from typing import Optional, Union
from litellm.utils import Logging
class BaseLLM:
_client_session: Optional[httpx.Client] = None
def process_response(
self,
model: str,
response: Union[requests.Response, httpx.Response],
model_response: litellm.utils.ModelResponse,
stream: bool,
logging_obj: Logging,
optional_params: dict,
api_key: str,
data: Union[dict, str],
messages: list,
print_verbose,
encoding,
) -> litellm.utils.ModelResponse:
"""
Helper function to process the response across sync + async completion calls
"""
return model_response
def create_client_session(self):
if litellm.client_session:
_client_session = litellm.client_session

View file

@ -0,0 +1,733 @@
# What is this?
## Initial implementation of calling bedrock via httpx client (allows for async calls).
## V0 - just covers cohere command-r support
import os, types
import json
from enum import Enum
import requests, copy # type: ignore
import time
from typing import (
Callable,
Optional,
List,
Literal,
Union,
Any,
TypedDict,
Tuple,
Iterator,
AsyncIterator,
)
from litellm.utils import (
ModelResponse,
Usage,
map_finish_reason,
CustomStreamWrapper,
Message,
Choices,
get_secret,
Logging,
)
import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt, cohere_message_pt
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from .base import BaseLLM
import httpx # type: ignore
from .bedrock import BedrockError, convert_messages_to_prompt
from litellm.types.llms.bedrock import *
class AmazonCohereChatConfig:
"""
Reference - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command-r-plus.html
"""
documents: Optional[List[Document]] = None
search_queries_only: Optional[bool] = None
preamble: Optional[str] = None
max_tokens: Optional[int] = None
temperature: Optional[float] = None
p: Optional[float] = None
k: Optional[float] = None
prompt_truncation: Optional[str] = None
frequency_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
seed: Optional[int] = None
return_prompt: Optional[bool] = None
stop_sequences: Optional[List[str]] = None
raw_prompting: Optional[bool] = None
def __init__(
self,
documents: Optional[List[Document]] = None,
search_queries_only: Optional[bool] = None,
preamble: Optional[str] = None,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
p: Optional[float] = None,
k: Optional[float] = None,
prompt_truncation: Optional[str] = None,
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
seed: Optional[int] = None,
return_prompt: Optional[bool] = None,
stop_sequences: Optional[str] = None,
raw_prompting: Optional[bool] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self) -> List[str]:
return [
"max_tokens",
"stream",
"stop",
"temperature",
"top_p",
"frequency_penalty",
"presence_penalty",
"seed",
"stop",
]
def map_openai_params(
self, non_default_params: dict, optional_params: dict
) -> dict:
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["max_tokens"] = value
if param == "stream":
optional_params["stream"] = value
if param == "stop":
if isinstance(value, str):
value = [value]
optional_params["stop_sequences"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["p"] = value
if param == "frequency_penalty":
optional_params["frequency_penalty"] = value
if param == "presence_penalty":
optional_params["presence_penalty"] = value
if "seed":
optional_params["seed"] = value
return optional_params
class BedrockLLM(BaseLLM):
"""
Example call
```
curl --location --request POST 'https://bedrock-runtime.{aws_region_name}.amazonaws.com/model/{bedrock_model_name}/invoke' \
--header 'Content-Type: application/json' \
--header 'Accept: application/json' \
--user "$AWS_ACCESS_KEY_ID":"$AWS_SECRET_ACCESS_KEY" \
--aws-sigv4 "aws:amz:us-east-1:bedrock" \
--data-raw '{
"prompt": "Hi",
"temperature": 0,
"p": 0.9,
"max_tokens": 4096
}'
```
"""
def __init__(self) -> None:
super().__init__()
def convert_messages_to_prompt(
self, model, messages, provider, custom_prompt_dict
) -> Tuple[str, Optional[list]]:
# handle anthropic prompts and amazon titan prompts
prompt = ""
chat_history: Optional[list] = None
if provider == "anthropic" or provider == "amazon":
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages,
)
else:
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="bedrock"
)
elif provider == "mistral":
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="bedrock"
)
elif provider == "meta":
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="bedrock"
)
elif provider == "cohere":
prompt, chat_history = cohere_message_pt(messages=messages)
else:
prompt = ""
for message in messages:
if "role" in message:
if message["role"] == "user":
prompt += f"{message['content']}"
else:
prompt += f"{message['content']}"
else:
prompt += f"{message['content']}"
return prompt, chat_history # type: ignore
def get_credentials(
self,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_region_name: Optional[str] = None,
aws_session_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
aws_role_name: Optional[str] = None,
):
"""
Return a boto3.Credentials object
"""
import boto3
## CHECK IS 'os.environ/' passed in
params_to_check: List[Optional[str]] = [
aws_access_key_id,
aws_secret_access_key,
aws_region_name,
aws_session_name,
aws_profile_name,
aws_role_name,
]
# Iterate over parameters and update if needed
for i, param in enumerate(params_to_check):
if param and param.startswith("os.environ/"):
_v = get_secret(param)
if _v is not None and isinstance(_v, str):
params_to_check[i] = _v
# Assign updated values back to parameters
(
aws_access_key_id,
aws_secret_access_key,
aws_region_name,
aws_session_name,
aws_profile_name,
aws_role_name,
) = params_to_check
### CHECK STS ###
if aws_role_name is not None and aws_session_name is not None:
sts_client = boto3.client(
"sts",
aws_access_key_id=aws_access_key_id, # [OPTIONAL]
aws_secret_access_key=aws_secret_access_key, # [OPTIONAL]
)
sts_response = sts_client.assume_role(
RoleArn=aws_role_name, RoleSessionName=aws_session_name
)
return sts_response["Credentials"]
elif aws_profile_name is not None: ### CHECK SESSION ###
# uses auth values from AWS profile usually stored in ~/.aws/credentials
client = boto3.Session(profile_name=aws_profile_name)
return client.get_credentials()
else:
session = boto3.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=aws_region_name,
)
return session.get_credentials()
def process_response(
self,
model: str,
response: Union[requests.Response, httpx.Response],
model_response: ModelResponse,
stream: bool,
logging_obj: Logging,
optional_params: dict,
api_key: str,
data: Union[dict, str],
messages: List,
print_verbose,
encoding,
) -> ModelResponse:
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
try:
completion_response = response.json()
except:
raise BedrockError(message=response.text, status_code=422)
try:
model_response.choices[0].message.content = completion_response["text"] # type: ignore
except Exception as e:
raise BedrockError(message=response.text, status_code=422)
## CALCULATING USAGE - bedrock returns usage in the headers
prompt_tokens = int(
response.headers.get(
"x-amzn-bedrock-input-token-count",
len(encoding.encode("".join(m.get("content", "") for m in messages))),
)
)
completion_tokens = int(
response.headers.get(
"x-amzn-bedrock-output-token-count",
len(
encoding.encode(
model_response.choices[0].message.content, # type: ignore
disallowed_special=(),
)
),
)
)
model_response["created"] = int(time.time())
model_response["model"] = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
def completion(
self,
model: str,
messages: list,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
logging_obj,
optional_params: dict,
acompletion: bool,
timeout: Optional[Union[float, httpx.Timeout]],
litellm_params=None,
logger_fn=None,
extra_headers: Optional[dict] = None,
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
) -> Union[ModelResponse, CustomStreamWrapper]:
try:
import boto3
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
except ImportError as e:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
## SETUP ##
stream = optional_params.pop("stream", None)
## CREDENTIALS ##
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_region_name = optional_params.pop("aws_region_name", None)
aws_role_name = optional_params.pop("aws_role_name", None)
aws_session_name = optional_params.pop("aws_session_name", None)
aws_profile_name = optional_params.pop("aws_profile_name", None)
aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None
) # https://bedrock-runtime.{region_name}.amazonaws.com
### SET REGION NAME ###
if aws_region_name is None:
# check env #
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
if litellm_aws_region_name is not None and isinstance(
litellm_aws_region_name, str
):
aws_region_name = litellm_aws_region_name
standard_aws_region_name = get_secret("AWS_REGION", None)
if standard_aws_region_name is not None and isinstance(
standard_aws_region_name, str
):
aws_region_name = standard_aws_region_name
if aws_region_name is None:
aws_region_name = "us-west-2"
credentials: Credentials = self.get_credentials(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name,
aws_session_name=aws_session_name,
aws_profile_name=aws_profile_name,
aws_role_name=aws_role_name,
)
### SET RUNTIME ENDPOINT ###
endpoint_url = ""
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
if aws_bedrock_runtime_endpoint is not None and isinstance(
aws_bedrock_runtime_endpoint, str
):
endpoint_url = aws_bedrock_runtime_endpoint
elif env_aws_bedrock_runtime_endpoint and isinstance(
env_aws_bedrock_runtime_endpoint, str
):
endpoint_url = env_aws_bedrock_runtime_endpoint
else:
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
if stream is not None and stream == True:
endpoint_url = f"{endpoint_url}/model/{model}/invoke-with-response-stream"
else:
endpoint_url = f"{endpoint_url}/model/{model}/invoke"
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
provider = model.split(".")[0]
prompt, chat_history = self.convert_messages_to_prompt(
model, messages, provider, custom_prompt_dict
)
inference_params = copy.deepcopy(optional_params)
if provider == "cohere":
if model.startswith("cohere.command-r"):
## LOAD CONFIG
config = litellm.AmazonCohereChatConfig().get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
_data = {"message": prompt, **inference_params}
if chat_history is not None:
_data["chat_history"] = chat_history
data = json.dumps(_data)
else:
## LOAD CONFIG
config = litellm.AmazonCohereConfig.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
if stream == True:
inference_params["stream"] = (
True # cohere requires stream = True in inference params
)
data = json.dumps({"prompt": prompt, **inference_params})
else:
raise Exception("UNSUPPORTED PROVIDER")
## COMPLETION CALL
headers = {"Content-Type": "application/json"}
if extra_headers is not None:
headers = {"Content-Type": "application/json", **extra_headers}
request = AWSRequest(
method="POST", url=endpoint_url, data=data, headers=headers
)
sigv4.add_auth(request)
prepped = request.prepare()
## LOGGING
logging_obj.pre_call(
input=messages,
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": prepped.url,
"headers": prepped.headers,
},
)
### ROUTING (ASYNC, STREAMING, SYNC)
if acompletion:
if isinstance(client, HTTPHandler):
client = None
if stream:
return self.async_streaming(
model=model,
messages=messages,
data=data,
api_base=prepped.url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
logging_obj=logging_obj,
optional_params=optional_params,
stream=True,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=prepped.headers,
timeout=timeout,
client=client,
) # type: ignore
### ASYNC COMPLETION
return self.async_completion(
model=model,
messages=messages,
data=data,
api_base=prepped.url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
logging_obj=logging_obj,
optional_params=optional_params,
stream=False,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=prepped.headers,
timeout=timeout,
client=client,
) # type: ignore
if client is None or isinstance(client, AsyncHTTPHandler):
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
self.client = HTTPHandler(**_params) # type: ignore
else:
self.client = client
if stream is not None and stream == True:
response = self.client.post(
url=prepped.url,
headers=prepped.headers, # type: ignore
data=data,
stream=stream,
)
if response.status_code != 200:
raise BedrockError(
status_code=response.status_code, message=response.text
)
decoder = AWSEventStreamDecoder()
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
streaming_response = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="bedrock",
logging_obj=logging_obj,
)
return streaming_response
response = self.client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore
try:
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=response.text)
return self.process_response(
model=model,
response=response,
model_response=model_response,
stream=stream,
logging_obj=logging_obj,
optional_params=optional_params,
api_key="",
data=data,
messages=messages,
print_verbose=print_verbose,
encoding=encoding,
)
async def async_completion(
self,
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
data: str,
timeout: Optional[Union[float, httpx.Timeout]],
encoding,
logging_obj,
stream,
optional_params: dict,
litellm_params=None,
logger_fn=None,
headers={},
client: Optional[AsyncHTTPHandler] = None,
) -> ModelResponse:
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
self.client = AsyncHTTPHandler(**_params) # type: ignore
else:
self.client = client # type: ignore
response = await self.client.post(api_base, headers=headers, data=data) # type: ignore
return self.process_response(
model=model,
response=response,
model_response=model_response,
stream=stream,
logging_obj=logging_obj,
api_key="",
data=data,
messages=messages,
print_verbose=print_verbose,
optional_params=optional_params,
encoding=encoding,
)
async def async_streaming(
self,
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
data: str,
timeout: Optional[Union[float, httpx.Timeout]],
encoding,
logging_obj,
stream,
optional_params: dict,
litellm_params=None,
logger_fn=None,
headers={},
client: Optional[AsyncHTTPHandler] = None,
) -> CustomStreamWrapper:
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
self.client = AsyncHTTPHandler(**_params) # type: ignore
else:
self.client = client # type: ignore
response = await self.client.post(api_base, headers=headers, data=data, stream=True) # type: ignore
if response.status_code != 200:
raise BedrockError(status_code=response.status_code, message=response.text)
decoder = AWSEventStreamDecoder()
completion_stream = decoder.aiter_bytes(response.aiter_bytes(chunk_size=1024))
streaming_response = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="bedrock",
logging_obj=logging_obj,
)
return streaming_response
def embedding(self, *args, **kwargs):
return super().embedding(*args, **kwargs)
def get_response_stream_shape():
from botocore.model import ServiceModel
from botocore.loaders import Loader
loader = Loader()
bedrock_service_dict = loader.load_service_model("bedrock-runtime", "service-2")
bedrock_service_model = ServiceModel(bedrock_service_dict)
return bedrock_service_model.shape_for("ResponseStream")
class AWSEventStreamDecoder:
def __init__(self) -> None:
from botocore.parsers import EventStreamJSONParser
self.parser = EventStreamJSONParser()
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GenericStreamingChunk]:
"""Given an iterator that yields lines, iterate over it & yield every event encountered"""
from botocore.eventstream import EventStreamBuffer
event_stream_buffer = EventStreamBuffer()
for chunk in iterator:
event_stream_buffer.add_data(chunk)
for event in event_stream_buffer:
message = self._parse_message_from_event(event)
if message:
# sse_event = ServerSentEvent(data=message, event="completion")
_data = json.loads(message)
streaming_chunk: GenericStreamingChunk = GenericStreamingChunk(
text=_data.get("text", ""),
is_finished=_data.get("is_finished", False),
finish_reason=_data.get("finish_reason", ""),
)
yield streaming_chunk
async def aiter_bytes(
self, iterator: AsyncIterator[bytes]
) -> AsyncIterator[GenericStreamingChunk]:
"""Given an async iterator that yields lines, iterate over it & yield every event encountered"""
from botocore.eventstream import EventStreamBuffer
event_stream_buffer = EventStreamBuffer()
async for chunk in iterator:
event_stream_buffer.add_data(chunk)
for event in event_stream_buffer:
message = self._parse_message_from_event(event)
if message:
_data = json.loads(message)
streaming_chunk: GenericStreamingChunk = GenericStreamingChunk(
text=_data.get("text", ""),
is_finished=_data.get("is_finished", False),
finish_reason=_data.get("finish_reason", ""),
)
yield streaming_chunk
def _parse_message_from_event(self, event) -> Optional[str]:
response_dict = event.to_response_dict()
parsed_response = self.parser.parse(response_dict, get_response_stream_shape())
if response_dict["status_code"] != 200:
raise ValueError(f"Bad response code, expected 200: {response_dict}")
chunk = parsed_response.get("chunk")
if not chunk:
return None
return chunk.get("bytes").decode() # type: ignore[no-any-return]

View file

@ -58,16 +58,25 @@ class AsyncHTTPHandler:
class HTTPHandler:
def __init__(
self, timeout: httpx.Timeout = _DEFAULT_TIMEOUT, concurrent_limit=1000
self,
timeout: Optional[httpx.Timeout] = None,
concurrent_limit=1000,
client: Optional[httpx.Client] = None,
):
# Create a client with a connection pool
self.client = httpx.Client(
timeout=timeout,
limits=httpx.Limits(
max_connections=concurrent_limit,
max_keepalive_connections=concurrent_limit,
),
)
if timeout is None:
timeout = _DEFAULT_TIMEOUT
if client is None:
# Create a client with a connection pool
self.client = httpx.Client(
timeout=timeout,
limits=httpx.Limits(
max_connections=concurrent_limit,
max_keepalive_connections=concurrent_limit,
),
)
else:
self.client = client
def close(self):
# Close the client when you're done with it
@ -82,11 +91,15 @@ class HTTPHandler:
def post(
self,
url: str,
data: Optional[dict] = None,
data: Optional[Union[dict, str]] = None,
params: Optional[dict] = None,
headers: Optional[dict] = None,
stream: bool = False,
):
response = self.client.post(url, data=data, params=params, headers=headers)
req = self.client.build_request(
"POST", url, data=data, params=params, headers=headers # type: ignore
)
response = self.client.send(req, stream=stream)
return response
def __del__(self) -> None:

View file

@ -168,7 +168,7 @@ class PredibaseChatCompletion(BaseLLM):
logging_obj: litellm.utils.Logging,
optional_params: dict,
api_key: str,
data: dict,
data: Union[dict, str],
messages: list,
print_verbose,
encoding,
@ -185,9 +185,7 @@ class PredibaseChatCompletion(BaseLLM):
try:
completion_response = response.json()
except:
raise PredibaseError(
message=response.text, status_code=response.status_code
)
raise PredibaseError(message=response.text, status_code=422)
if "error" in completion_response:
raise PredibaseError(
message=str(completion_response["error"]),
@ -363,7 +361,7 @@ class PredibaseChatCompletion(BaseLLM):
},
)
## COMPLETION CALL
if acompletion is True:
if acompletion == True:
### ASYNC STREAMING
if stream == True:
return self.async_streaming(

View file

@ -76,6 +76,7 @@ from .llms.anthropic import AnthropicChatCompletion
from .llms.anthropic_text import AnthropicTextCompletion
from .llms.huggingface_restapi import Huggingface
from .llms.predibase import PredibaseChatCompletion
from .llms.bedrock_httpx import BedrockLLM
from .llms.triton import TritonChatCompletion
from .llms.prompt_templates.factory import (
prompt_factory,
@ -105,7 +106,6 @@ from litellm.utils import (
)
####### ENVIRONMENT VARIABLES ###################
dotenv.load_dotenv() # Loading env variables using dotenv
openai_chat_completions = OpenAIChatCompletion()
openai_text_completions = OpenAITextCompletion()
anthropic_chat_completions = AnthropicChatCompletion()
@ -115,6 +115,7 @@ azure_text_completions = AzureTextCompletion()
huggingface = Huggingface()
predibase_chat_completions = PredibaseChatCompletion()
triton_chat_completions = TritonChatCompletion()
bedrock_chat_completion = BedrockLLM()
####### COMPLETION ENDPOINTS ################
@ -257,7 +258,7 @@ async def acompletion(
- If `stream` is True, the function returns an async generator that yields completion lines.
"""
loop = asyncio.get_event_loop()
custom_llm_provider = None
custom_llm_provider = kwargs.get("custom_llm_provider", None)
# Adjusted to use explicit arguments instead of *args and **kwargs
completion_kwargs = {
"model": model,
@ -289,9 +290,10 @@ async def acompletion(
"model_list": model_list,
"acompletion": True, # assuming this is a required parameter
}
_, custom_llm_provider, _, _ = get_llm_provider(
model=model, api_base=completion_kwargs.get("base_url", None)
)
if custom_llm_provider is None:
_, custom_llm_provider, _, _ = get_llm_provider(
model=model, api_base=completion_kwargs.get("base_url", None)
)
try:
# Use a partial function to pass your keyword arguments
func = partial(completion, **completion_kwargs, **kwargs)
@ -300,9 +302,6 @@ async def acompletion(
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
_, custom_llm_provider, _, _ = get_llm_provider(
model=model, api_base=kwargs.get("api_base", None)
)
if (
custom_llm_provider == "openai"
or custom_llm_provider == "azure"
@ -324,6 +323,7 @@ async def acompletion(
or custom_llm_provider == "sagemaker"
or custom_llm_provider == "anthropic"
or custom_llm_provider == "predibase"
or (custom_llm_provider == "bedrock" and "cohere" in model)
or custom_llm_provider in litellm.openai_compatible_providers
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
init_response = await loop.run_in_executor(None, func_with_context)
@ -1976,41 +1976,59 @@ def completion(
elif custom_llm_provider == "bedrock":
# boto3 reads keys from .env
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
response = bedrock.completion(
model=model,
messages=messages,
custom_prompt_dict=litellm.custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
extra_headers=extra_headers,
timeout=timeout,
)
if (
"stream" in optional_params
and optional_params["stream"] == True
and not isinstance(response, CustomStreamWrapper)
):
# don't try to access stream object,
if "ai21" in model:
response = CustomStreamWrapper(
response,
model,
custom_llm_provider="bedrock",
logging_obj=logging,
)
else:
response = CustomStreamWrapper(
iter(response),
model,
custom_llm_provider="bedrock",
logging_obj=logging,
)
if "cohere" in model:
response = bedrock_chat_completion.completion(
model=model,
messages=messages,
custom_prompt_dict=litellm.custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
extra_headers=extra_headers,
timeout=timeout,
acompletion=acompletion,
)
else:
response = bedrock.completion(
model=model,
messages=messages,
custom_prompt_dict=litellm.custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
extra_headers=extra_headers,
timeout=timeout,
)
if (
"stream" in optional_params
and optional_params["stream"] == True
and not isinstance(response, CustomStreamWrapper)
):
# don't try to access stream object,
if "ai21" in model:
response = CustomStreamWrapper(
response,
model,
custom_llm_provider="bedrock",
logging_obj=logging,
)
else:
response = CustomStreamWrapper(
iter(response),
model,
custom_llm_provider="bedrock",
logging_obj=logging,
)
if optional_params.get("stream", False):
## LOGGING

View file

@ -1,10 +1,7 @@
from litellm.proxy._types import UserAPIKeyAuth, GenerateKeyRequest
from fastapi import Request
from dotenv import load_dotenv
import os
load_dotenv()
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
try:

View file

@ -1507,7 +1507,6 @@ class Router:
return response
except Exception as e:
original_exception = e
"""
Retry Logic

View file

@ -8,8 +8,6 @@
import dotenv, os, requests, random # type: ignore
from typing import Optional
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger

View file

@ -1,12 +1,11 @@
#### What this does ####
# picks based on response time (for streaming, this is time to first token)
from pydantic import BaseModel, Extra, Field, root_validator
import dotenv, os, requests, random # type: ignore
import os, requests, random # type: ignore
from typing import Optional, Union, List, Dict
from datetime import datetime, timedelta
import random
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger

View file

@ -5,8 +5,6 @@ import dotenv, os, requests, random # type: ignore
from typing import Optional, Union, List, Dict
from datetime import datetime, timedelta
import random
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger

View file

@ -4,8 +4,6 @@
import dotenv, os, requests, random
from typing import Optional, Union, List, Dict
from datetime import datetime
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
from litellm import token_counter
from litellm.caching import DualCache

View file

@ -5,8 +5,6 @@ import dotenv, os, requests, random
from typing import Optional, Union, List, Dict
import datetime as datetime_og
from datetime import datetime
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback, asyncio, httpx
import litellm
from litellm import token_counter

View file

@ -2584,6 +2584,69 @@ def test_completion_chat_sagemaker_mistral():
# test_completion_chat_sagemaker_mistral()
def response_format_tests(response: litellm.ModelResponse):
assert isinstance(response.id, str)
assert response.id != ""
assert isinstance(response.object, str)
assert response.object != ""
assert isinstance(response.created, int)
assert isinstance(response.model, str)
assert response.model != ""
assert isinstance(response.choices, list)
assert len(response.choices) == 1
choice = response.choices[0]
assert isinstance(choice, litellm.Choices)
assert isinstance(choice.get("index"), int)
message = choice.get("message")
assert isinstance(message, litellm.Message)
assert isinstance(message.get("role"), str)
assert message.get("role") != ""
assert isinstance(message.get("content"), str)
assert message.get("content") != ""
assert choice.get("logprobs") is None
assert isinstance(choice.get("finish_reason"), str)
assert choice.get("finish_reason") != ""
assert isinstance(response.usage, litellm.Usage) # type: ignore
assert isinstance(response.usage.prompt_tokens, int) # type: ignore
assert isinstance(response.usage.completion_tokens, int) # type: ignore
assert isinstance(response.usage.total_tokens, int) # type: ignore
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_completion_bedrock_command_r(sync_mode):
litellm.set_verbose = True
if sync_mode:
response = completion(
model="bedrock/cohere.command-r-plus-v1:0",
messages=[{"role": "user", "content": "Hey! how's it going?"}],
)
assert isinstance(response, litellm.ModelResponse)
response_format_tests(response=response)
else:
response = await litellm.acompletion(
model="bedrock/cohere.command-r-plus-v1:0",
messages=[{"role": "user", "content": "Hey! how's it going?"}],
)
assert isinstance(response, litellm.ModelResponse)
print(f"response: {response}")
response_format_tests(response=response)
print(f"response: {response}")
def test_completion_bedrock_titan_null_response():
try:
response = completion(

View file

@ -132,12 +132,15 @@ def test_post_call_rule_streaming():
)
def test_post_call_processing_error_async_response():
response = asyncio.run(
acompletion(
@pytest.mark.asyncio
async def test_post_call_processing_error_async_response():
try:
response = await acompletion(
model="command-nightly", # Just used as an example
messages=[{"content": "Hello, how are you?", "role": "user"}],
api_base="https://openai-proxy.berriai.repl.co", # Just used as an example
custom_llm_provider="openai",
)
)
pytest.fail("This call should have failed")
except Exception as e:
pass

View file

@ -983,6 +983,64 @@ def test_vertex_ai_stream():
# pytest.fail(f"Error occurred: {e}")
@pytest.mark.parametrize("sync_mode", [True])
@pytest.mark.asyncio
async def test_bedrock_cohere_command_r_streaming(sync_mode):
try:
litellm.set_verbose = True
if sync_mode:
final_chunk: Optional[litellm.ModelResponse] = None
response: litellm.CustomStreamWrapper = completion( # type: ignore
model="bedrock/cohere.command-r-plus-v1:0",
messages=messages,
max_tokens=10, # type: ignore
stream=True,
)
complete_response = ""
# Add any assertions here to check the response
has_finish_reason = False
for idx, chunk in enumerate(response):
final_chunk = chunk
chunk, finished = streaming_format_tests(idx, chunk)
if finished:
has_finish_reason = True
break
complete_response += chunk
if has_finish_reason == False:
raise Exception("finish reason not set")
if complete_response.strip() == "":
raise Exception("Empty response received")
else:
response: litellm.CustomStreamWrapper = await litellm.acompletion( # type: ignore
model="bedrock/cohere.command-r-plus-v1:0",
messages=messages,
max_tokens=100, # type: ignore
stream=True,
)
complete_response = ""
# Add any assertions here to check the response
has_finish_reason = False
idx = 0
final_chunk: Optional[litellm.ModelResponse] = None
async for chunk in response:
final_chunk = chunk
chunk, finished = streaming_format_tests(idx, chunk)
if finished:
has_finish_reason = True
break
complete_response += chunk
idx += 1
if has_finish_reason == False:
raise Exception("finish reason not set")
if complete_response.strip() == "":
raise Exception("Empty response received")
print(f"completion_response: {complete_response}\n\nFinalChunk: {final_chunk}")
except RateLimitError:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_bedrock_claude_3_streaming():
try:
litellm.set_verbose = True

View file

@ -0,0 +1,63 @@
from typing import TypedDict, Any, Union, Optional
import json
from typing_extensions import (
Self,
Protocol,
TypeGuard,
override,
get_origin,
runtime_checkable,
Required,
)
class GenericStreamingChunk(TypedDict):
text: Required[str]
is_finished: Required[bool]
finish_reason: Required[str]
class Document(TypedDict):
title: str
snippet: str
class ServerSentEvent:
def __init__(
self,
*,
event: Optional[str] = None,
data: Optional[str] = None,
id: Optional[str] = None,
retry: Optional[int] = None,
) -> None:
if data is None:
data = ""
self._id = id
self._data = data
self._event = event or None
self._retry = retry
@property
def event(self) -> Optional[str]:
return self._event
@property
def id(self) -> Optional[str]:
return self._id
@property
def retry(self) -> Optional[int]:
return self._retry
@property
def data(self) -> str:
return self._data
def json(self) -> Any:
return json.loads(self.data)
@override
def __repr__(self) -> str:
return f"ServerSentEvent(event={self.event}, data={self.data}, id={self.id}, retry={self.retry})"

View file

@ -132,7 +132,6 @@ MAX_THREADS = 100
# Create a ThreadPoolExecutor
executor = ThreadPoolExecutor(max_workers=MAX_THREADS)
dotenv.load_dotenv() # Loading env variables using dotenv
sentry_sdk_instance = None
capture_exception = None
add_breadcrumb = None
@ -10474,6 +10473,12 @@ class CustomStreamWrapper:
raise e
def handle_bedrock_stream(self, chunk):
if "cohere" in self.model:
return {
"text": chunk["text"],
"is_finished": chunk["is_finished"],
"finish_reason": chunk["finish_reason"],
}
if hasattr(chunk, "get"):
chunk = chunk.get("chunk")
chunk_data = json.loads(chunk.get("bytes").decode())
@ -11322,6 +11327,7 @@ class CustomStreamWrapper:
or self.custom_llm_provider == "gemini"
or self.custom_llm_provider == "cached_response"
or self.custom_llm_provider == "predibase"
or (self.custom_llm_provider == "bedrock" and "cohere" in self.model)
or self.custom_llm_provider in litellm.openai_compatible_endpoints
):
async for chunk in self.completion_stream: