forked from phoenix/litellm-mirror
Merge pull request #3586 from BerriAI/litellm_bedrock_command_r_support
feat(bedrock_httpx.py): Make Bedrock-Cohere calls Async
This commit is contained in:
commit
94c9df969e
39 changed files with 1222 additions and 195 deletions
|
@ -10,7 +10,6 @@ from litellm.caching import DualCache
|
|||
|
||||
from typing import Literal, Union
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
|
||||
|
||||
|
@ -19,8 +18,6 @@ import traceback
|
|||
|
||||
import dotenv, os
|
||||
import requests
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
import datetime, subprocess, sys
|
||||
import litellm, uuid
|
||||
|
|
|
@ -743,6 +743,7 @@ from .llms.sagemaker import SagemakerConfig
|
|||
from .llms.ollama import OllamaConfig
|
||||
from .llms.ollama_chat import OllamaChatConfig
|
||||
from .llms.maritalk import MaritTalkConfig
|
||||
from .llms.bedrock_httpx import AmazonCohereChatConfig
|
||||
from .llms.bedrock import (
|
||||
AmazonTitanConfig,
|
||||
AmazonAI21Config,
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
#### What this does ####
|
||||
# On success + failure, log events to aispend.io
|
||||
import dotenv, os
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
import datetime
|
||||
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
import dotenv, os
|
||||
import requests # type: ignore
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
import datetime
|
||||
|
||||
|
|
|
@ -8,8 +8,6 @@ from litellm.proxy._types import UserAPIKeyAuth
|
|||
from litellm.caching import DualCache
|
||||
|
||||
from typing import Literal, Union
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
|
||||
|
||||
|
@ -18,8 +16,6 @@ import traceback
|
|||
|
||||
import dotenv, os
|
||||
import requests
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
import datetime, subprocess, sys
|
||||
import litellm, uuid
|
||||
|
|
|
@ -6,8 +6,6 @@ from litellm.proxy._types import UserAPIKeyAuth
|
|||
from litellm.caching import DualCache
|
||||
|
||||
from typing import Literal, Union, Optional
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
|
||||
|
||||
|
|
|
@ -3,8 +3,6 @@
|
|||
|
||||
import dotenv, os
|
||||
import requests # type: ignore
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
import datetime, subprocess, sys
|
||||
import litellm, uuid
|
||||
|
|
|
@ -3,8 +3,6 @@
|
|||
|
||||
import dotenv, os
|
||||
import requests # type: ignore
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
import datetime, subprocess, sys
|
||||
import litellm, uuid
|
||||
|
|
|
@ -3,8 +3,6 @@
|
|||
import dotenv, os
|
||||
import requests # type: ignore
|
||||
import litellm
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
|
||||
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
#### What this does ####
|
||||
# On success, logs events to Langfuse
|
||||
import dotenv, os
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import os
|
||||
import copy
|
||||
import traceback
|
||||
from packaging.version import Version
|
||||
|
|
|
@ -3,8 +3,6 @@
|
|||
import dotenv, os # type: ignore
|
||||
import requests # type: ignore
|
||||
from datetime import datetime
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
import asyncio
|
||||
import types
|
||||
|
|
|
@ -2,13 +2,10 @@
|
|||
# On success + failure, log events to lunary.ai
|
||||
from datetime import datetime, timezone
|
||||
import traceback
|
||||
import dotenv
|
||||
import importlib
|
||||
|
||||
import packaging
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
|
||||
# convert to {completion: xx, tokens: xx}
|
||||
def parse_usage(usage):
|
||||
|
@ -79,14 +76,16 @@ class LunaryLogger:
|
|||
version = importlib.metadata.version("lunary")
|
||||
# if version < 0.1.43 then raise ImportError
|
||||
if packaging.version.Version(version) < packaging.version.Version("0.1.43"):
|
||||
print(
|
||||
print( # noqa
|
||||
"Lunary version outdated. Required: >= 0.1.43. Upgrade via 'pip install lunary --upgrade'"
|
||||
)
|
||||
raise ImportError
|
||||
|
||||
self.lunary_client = lunary
|
||||
except ImportError:
|
||||
print("Lunary not installed. Please install it using 'pip install lunary'")
|
||||
print( # noqa
|
||||
"Lunary not installed. Please install it using 'pip install lunary'"
|
||||
) # noqa
|
||||
raise ImportError
|
||||
|
||||
def log_event(
|
||||
|
|
|
@ -3,8 +3,6 @@
|
|||
|
||||
import dotenv, os, json
|
||||
import litellm
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
|
|
|
@ -4,8 +4,6 @@
|
|||
|
||||
import dotenv, os
|
||||
import requests # type: ignore
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
import datetime, subprocess, sys
|
||||
import litellm, uuid
|
||||
|
|
|
@ -5,8 +5,6 @@
|
|||
|
||||
import dotenv, os
|
||||
import requests # type: ignore
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
import datetime, subprocess, sys
|
||||
import litellm, uuid
|
||||
|
|
|
@ -3,8 +3,6 @@
|
|||
import dotenv, os
|
||||
import requests # type: ignore
|
||||
from pydantic import BaseModel
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
|
||||
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
#### What this does ####
|
||||
# On success + failure, log events to Supabase
|
||||
|
||||
import dotenv, os
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import os
|
||||
import traceback
|
||||
import datetime, subprocess, sys
|
||||
import litellm, uuid
|
||||
|
|
|
@ -2,8 +2,6 @@
|
|||
# Class for sending Slack Alerts #
|
||||
import dotenv, os
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
from litellm._logging import verbose_logger, verbose_proxy_logger
|
||||
import litellm, threading
|
||||
from typing import List, Literal, Any, Union, Optional, Dict
|
||||
|
|
|
@ -3,8 +3,6 @@
|
|||
|
||||
import dotenv, os
|
||||
import requests # type: ignore
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
import datetime, subprocess, sys
|
||||
import litellm
|
||||
|
|
|
@ -21,11 +21,11 @@ try:
|
|||
# contains a (known) object attribute
|
||||
object: Literal["chat.completion", "edit", "text_completion"]
|
||||
|
||||
def __getitem__(self, key: K) -> V:
|
||||
... # pragma: no cover
|
||||
def __getitem__(self, key: K) -> V: ... # noqa
|
||||
|
||||
def get(self, key: K, default: Optional[V] = None) -> Optional[V]:
|
||||
... # pragma: no cover
|
||||
def get( # noqa
|
||||
self, key: K, default: Optional[V] = None
|
||||
) -> Optional[V]: ... # pragma: no cover
|
||||
|
||||
class OpenAIRequestResponseResolver:
|
||||
def __call__(
|
||||
|
@ -173,12 +173,11 @@ except:
|
|||
|
||||
#### What this does ####
|
||||
# On success, logs events to Langfuse
|
||||
import dotenv, os
|
||||
import os
|
||||
import requests
|
||||
import requests
|
||||
from datetime import datetime
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ import json
|
|||
from enum import Enum
|
||||
import requests, copy # type: ignore
|
||||
import time
|
||||
from typing import Callable, Optional, List
|
||||
from typing import Callable, Optional, List, Union
|
||||
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
|
||||
import litellm
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
|
@ -151,19 +151,135 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def process_streaming_response(
|
||||
self,
|
||||
model: str,
|
||||
response: Union[requests.Response, httpx.Response],
|
||||
model_response: ModelResponse,
|
||||
stream: bool,
|
||||
logging_obj: litellm.utils.Logging,
|
||||
optional_params: dict,
|
||||
api_key: str,
|
||||
data: Union[dict, str],
|
||||
messages: List,
|
||||
print_verbose,
|
||||
encoding,
|
||||
) -> CustomStreamWrapper:
|
||||
"""
|
||||
Return stream object for tool-calling + streaming
|
||||
"""
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
print_verbose(f"raw model_response: {response.text}")
|
||||
## RESPONSE OBJECT
|
||||
try:
|
||||
completion_response = response.json()
|
||||
except:
|
||||
raise AnthropicError(
|
||||
message=response.text, status_code=response.status_code
|
||||
)
|
||||
text_content = ""
|
||||
tool_calls = []
|
||||
for content in completion_response["content"]:
|
||||
if content["type"] == "text":
|
||||
text_content += content["text"]
|
||||
## TOOL CALLING
|
||||
elif content["type"] == "tool_use":
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": content["id"],
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": content["name"],
|
||||
"arguments": json.dumps(content["input"]),
|
||||
},
|
||||
}
|
||||
)
|
||||
if "error" in completion_response:
|
||||
raise AnthropicError(
|
||||
message=str(completion_response["error"]),
|
||||
status_code=response.status_code,
|
||||
)
|
||||
_message = litellm.Message(
|
||||
tool_calls=tool_calls,
|
||||
content=text_content or None,
|
||||
)
|
||||
model_response.choices[0].message = _message # type: ignore
|
||||
model_response._hidden_params["original_response"] = completion_response[
|
||||
"content"
|
||||
] # allow user to access raw anthropic tool calling response
|
||||
|
||||
model_response.choices[0].finish_reason = map_finish_reason(
|
||||
completion_response["stop_reason"]
|
||||
)
|
||||
|
||||
print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
|
||||
# return an iterator
|
||||
streaming_model_response = ModelResponse(stream=True)
|
||||
streaming_model_response.choices[0].finish_reason = model_response.choices[ # type: ignore
|
||||
0
|
||||
].finish_reason
|
||||
# streaming_model_response.choices = [litellm.utils.StreamingChoices()]
|
||||
streaming_choice = litellm.utils.StreamingChoices()
|
||||
streaming_choice.index = model_response.choices[0].index
|
||||
_tool_calls = []
|
||||
print_verbose(
|
||||
f"type of model_response.choices[0]: {type(model_response.choices[0])}"
|
||||
)
|
||||
print_verbose(f"type of streaming_choice: {type(streaming_choice)}")
|
||||
if isinstance(model_response.choices[0], litellm.Choices):
|
||||
if getattr(
|
||||
model_response.choices[0].message, "tool_calls", None
|
||||
) is not None and isinstance(
|
||||
model_response.choices[0].message.tool_calls, list
|
||||
):
|
||||
for tool_call in model_response.choices[0].message.tool_calls:
|
||||
_tool_call = {**tool_call.dict(), "index": 0}
|
||||
_tool_calls.append(_tool_call)
|
||||
delta_obj = litellm.utils.Delta(
|
||||
content=getattr(model_response.choices[0].message, "content", None),
|
||||
role=model_response.choices[0].message.role,
|
||||
tool_calls=_tool_calls,
|
||||
)
|
||||
streaming_choice.delta = delta_obj
|
||||
streaming_model_response.choices = [streaming_choice]
|
||||
completion_stream = ModelResponseIterator(
|
||||
model_response=streaming_model_response
|
||||
)
|
||||
print_verbose(
|
||||
"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
|
||||
)
|
||||
return CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="cached_response",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
else:
|
||||
raise AnthropicError(
|
||||
status_code=422,
|
||||
message="Unprocessable response object - {}".format(response.text),
|
||||
)
|
||||
|
||||
def process_response(
|
||||
self,
|
||||
model,
|
||||
response,
|
||||
model_response,
|
||||
_is_function_call,
|
||||
stream,
|
||||
logging_obj,
|
||||
api_key,
|
||||
data,
|
||||
messages,
|
||||
model: str,
|
||||
response: Union[requests.Response, httpx.Response],
|
||||
model_response: ModelResponse,
|
||||
stream: bool,
|
||||
logging_obj: litellm.utils.Logging,
|
||||
optional_params: dict,
|
||||
api_key: str,
|
||||
data: Union[dict, str],
|
||||
messages: List,
|
||||
print_verbose,
|
||||
):
|
||||
encoding,
|
||||
) -> ModelResponse:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
|
@ -216,51 +332,6 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
completion_response["stop_reason"]
|
||||
)
|
||||
|
||||
print_verbose(f"_is_function_call: {_is_function_call}; stream: {stream}")
|
||||
if _is_function_call and stream:
|
||||
print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
|
||||
# return an iterator
|
||||
streaming_model_response = ModelResponse(stream=True)
|
||||
streaming_model_response.choices[0].finish_reason = model_response.choices[
|
||||
0
|
||||
].finish_reason
|
||||
# streaming_model_response.choices = [litellm.utils.StreamingChoices()]
|
||||
streaming_choice = litellm.utils.StreamingChoices()
|
||||
streaming_choice.index = model_response.choices[0].index
|
||||
_tool_calls = []
|
||||
print_verbose(
|
||||
f"type of model_response.choices[0]: {type(model_response.choices[0])}"
|
||||
)
|
||||
print_verbose(f"type of streaming_choice: {type(streaming_choice)}")
|
||||
if isinstance(model_response.choices[0], litellm.Choices):
|
||||
if getattr(
|
||||
model_response.choices[0].message, "tool_calls", None
|
||||
) is not None and isinstance(
|
||||
model_response.choices[0].message.tool_calls, list
|
||||
):
|
||||
for tool_call in model_response.choices[0].message.tool_calls:
|
||||
_tool_call = {**tool_call.dict(), "index": 0}
|
||||
_tool_calls.append(_tool_call)
|
||||
delta_obj = litellm.utils.Delta(
|
||||
content=getattr(model_response.choices[0].message, "content", None),
|
||||
role=model_response.choices[0].message.role,
|
||||
tool_calls=_tool_calls,
|
||||
)
|
||||
streaming_choice.delta = delta_obj
|
||||
streaming_model_response.choices = [streaming_choice]
|
||||
completion_stream = ModelResponseIterator(
|
||||
model_response=streaming_model_response
|
||||
)
|
||||
print_verbose(
|
||||
"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
|
||||
)
|
||||
return CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="cached_response",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
## CALCULATING USAGE
|
||||
prompt_tokens = completion_response["usage"]["input_tokens"]
|
||||
completion_tokens = completion_response["usage"]["output_tokens"]
|
||||
|
@ -273,7 +344,7 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
setattr(model_response, "usage", usage) # type: ignore
|
||||
return model_response
|
||||
|
||||
async def acompletion_stream_function(
|
||||
|
@ -289,7 +360,7 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
logging_obj,
|
||||
stream,
|
||||
_is_function_call,
|
||||
data=None,
|
||||
data: dict,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
|
@ -331,29 +402,44 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
logging_obj,
|
||||
stream,
|
||||
_is_function_call,
|
||||
data=None,
|
||||
optional_params=None,
|
||||
data: dict,
|
||||
optional_params: dict,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
):
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
self.async_handler = AsyncHTTPHandler(
|
||||
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
)
|
||||
response = await self.async_handler.post(
|
||||
api_base, headers=headers, data=json.dumps(data)
|
||||
)
|
||||
if stream and _is_function_call:
|
||||
return self.process_streaming_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
stream=stream,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
data=data,
|
||||
messages=messages,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
return self.process_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
_is_function_call=_is_function_call,
|
||||
stream=stream,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
data=data,
|
||||
messages=messages,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
def completion(
|
||||
|
@ -367,7 +453,7 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params=None,
|
||||
optional_params: dict,
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
|
@ -526,17 +612,33 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
raise AnthropicError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
|
||||
if stream and _is_function_call:
|
||||
return self.process_streaming_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
stream=stream,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
data=data,
|
||||
messages=messages,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
return self.process_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
_is_function_call=_is_function_call,
|
||||
stream=stream,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
data=data,
|
||||
messages=messages,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
def embedding(self):
|
||||
|
|
|
@ -100,7 +100,7 @@ class AnthropicTextCompletion(BaseLLM):
|
|||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def process_response(
|
||||
def _process_response(
|
||||
self, model_response: ModelResponse, response, encoding, prompt: str, model: str
|
||||
):
|
||||
## RESPONSE OBJECT
|
||||
|
@ -171,7 +171,7 @@ class AnthropicTextCompletion(BaseLLM):
|
|||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
response = self.process_response(
|
||||
response = self._process_response(
|
||||
model_response=model_response,
|
||||
response=response,
|
||||
encoding=encoding,
|
||||
|
@ -330,7 +330,7 @@ class AnthropicTextCompletion(BaseLLM):
|
|||
)
|
||||
print_verbose(f"raw model_response: {response.text}")
|
||||
|
||||
response = self.process_response(
|
||||
response = self._process_response(
|
||||
model_response=model_response,
|
||||
response=response,
|
||||
encoding=encoding,
|
||||
|
|
|
@ -1,12 +1,32 @@
|
|||
## This is a template base class to be used for adding new LLM providers via API calls
|
||||
import litellm
|
||||
import httpx
|
||||
from typing import Optional
|
||||
import httpx, requests
|
||||
from typing import Optional, Union
|
||||
from litellm.utils import Logging
|
||||
|
||||
|
||||
class BaseLLM:
|
||||
_client_session: Optional[httpx.Client] = None
|
||||
|
||||
def process_response(
|
||||
self,
|
||||
model: str,
|
||||
response: Union[requests.Response, httpx.Response],
|
||||
model_response: litellm.utils.ModelResponse,
|
||||
stream: bool,
|
||||
logging_obj: Logging,
|
||||
optional_params: dict,
|
||||
api_key: str,
|
||||
data: Union[dict, str],
|
||||
messages: list,
|
||||
print_verbose,
|
||||
encoding,
|
||||
) -> litellm.utils.ModelResponse:
|
||||
"""
|
||||
Helper function to process the response across sync + async completion calls
|
||||
"""
|
||||
return model_response
|
||||
|
||||
def create_client_session(self):
|
||||
if litellm.client_session:
|
||||
_client_session = litellm.client_session
|
||||
|
|
733
litellm/llms/bedrock_httpx.py
Normal file
733
litellm/llms/bedrock_httpx.py
Normal file
|
@ -0,0 +1,733 @@
|
|||
# What is this?
|
||||
## Initial implementation of calling bedrock via httpx client (allows for async calls).
|
||||
## V0 - just covers cohere command-r support
|
||||
|
||||
import os, types
|
||||
import json
|
||||
from enum import Enum
|
||||
import requests, copy # type: ignore
|
||||
import time
|
||||
from typing import (
|
||||
Callable,
|
||||
Optional,
|
||||
List,
|
||||
Literal,
|
||||
Union,
|
||||
Any,
|
||||
TypedDict,
|
||||
Tuple,
|
||||
Iterator,
|
||||
AsyncIterator,
|
||||
)
|
||||
from litellm.utils import (
|
||||
ModelResponse,
|
||||
Usage,
|
||||
map_finish_reason,
|
||||
CustomStreamWrapper,
|
||||
Message,
|
||||
Choices,
|
||||
get_secret,
|
||||
Logging,
|
||||
)
|
||||
import litellm
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt, cohere_message_pt
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from .base import BaseLLM
|
||||
import httpx # type: ignore
|
||||
from .bedrock import BedrockError, convert_messages_to_prompt
|
||||
from litellm.types.llms.bedrock import *
|
||||
|
||||
|
||||
class AmazonCohereChatConfig:
|
||||
"""
|
||||
Reference - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command-r-plus.html
|
||||
"""
|
||||
|
||||
documents: Optional[List[Document]] = None
|
||||
search_queries_only: Optional[bool] = None
|
||||
preamble: Optional[str] = None
|
||||
max_tokens: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
p: Optional[float] = None
|
||||
k: Optional[float] = None
|
||||
prompt_truncation: Optional[str] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
seed: Optional[int] = None
|
||||
return_prompt: Optional[bool] = None
|
||||
stop_sequences: Optional[List[str]] = None
|
||||
raw_prompting: Optional[bool] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
documents: Optional[List[Document]] = None,
|
||||
search_queries_only: Optional[bool] = None,
|
||||
preamble: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
p: Optional[float] = None,
|
||||
k: Optional[float] = None,
|
||||
prompt_truncation: Optional[str] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
seed: Optional[int] = None,
|
||||
return_prompt: Optional[bool] = None,
|
||||
stop_sequences: Optional[str] = None,
|
||||
raw_prompting: Optional[bool] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self) -> List[str]:
|
||||
return [
|
||||
"max_tokens",
|
||||
"stream",
|
||||
"stop",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"seed",
|
||||
"stop",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self, non_default_params: dict, optional_params: dict
|
||||
) -> dict:
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_tokens":
|
||||
optional_params["max_tokens"] = value
|
||||
if param == "stream":
|
||||
optional_params["stream"] = value
|
||||
if param == "stop":
|
||||
if isinstance(value, str):
|
||||
value = [value]
|
||||
optional_params["stop_sequences"] = value
|
||||
if param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
optional_params["p"] = value
|
||||
if param == "frequency_penalty":
|
||||
optional_params["frequency_penalty"] = value
|
||||
if param == "presence_penalty":
|
||||
optional_params["presence_penalty"] = value
|
||||
if "seed":
|
||||
optional_params["seed"] = value
|
||||
return optional_params
|
||||
|
||||
|
||||
class BedrockLLM(BaseLLM):
|
||||
"""
|
||||
Example call
|
||||
|
||||
```
|
||||
curl --location --request POST 'https://bedrock-runtime.{aws_region_name}.amazonaws.com/model/{bedrock_model_name}/invoke' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--header 'Accept: application/json' \
|
||||
--user "$AWS_ACCESS_KEY_ID":"$AWS_SECRET_ACCESS_KEY" \
|
||||
--aws-sigv4 "aws:amz:us-east-1:bedrock" \
|
||||
--data-raw '{
|
||||
"prompt": "Hi",
|
||||
"temperature": 0,
|
||||
"p": 0.9,
|
||||
"max_tokens": 4096
|
||||
}'
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def convert_messages_to_prompt(
|
||||
self, model, messages, provider, custom_prompt_dict
|
||||
) -> Tuple[str, Optional[list]]:
|
||||
# handle anthropic prompts and amazon titan prompts
|
||||
prompt = ""
|
||||
chat_history: Optional[list] = None
|
||||
if provider == "anthropic" or provider == "amazon":
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
prompt = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="bedrock"
|
||||
)
|
||||
elif provider == "mistral":
|
||||
prompt = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="bedrock"
|
||||
)
|
||||
elif provider == "meta":
|
||||
prompt = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="bedrock"
|
||||
)
|
||||
elif provider == "cohere":
|
||||
prompt, chat_history = cohere_message_pt(messages=messages)
|
||||
else:
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
if "role" in message:
|
||||
if message["role"] == "user":
|
||||
prompt += f"{message['content']}"
|
||||
else:
|
||||
prompt += f"{message['content']}"
|
||||
else:
|
||||
prompt += f"{message['content']}"
|
||||
return prompt, chat_history # type: ignore
|
||||
|
||||
def get_credentials(
|
||||
self,
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
aws_secret_access_key: Optional[str] = None,
|
||||
aws_region_name: Optional[str] = None,
|
||||
aws_session_name: Optional[str] = None,
|
||||
aws_profile_name: Optional[str] = None,
|
||||
aws_role_name: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Return a boto3.Credentials object
|
||||
"""
|
||||
import boto3
|
||||
|
||||
## CHECK IS 'os.environ/' passed in
|
||||
params_to_check: List[Optional[str]] = [
|
||||
aws_access_key_id,
|
||||
aws_secret_access_key,
|
||||
aws_region_name,
|
||||
aws_session_name,
|
||||
aws_profile_name,
|
||||
aws_role_name,
|
||||
]
|
||||
|
||||
# Iterate over parameters and update if needed
|
||||
for i, param in enumerate(params_to_check):
|
||||
if param and param.startswith("os.environ/"):
|
||||
_v = get_secret(param)
|
||||
if _v is not None and isinstance(_v, str):
|
||||
params_to_check[i] = _v
|
||||
# Assign updated values back to parameters
|
||||
(
|
||||
aws_access_key_id,
|
||||
aws_secret_access_key,
|
||||
aws_region_name,
|
||||
aws_session_name,
|
||||
aws_profile_name,
|
||||
aws_role_name,
|
||||
) = params_to_check
|
||||
|
||||
### CHECK STS ###
|
||||
if aws_role_name is not None and aws_session_name is not None:
|
||||
sts_client = boto3.client(
|
||||
"sts",
|
||||
aws_access_key_id=aws_access_key_id, # [OPTIONAL]
|
||||
aws_secret_access_key=aws_secret_access_key, # [OPTIONAL]
|
||||
)
|
||||
|
||||
sts_response = sts_client.assume_role(
|
||||
RoleArn=aws_role_name, RoleSessionName=aws_session_name
|
||||
)
|
||||
|
||||
return sts_response["Credentials"]
|
||||
elif aws_profile_name is not None: ### CHECK SESSION ###
|
||||
# uses auth values from AWS profile usually stored in ~/.aws/credentials
|
||||
client = boto3.Session(profile_name=aws_profile_name)
|
||||
|
||||
return client.get_credentials()
|
||||
else:
|
||||
session = boto3.Session(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
region_name=aws_region_name,
|
||||
)
|
||||
|
||||
return session.get_credentials()
|
||||
|
||||
def process_response(
|
||||
self,
|
||||
model: str,
|
||||
response: Union[requests.Response, httpx.Response],
|
||||
model_response: ModelResponse,
|
||||
stream: bool,
|
||||
logging_obj: Logging,
|
||||
optional_params: dict,
|
||||
api_key: str,
|
||||
data: Union[dict, str],
|
||||
messages: List,
|
||||
print_verbose,
|
||||
encoding,
|
||||
) -> ModelResponse:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
print_verbose(f"raw model_response: {response.text}")
|
||||
|
||||
## RESPONSE OBJECT
|
||||
try:
|
||||
completion_response = response.json()
|
||||
except:
|
||||
raise BedrockError(message=response.text, status_code=422)
|
||||
|
||||
try:
|
||||
model_response.choices[0].message.content = completion_response["text"] # type: ignore
|
||||
except Exception as e:
|
||||
raise BedrockError(message=response.text, status_code=422)
|
||||
|
||||
## CALCULATING USAGE - bedrock returns usage in the headers
|
||||
prompt_tokens = int(
|
||||
response.headers.get(
|
||||
"x-amzn-bedrock-input-token-count",
|
||||
len(encoding.encode("".join(m.get("content", "") for m in messages))),
|
||||
)
|
||||
)
|
||||
completion_tokens = int(
|
||||
response.headers.get(
|
||||
"x-amzn-bedrock-output-token-count",
|
||||
len(
|
||||
encoding.encode(
|
||||
model_response.choices[0].message.content, # type: ignore
|
||||
disallowed_special=(),
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
|
||||
return model_response
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
acompletion: bool,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
try:
|
||||
import boto3
|
||||
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
from botocore.credentials import Credentials
|
||||
except ImportError as e:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
|
||||
## SETUP ##
|
||||
stream = optional_params.pop("stream", None)
|
||||
|
||||
## CREDENTIALS ##
|
||||
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||
aws_role_name = optional_params.pop("aws_role_name", None)
|
||||
aws_session_name = optional_params.pop("aws_session_name", None)
|
||||
aws_profile_name = optional_params.pop("aws_profile_name", None)
|
||||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||
"aws_bedrock_runtime_endpoint", None
|
||||
) # https://bedrock-runtime.{region_name}.amazonaws.com
|
||||
|
||||
### SET REGION NAME ###
|
||||
if aws_region_name is None:
|
||||
# check env #
|
||||
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||
|
||||
if litellm_aws_region_name is not None and isinstance(
|
||||
litellm_aws_region_name, str
|
||||
):
|
||||
aws_region_name = litellm_aws_region_name
|
||||
|
||||
standard_aws_region_name = get_secret("AWS_REGION", None)
|
||||
if standard_aws_region_name is not None and isinstance(
|
||||
standard_aws_region_name, str
|
||||
):
|
||||
aws_region_name = standard_aws_region_name
|
||||
|
||||
if aws_region_name is None:
|
||||
aws_region_name = "us-west-2"
|
||||
|
||||
credentials: Credentials = self.get_credentials(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_region_name=aws_region_name,
|
||||
aws_session_name=aws_session_name,
|
||||
aws_profile_name=aws_profile_name,
|
||||
aws_role_name=aws_role_name,
|
||||
)
|
||||
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
endpoint_url = ""
|
||||
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
|
||||
if aws_bedrock_runtime_endpoint is not None and isinstance(
|
||||
aws_bedrock_runtime_endpoint, str
|
||||
):
|
||||
endpoint_url = aws_bedrock_runtime_endpoint
|
||||
elif env_aws_bedrock_runtime_endpoint and isinstance(
|
||||
env_aws_bedrock_runtime_endpoint, str
|
||||
):
|
||||
endpoint_url = env_aws_bedrock_runtime_endpoint
|
||||
else:
|
||||
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
|
||||
|
||||
if stream is not None and stream == True:
|
||||
endpoint_url = f"{endpoint_url}/model/{model}/invoke-with-response-stream"
|
||||
else:
|
||||
endpoint_url = f"{endpoint_url}/model/{model}/invoke"
|
||||
|
||||
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
||||
|
||||
provider = model.split(".")[0]
|
||||
prompt, chat_history = self.convert_messages_to_prompt(
|
||||
model, messages, provider, custom_prompt_dict
|
||||
)
|
||||
inference_params = copy.deepcopy(optional_params)
|
||||
|
||||
if provider == "cohere":
|
||||
if model.startswith("cohere.command-r"):
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonCohereChatConfig().get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
_data = {"message": prompt, **inference_params}
|
||||
if chat_history is not None:
|
||||
_data["chat_history"] = chat_history
|
||||
data = json.dumps(_data)
|
||||
else:
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonCohereConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
if stream == True:
|
||||
inference_params["stream"] = (
|
||||
True # cohere requires stream = True in inference params
|
||||
)
|
||||
data = json.dumps({"prompt": prompt, **inference_params})
|
||||
else:
|
||||
raise Exception("UNSUPPORTED PROVIDER")
|
||||
|
||||
## COMPLETION CALL
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if extra_headers is not None:
|
||||
headers = {"Content-Type": "application/json", **extra_headers}
|
||||
request = AWSRequest(
|
||||
method="POST", url=endpoint_url, data=data, headers=headers
|
||||
)
|
||||
sigv4.add_auth(request)
|
||||
prepped = request.prepare()
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": prepped.url,
|
||||
"headers": prepped.headers,
|
||||
},
|
||||
)
|
||||
|
||||
### ROUTING (ASYNC, STREAMING, SYNC)
|
||||
if acompletion:
|
||||
if isinstance(client, HTTPHandler):
|
||||
client = None
|
||||
if stream:
|
||||
return self.async_streaming(
|
||||
model=model,
|
||||
messages=messages,
|
||||
data=data,
|
||||
api_base=prepped.url,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
stream=True,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=prepped.headers,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
) # type: ignore
|
||||
### ASYNC COMPLETION
|
||||
return self.async_completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
data=data,
|
||||
api_base=prepped.url,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
stream=False,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=prepped.headers,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
) # type: ignore
|
||||
|
||||
if client is None or isinstance(client, AsyncHTTPHandler):
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = timeout
|
||||
self.client = HTTPHandler(**_params) # type: ignore
|
||||
else:
|
||||
self.client = client
|
||||
if stream is not None and stream == True:
|
||||
response = self.client.post(
|
||||
url=prepped.url,
|
||||
headers=prepped.headers, # type: ignore
|
||||
data=data,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise BedrockError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
|
||||
decoder = AWSEventStreamDecoder()
|
||||
|
||||
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
|
||||
streaming_response = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="bedrock",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return streaming_response
|
||||
|
||||
response = self.client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=response.text)
|
||||
|
||||
return self.process_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
stream=stream,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
api_key="",
|
||||
data=data,
|
||||
messages=messages,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
async def async_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
data: str,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
encoding,
|
||||
logging_obj,
|
||||
stream,
|
||||
optional_params: dict,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
) -> ModelResponse:
|
||||
if client is None:
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = timeout
|
||||
self.client = AsyncHTTPHandler(**_params) # type: ignore
|
||||
else:
|
||||
self.client = client # type: ignore
|
||||
|
||||
response = await self.client.post(api_base, headers=headers, data=data) # type: ignore
|
||||
return self.process_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
stream=stream,
|
||||
logging_obj=logging_obj,
|
||||
api_key="",
|
||||
data=data,
|
||||
messages=messages,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
async def async_streaming(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
data: str,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
encoding,
|
||||
logging_obj,
|
||||
stream,
|
||||
optional_params: dict,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
) -> CustomStreamWrapper:
|
||||
if client is None:
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = timeout
|
||||
self.client = AsyncHTTPHandler(**_params) # type: ignore
|
||||
else:
|
||||
self.client = client # type: ignore
|
||||
|
||||
response = await self.client.post(api_base, headers=headers, data=data, stream=True) # type: ignore
|
||||
|
||||
if response.status_code != 200:
|
||||
raise BedrockError(status_code=response.status_code, message=response.text)
|
||||
|
||||
decoder = AWSEventStreamDecoder()
|
||||
|
||||
completion_stream = decoder.aiter_bytes(response.aiter_bytes(chunk_size=1024))
|
||||
streaming_response = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="bedrock",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return streaming_response
|
||||
|
||||
def embedding(self, *args, **kwargs):
|
||||
return super().embedding(*args, **kwargs)
|
||||
|
||||
|
||||
def get_response_stream_shape():
|
||||
from botocore.model import ServiceModel
|
||||
from botocore.loaders import Loader
|
||||
|
||||
loader = Loader()
|
||||
bedrock_service_dict = loader.load_service_model("bedrock-runtime", "service-2")
|
||||
bedrock_service_model = ServiceModel(bedrock_service_dict)
|
||||
return bedrock_service_model.shape_for("ResponseStream")
|
||||
|
||||
|
||||
class AWSEventStreamDecoder:
|
||||
def __init__(self) -> None:
|
||||
from botocore.parsers import EventStreamJSONParser
|
||||
|
||||
self.parser = EventStreamJSONParser()
|
||||
|
||||
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GenericStreamingChunk]:
|
||||
"""Given an iterator that yields lines, iterate over it & yield every event encountered"""
|
||||
from botocore.eventstream import EventStreamBuffer
|
||||
|
||||
event_stream_buffer = EventStreamBuffer()
|
||||
for chunk in iterator:
|
||||
event_stream_buffer.add_data(chunk)
|
||||
for event in event_stream_buffer:
|
||||
message = self._parse_message_from_event(event)
|
||||
if message:
|
||||
# sse_event = ServerSentEvent(data=message, event="completion")
|
||||
_data = json.loads(message)
|
||||
streaming_chunk: GenericStreamingChunk = GenericStreamingChunk(
|
||||
text=_data.get("text", ""),
|
||||
is_finished=_data.get("is_finished", False),
|
||||
finish_reason=_data.get("finish_reason", ""),
|
||||
)
|
||||
yield streaming_chunk
|
||||
|
||||
async def aiter_bytes(
|
||||
self, iterator: AsyncIterator[bytes]
|
||||
) -> AsyncIterator[GenericStreamingChunk]:
|
||||
"""Given an async iterator that yields lines, iterate over it & yield every event encountered"""
|
||||
from botocore.eventstream import EventStreamBuffer
|
||||
|
||||
event_stream_buffer = EventStreamBuffer()
|
||||
async for chunk in iterator:
|
||||
event_stream_buffer.add_data(chunk)
|
||||
for event in event_stream_buffer:
|
||||
message = self._parse_message_from_event(event)
|
||||
if message:
|
||||
_data = json.loads(message)
|
||||
streaming_chunk: GenericStreamingChunk = GenericStreamingChunk(
|
||||
text=_data.get("text", ""),
|
||||
is_finished=_data.get("is_finished", False),
|
||||
finish_reason=_data.get("finish_reason", ""),
|
||||
)
|
||||
yield streaming_chunk
|
||||
|
||||
def _parse_message_from_event(self, event) -> Optional[str]:
|
||||
response_dict = event.to_response_dict()
|
||||
parsed_response = self.parser.parse(response_dict, get_response_stream_shape())
|
||||
if response_dict["status_code"] != 200:
|
||||
raise ValueError(f"Bad response code, expected 200: {response_dict}")
|
||||
|
||||
chunk = parsed_response.get("chunk")
|
||||
if not chunk:
|
||||
return None
|
||||
|
||||
return chunk.get("bytes").decode() # type: ignore[no-any-return]
|
|
@ -58,16 +58,25 @@ class AsyncHTTPHandler:
|
|||
|
||||
class HTTPHandler:
|
||||
def __init__(
|
||||
self, timeout: httpx.Timeout = _DEFAULT_TIMEOUT, concurrent_limit=1000
|
||||
self,
|
||||
timeout: Optional[httpx.Timeout] = None,
|
||||
concurrent_limit=1000,
|
||||
client: Optional[httpx.Client] = None,
|
||||
):
|
||||
# Create a client with a connection pool
|
||||
self.client = httpx.Client(
|
||||
timeout=timeout,
|
||||
limits=httpx.Limits(
|
||||
max_connections=concurrent_limit,
|
||||
max_keepalive_connections=concurrent_limit,
|
||||
),
|
||||
)
|
||||
if timeout is None:
|
||||
timeout = _DEFAULT_TIMEOUT
|
||||
|
||||
if client is None:
|
||||
# Create a client with a connection pool
|
||||
self.client = httpx.Client(
|
||||
timeout=timeout,
|
||||
limits=httpx.Limits(
|
||||
max_connections=concurrent_limit,
|
||||
max_keepalive_connections=concurrent_limit,
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.client = client
|
||||
|
||||
def close(self):
|
||||
# Close the client when you're done with it
|
||||
|
@ -82,11 +91,15 @@ class HTTPHandler:
|
|||
def post(
|
||||
self,
|
||||
url: str,
|
||||
data: Optional[dict] = None,
|
||||
data: Optional[Union[dict, str]] = None,
|
||||
params: Optional[dict] = None,
|
||||
headers: Optional[dict] = None,
|
||||
stream: bool = False,
|
||||
):
|
||||
response = self.client.post(url, data=data, params=params, headers=headers)
|
||||
req = self.client.build_request(
|
||||
"POST", url, data=data, params=params, headers=headers # type: ignore
|
||||
)
|
||||
response = self.client.send(req, stream=stream)
|
||||
return response
|
||||
|
||||
def __del__(self) -> None:
|
||||
|
|
|
@ -168,7 +168,7 @@ class PredibaseChatCompletion(BaseLLM):
|
|||
logging_obj: litellm.utils.Logging,
|
||||
optional_params: dict,
|
||||
api_key: str,
|
||||
data: dict,
|
||||
data: Union[dict, str],
|
||||
messages: list,
|
||||
print_verbose,
|
||||
encoding,
|
||||
|
@ -185,9 +185,7 @@ class PredibaseChatCompletion(BaseLLM):
|
|||
try:
|
||||
completion_response = response.json()
|
||||
except:
|
||||
raise PredibaseError(
|
||||
message=response.text, status_code=response.status_code
|
||||
)
|
||||
raise PredibaseError(message=response.text, status_code=422)
|
||||
if "error" in completion_response:
|
||||
raise PredibaseError(
|
||||
message=str(completion_response["error"]),
|
||||
|
@ -363,7 +361,7 @@ class PredibaseChatCompletion(BaseLLM):
|
|||
},
|
||||
)
|
||||
## COMPLETION CALL
|
||||
if acompletion is True:
|
||||
if acompletion == True:
|
||||
### ASYNC STREAMING
|
||||
if stream == True:
|
||||
return self.async_streaming(
|
||||
|
|
102
litellm/main.py
102
litellm/main.py
|
@ -76,6 +76,7 @@ from .llms.anthropic import AnthropicChatCompletion
|
|||
from .llms.anthropic_text import AnthropicTextCompletion
|
||||
from .llms.huggingface_restapi import Huggingface
|
||||
from .llms.predibase import PredibaseChatCompletion
|
||||
from .llms.bedrock_httpx import BedrockLLM
|
||||
from .llms.triton import TritonChatCompletion
|
||||
from .llms.prompt_templates.factory import (
|
||||
prompt_factory,
|
||||
|
@ -105,7 +106,6 @@ from litellm.utils import (
|
|||
)
|
||||
|
||||
####### ENVIRONMENT VARIABLES ###################
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
openai_chat_completions = OpenAIChatCompletion()
|
||||
openai_text_completions = OpenAITextCompletion()
|
||||
anthropic_chat_completions = AnthropicChatCompletion()
|
||||
|
@ -115,6 +115,7 @@ azure_text_completions = AzureTextCompletion()
|
|||
huggingface = Huggingface()
|
||||
predibase_chat_completions = PredibaseChatCompletion()
|
||||
triton_chat_completions = TritonChatCompletion()
|
||||
bedrock_chat_completion = BedrockLLM()
|
||||
####### COMPLETION ENDPOINTS ################
|
||||
|
||||
|
||||
|
@ -257,7 +258,7 @@ async def acompletion(
|
|||
- If `stream` is True, the function returns an async generator that yields completion lines.
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
custom_llm_provider = None
|
||||
custom_llm_provider = kwargs.get("custom_llm_provider", None)
|
||||
# Adjusted to use explicit arguments instead of *args and **kwargs
|
||||
completion_kwargs = {
|
||||
"model": model,
|
||||
|
@ -289,9 +290,10 @@ async def acompletion(
|
|||
"model_list": model_list,
|
||||
"acompletion": True, # assuming this is a required parameter
|
||||
}
|
||||
_, custom_llm_provider, _, _ = get_llm_provider(
|
||||
model=model, api_base=completion_kwargs.get("base_url", None)
|
||||
)
|
||||
if custom_llm_provider is None:
|
||||
_, custom_llm_provider, _, _ = get_llm_provider(
|
||||
model=model, api_base=completion_kwargs.get("base_url", None)
|
||||
)
|
||||
try:
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(completion, **completion_kwargs, **kwargs)
|
||||
|
@ -300,9 +302,6 @@ async def acompletion(
|
|||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
|
||||
_, custom_llm_provider, _, _ = get_llm_provider(
|
||||
model=model, api_base=kwargs.get("api_base", None)
|
||||
)
|
||||
if (
|
||||
custom_llm_provider == "openai"
|
||||
or custom_llm_provider == "azure"
|
||||
|
@ -324,6 +323,7 @@ async def acompletion(
|
|||
or custom_llm_provider == "sagemaker"
|
||||
or custom_llm_provider == "anthropic"
|
||||
or custom_llm_provider == "predibase"
|
||||
or (custom_llm_provider == "bedrock" and "cohere" in model)
|
||||
or custom_llm_provider in litellm.openai_compatible_providers
|
||||
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
|
@ -1976,41 +1976,59 @@ def completion(
|
|||
elif custom_llm_provider == "bedrock":
|
||||
# boto3 reads keys from .env
|
||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||
response = bedrock.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
custom_prompt_dict=litellm.custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
extra_headers=extra_headers,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
if (
|
||||
"stream" in optional_params
|
||||
and optional_params["stream"] == True
|
||||
and not isinstance(response, CustomStreamWrapper)
|
||||
):
|
||||
# don't try to access stream object,
|
||||
if "ai21" in model:
|
||||
response = CustomStreamWrapper(
|
||||
response,
|
||||
model,
|
||||
custom_llm_provider="bedrock",
|
||||
logging_obj=logging,
|
||||
)
|
||||
else:
|
||||
response = CustomStreamWrapper(
|
||||
iter(response),
|
||||
model,
|
||||
custom_llm_provider="bedrock",
|
||||
logging_obj=logging,
|
||||
)
|
||||
if "cohere" in model:
|
||||
response = bedrock_chat_completion.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
custom_prompt_dict=litellm.custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
extra_headers=extra_headers,
|
||||
timeout=timeout,
|
||||
acompletion=acompletion,
|
||||
)
|
||||
else:
|
||||
response = bedrock.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
custom_prompt_dict=litellm.custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
extra_headers=extra_headers,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
if (
|
||||
"stream" in optional_params
|
||||
and optional_params["stream"] == True
|
||||
and not isinstance(response, CustomStreamWrapper)
|
||||
):
|
||||
# don't try to access stream object,
|
||||
if "ai21" in model:
|
||||
response = CustomStreamWrapper(
|
||||
response,
|
||||
model,
|
||||
custom_llm_provider="bedrock",
|
||||
logging_obj=logging,
|
||||
)
|
||||
else:
|
||||
response = CustomStreamWrapper(
|
||||
iter(response),
|
||||
model,
|
||||
custom_llm_provider="bedrock",
|
||||
logging_obj=logging,
|
||||
)
|
||||
|
||||
if optional_params.get("stream", False):
|
||||
## LOGGING
|
||||
|
|
|
@ -1,10 +1,7 @@
|
|||
from litellm.proxy._types import UserAPIKeyAuth, GenerateKeyRequest
|
||||
from fastapi import Request
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
|
||||
try:
|
||||
|
|
|
@ -1507,7 +1507,6 @@ class Router:
|
|||
return response
|
||||
except Exception as e:
|
||||
original_exception = e
|
||||
|
||||
"""
|
||||
Retry Logic
|
||||
|
||||
|
|
|
@ -8,8 +8,6 @@
|
|||
|
||||
import dotenv, os, requests, random # type: ignore
|
||||
from typing import Optional
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
from litellm.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
#### What this does ####
|
||||
# picks based on response time (for streaming, this is time to first token)
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
import dotenv, os, requests, random # type: ignore
|
||||
import os, requests, random # type: ignore
|
||||
from typing import Optional, Union, List, Dict
|
||||
from datetime import datetime, timedelta
|
||||
import random
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
from litellm.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
|
|
@ -5,8 +5,6 @@ import dotenv, os, requests, random # type: ignore
|
|||
from typing import Optional, Union, List, Dict
|
||||
from datetime import datetime, timedelta
|
||||
import random
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
from litellm.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
|
|
@ -4,8 +4,6 @@
|
|||
import dotenv, os, requests, random
|
||||
from typing import Optional, Union, List, Dict
|
||||
from datetime import datetime
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
from litellm import token_counter
|
||||
from litellm.caching import DualCache
|
||||
|
|
|
@ -5,8 +5,6 @@ import dotenv, os, requests, random
|
|||
from typing import Optional, Union, List, Dict
|
||||
import datetime as datetime_og
|
||||
from datetime import datetime
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback, asyncio, httpx
|
||||
import litellm
|
||||
from litellm import token_counter
|
||||
|
|
|
@ -2584,6 +2584,69 @@ def test_completion_chat_sagemaker_mistral():
|
|||
# test_completion_chat_sagemaker_mistral()
|
||||
|
||||
|
||||
def response_format_tests(response: litellm.ModelResponse):
|
||||
assert isinstance(response.id, str)
|
||||
assert response.id != ""
|
||||
|
||||
assert isinstance(response.object, str)
|
||||
assert response.object != ""
|
||||
|
||||
assert isinstance(response.created, int)
|
||||
|
||||
assert isinstance(response.model, str)
|
||||
assert response.model != ""
|
||||
|
||||
assert isinstance(response.choices, list)
|
||||
assert len(response.choices) == 1
|
||||
choice = response.choices[0]
|
||||
assert isinstance(choice, litellm.Choices)
|
||||
assert isinstance(choice.get("index"), int)
|
||||
|
||||
message = choice.get("message")
|
||||
assert isinstance(message, litellm.Message)
|
||||
assert isinstance(message.get("role"), str)
|
||||
assert message.get("role") != ""
|
||||
assert isinstance(message.get("content"), str)
|
||||
assert message.get("content") != ""
|
||||
|
||||
assert choice.get("logprobs") is None
|
||||
assert isinstance(choice.get("finish_reason"), str)
|
||||
assert choice.get("finish_reason") != ""
|
||||
|
||||
assert isinstance(response.usage, litellm.Usage) # type: ignore
|
||||
assert isinstance(response.usage.prompt_tokens, int) # type: ignore
|
||||
assert isinstance(response.usage.completion_tokens, int) # type: ignore
|
||||
assert isinstance(response.usage.total_tokens, int) # type: ignore
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_bedrock_command_r(sync_mode):
|
||||
litellm.set_verbose = True
|
||||
|
||||
if sync_mode:
|
||||
response = completion(
|
||||
model="bedrock/cohere.command-r-plus-v1:0",
|
||||
messages=[{"role": "user", "content": "Hey! how's it going?"}],
|
||||
)
|
||||
|
||||
assert isinstance(response, litellm.ModelResponse)
|
||||
|
||||
response_format_tests(response=response)
|
||||
else:
|
||||
response = await litellm.acompletion(
|
||||
model="bedrock/cohere.command-r-plus-v1:0",
|
||||
messages=[{"role": "user", "content": "Hey! how's it going?"}],
|
||||
)
|
||||
|
||||
assert isinstance(response, litellm.ModelResponse)
|
||||
|
||||
print(f"response: {response}")
|
||||
response_format_tests(response=response)
|
||||
|
||||
print(f"response: {response}")
|
||||
|
||||
|
||||
def test_completion_bedrock_titan_null_response():
|
||||
try:
|
||||
response = completion(
|
||||
|
|
|
@ -132,12 +132,15 @@ def test_post_call_rule_streaming():
|
|||
)
|
||||
|
||||
|
||||
def test_post_call_processing_error_async_response():
|
||||
response = asyncio.run(
|
||||
acompletion(
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_call_processing_error_async_response():
|
||||
try:
|
||||
response = await acompletion(
|
||||
model="command-nightly", # Just used as an example
|
||||
messages=[{"content": "Hello, how are you?", "role": "user"}],
|
||||
api_base="https://openai-proxy.berriai.repl.co", # Just used as an example
|
||||
custom_llm_provider="openai",
|
||||
)
|
||||
)
|
||||
pytest.fail("This call should have failed")
|
||||
except Exception as e:
|
||||
pass
|
||||
|
|
|
@ -983,6 +983,64 @@ def test_vertex_ai_stream():
|
|||
# pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True])
|
||||
@pytest.mark.asyncio
|
||||
async def test_bedrock_cohere_command_r_streaming(sync_mode):
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
if sync_mode:
|
||||
final_chunk: Optional[litellm.ModelResponse] = None
|
||||
response: litellm.CustomStreamWrapper = completion( # type: ignore
|
||||
model="bedrock/cohere.command-r-plus-v1:0",
|
||||
messages=messages,
|
||||
max_tokens=10, # type: ignore
|
||||
stream=True,
|
||||
)
|
||||
complete_response = ""
|
||||
# Add any assertions here to check the response
|
||||
has_finish_reason = False
|
||||
for idx, chunk in enumerate(response):
|
||||
final_chunk = chunk
|
||||
chunk, finished = streaming_format_tests(idx, chunk)
|
||||
if finished:
|
||||
has_finish_reason = True
|
||||
break
|
||||
complete_response += chunk
|
||||
if has_finish_reason == False:
|
||||
raise Exception("finish reason not set")
|
||||
if complete_response.strip() == "":
|
||||
raise Exception("Empty response received")
|
||||
else:
|
||||
response: litellm.CustomStreamWrapper = await litellm.acompletion( # type: ignore
|
||||
model="bedrock/cohere.command-r-plus-v1:0",
|
||||
messages=messages,
|
||||
max_tokens=100, # type: ignore
|
||||
stream=True,
|
||||
)
|
||||
complete_response = ""
|
||||
# Add any assertions here to check the response
|
||||
has_finish_reason = False
|
||||
idx = 0
|
||||
final_chunk: Optional[litellm.ModelResponse] = None
|
||||
async for chunk in response:
|
||||
final_chunk = chunk
|
||||
chunk, finished = streaming_format_tests(idx, chunk)
|
||||
if finished:
|
||||
has_finish_reason = True
|
||||
break
|
||||
complete_response += chunk
|
||||
idx += 1
|
||||
if has_finish_reason == False:
|
||||
raise Exception("finish reason not set")
|
||||
if complete_response.strip() == "":
|
||||
raise Exception("Empty response received")
|
||||
print(f"completion_response: {complete_response}\n\nFinalChunk: {final_chunk}")
|
||||
except RateLimitError:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_bedrock_claude_3_streaming():
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
|
|
63
litellm/types/llms/bedrock.py
Normal file
63
litellm/types/llms/bedrock.py
Normal file
|
@ -0,0 +1,63 @@
|
|||
from typing import TypedDict, Any, Union, Optional
|
||||
import json
|
||||
from typing_extensions import (
|
||||
Self,
|
||||
Protocol,
|
||||
TypeGuard,
|
||||
override,
|
||||
get_origin,
|
||||
runtime_checkable,
|
||||
Required,
|
||||
)
|
||||
|
||||
|
||||
class GenericStreamingChunk(TypedDict):
|
||||
text: Required[str]
|
||||
is_finished: Required[bool]
|
||||
finish_reason: Required[str]
|
||||
|
||||
|
||||
class Document(TypedDict):
|
||||
title: str
|
||||
snippet: str
|
||||
|
||||
|
||||
class ServerSentEvent:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
event: Optional[str] = None,
|
||||
data: Optional[str] = None,
|
||||
id: Optional[str] = None,
|
||||
retry: Optional[int] = None,
|
||||
) -> None:
|
||||
if data is None:
|
||||
data = ""
|
||||
|
||||
self._id = id
|
||||
self._data = data
|
||||
self._event = event or None
|
||||
self._retry = retry
|
||||
|
||||
@property
|
||||
def event(self) -> Optional[str]:
|
||||
return self._event
|
||||
|
||||
@property
|
||||
def id(self) -> Optional[str]:
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def retry(self) -> Optional[int]:
|
||||
return self._retry
|
||||
|
||||
@property
|
||||
def data(self) -> str:
|
||||
return self._data
|
||||
|
||||
def json(self) -> Any:
|
||||
return json.loads(self.data)
|
||||
|
||||
@override
|
||||
def __repr__(self) -> str:
|
||||
return f"ServerSentEvent(event={self.event}, data={self.data}, id={self.id}, retry={self.retry})"
|
|
@ -132,7 +132,6 @@ MAX_THREADS = 100
|
|||
|
||||
# Create a ThreadPoolExecutor
|
||||
executor = ThreadPoolExecutor(max_workers=MAX_THREADS)
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
sentry_sdk_instance = None
|
||||
capture_exception = None
|
||||
add_breadcrumb = None
|
||||
|
@ -10474,6 +10473,12 @@ class CustomStreamWrapper:
|
|||
raise e
|
||||
|
||||
def handle_bedrock_stream(self, chunk):
|
||||
if "cohere" in self.model:
|
||||
return {
|
||||
"text": chunk["text"],
|
||||
"is_finished": chunk["is_finished"],
|
||||
"finish_reason": chunk["finish_reason"],
|
||||
}
|
||||
if hasattr(chunk, "get"):
|
||||
chunk = chunk.get("chunk")
|
||||
chunk_data = json.loads(chunk.get("bytes").decode())
|
||||
|
@ -11322,6 +11327,7 @@ class CustomStreamWrapper:
|
|||
or self.custom_llm_provider == "gemini"
|
||||
or self.custom_llm_provider == "cached_response"
|
||||
or self.custom_llm_provider == "predibase"
|
||||
or (self.custom_llm_provider == "bedrock" and "cohere" in self.model)
|
||||
or self.custom_llm_provider in litellm.openai_compatible_endpoints
|
||||
):
|
||||
async for chunk in self.completion_stream:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue