forked from phoenix/litellm-mirror
Merge pull request #3586 from BerriAI/litellm_bedrock_command_r_support
feat(bedrock_httpx.py): Make Bedrock-Cohere calls Async
This commit is contained in:
commit
94c9df969e
39 changed files with 1222 additions and 195 deletions
|
@ -10,7 +10,6 @@ from litellm.caching import DualCache
|
||||||
|
|
||||||
from typing import Literal, Union
|
from typing import Literal, Union
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,8 +18,6 @@ import traceback
|
||||||
|
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
import datetime, subprocess, sys
|
import datetime, subprocess, sys
|
||||||
import litellm, uuid
|
import litellm, uuid
|
||||||
|
|
|
@ -743,6 +743,7 @@ from .llms.sagemaker import SagemakerConfig
|
||||||
from .llms.ollama import OllamaConfig
|
from .llms.ollama import OllamaConfig
|
||||||
from .llms.ollama_chat import OllamaChatConfig
|
from .llms.ollama_chat import OllamaChatConfig
|
||||||
from .llms.maritalk import MaritTalkConfig
|
from .llms.maritalk import MaritTalkConfig
|
||||||
|
from .llms.bedrock_httpx import AmazonCohereChatConfig
|
||||||
from .llms.bedrock import (
|
from .llms.bedrock import (
|
||||||
AmazonTitanConfig,
|
AmazonTitanConfig,
|
||||||
AmazonAI21Config,
|
AmazonAI21Config,
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# On success + failure, log events to aispend.io
|
# On success + failure, log events to aispend.io
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,6 @@
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
|
|
|
@ -8,8 +8,6 @@ from litellm.proxy._types import UserAPIKeyAuth
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
|
|
||||||
from typing import Literal, Union
|
from typing import Literal, Union
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
|
||||||
|
@ -18,8 +16,6 @@ import traceback
|
||||||
|
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
import datetime, subprocess, sys
|
import datetime, subprocess, sys
|
||||||
import litellm, uuid
|
import litellm, uuid
|
||||||
|
|
|
@ -6,8 +6,6 @@ from litellm.proxy._types import UserAPIKeyAuth
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
|
|
||||||
from typing import Literal, Union, Optional
|
from typing import Literal, Union, Optional
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -3,8 +3,6 @@
|
||||||
|
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
import datetime, subprocess, sys
|
import datetime, subprocess, sys
|
||||||
import litellm, uuid
|
import litellm, uuid
|
||||||
|
|
|
@ -3,8 +3,6 @@
|
||||||
|
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
import datetime, subprocess, sys
|
import datetime, subprocess, sys
|
||||||
import litellm, uuid
|
import litellm, uuid
|
||||||
|
|
|
@ -3,8 +3,6 @@
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# On success, logs events to Langfuse
|
# On success, logs events to Langfuse
|
||||||
import dotenv, os
|
import os
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import copy
|
import copy
|
||||||
import traceback
|
import traceback
|
||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
|
|
|
@ -3,8 +3,6 @@
|
||||||
import dotenv, os # type: ignore
|
import dotenv, os # type: ignore
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
import asyncio
|
import asyncio
|
||||||
import types
|
import types
|
||||||
|
|
|
@ -2,13 +2,10 @@
|
||||||
# On success + failure, log events to lunary.ai
|
# On success + failure, log events to lunary.ai
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
import traceback
|
import traceback
|
||||||
import dotenv
|
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
import packaging
|
import packaging
|
||||||
|
|
||||||
dotenv.load_dotenv()
|
|
||||||
|
|
||||||
|
|
||||||
# convert to {completion: xx, tokens: xx}
|
# convert to {completion: xx, tokens: xx}
|
||||||
def parse_usage(usage):
|
def parse_usage(usage):
|
||||||
|
@ -79,14 +76,16 @@ class LunaryLogger:
|
||||||
version = importlib.metadata.version("lunary")
|
version = importlib.metadata.version("lunary")
|
||||||
# if version < 0.1.43 then raise ImportError
|
# if version < 0.1.43 then raise ImportError
|
||||||
if packaging.version.Version(version) < packaging.version.Version("0.1.43"):
|
if packaging.version.Version(version) < packaging.version.Version("0.1.43"):
|
||||||
print(
|
print( # noqa
|
||||||
"Lunary version outdated. Required: >= 0.1.43. Upgrade via 'pip install lunary --upgrade'"
|
"Lunary version outdated. Required: >= 0.1.43. Upgrade via 'pip install lunary --upgrade'"
|
||||||
)
|
)
|
||||||
raise ImportError
|
raise ImportError
|
||||||
|
|
||||||
self.lunary_client = lunary
|
self.lunary_client = lunary
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("Lunary not installed. Please install it using 'pip install lunary'")
|
print( # noqa
|
||||||
|
"Lunary not installed. Please install it using 'pip install lunary'"
|
||||||
|
) # noqa
|
||||||
raise ImportError
|
raise ImportError
|
||||||
|
|
||||||
def log_event(
|
def log_event(
|
||||||
|
|
|
@ -3,8 +3,6 @@
|
||||||
|
|
||||||
import dotenv, os, json
|
import dotenv, os, json
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
|
|
|
@ -4,8 +4,6 @@
|
||||||
|
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
import datetime, subprocess, sys
|
import datetime, subprocess, sys
|
||||||
import litellm, uuid
|
import litellm, uuid
|
||||||
|
|
|
@ -5,8 +5,6 @@
|
||||||
|
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
import datetime, subprocess, sys
|
import datetime, subprocess, sys
|
||||||
import litellm, uuid
|
import litellm, uuid
|
||||||
|
|
|
@ -3,8 +3,6 @@
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,7 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# On success + failure, log events to Supabase
|
# On success + failure, log events to Supabase
|
||||||
|
|
||||||
import dotenv, os
|
import os
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
import datetime, subprocess, sys
|
import datetime, subprocess, sys
|
||||||
import litellm, uuid
|
import litellm, uuid
|
||||||
|
|
|
@ -2,8 +2,6 @@
|
||||||
# Class for sending Slack Alerts #
|
# Class for sending Slack Alerts #
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
from litellm._logging import verbose_logger, verbose_proxy_logger
|
from litellm._logging import verbose_logger, verbose_proxy_logger
|
||||||
import litellm, threading
|
import litellm, threading
|
||||||
from typing import List, Literal, Any, Union, Optional, Dict
|
from typing import List, Literal, Any, Union, Optional, Dict
|
||||||
|
|
|
@ -3,8 +3,6 @@
|
||||||
|
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
import datetime, subprocess, sys
|
import datetime, subprocess, sys
|
||||||
import litellm
|
import litellm
|
||||||
|
|
|
@ -21,11 +21,11 @@ try:
|
||||||
# contains a (known) object attribute
|
# contains a (known) object attribute
|
||||||
object: Literal["chat.completion", "edit", "text_completion"]
|
object: Literal["chat.completion", "edit", "text_completion"]
|
||||||
|
|
||||||
def __getitem__(self, key: K) -> V:
|
def __getitem__(self, key: K) -> V: ... # noqa
|
||||||
... # pragma: no cover
|
|
||||||
|
|
||||||
def get(self, key: K, default: Optional[V] = None) -> Optional[V]:
|
def get( # noqa
|
||||||
... # pragma: no cover
|
self, key: K, default: Optional[V] = None
|
||||||
|
) -> Optional[V]: ... # pragma: no cover
|
||||||
|
|
||||||
class OpenAIRequestResponseResolver:
|
class OpenAIRequestResponseResolver:
|
||||||
def __call__(
|
def __call__(
|
||||||
|
@ -173,12 +173,11 @@ except:
|
||||||
|
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# On success, logs events to Langfuse
|
# On success, logs events to Langfuse
|
||||||
import dotenv, os
|
import os
|
||||||
import requests
|
import requests
|
||||||
import requests
|
import requests
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import requests, copy # type: ignore
|
import requests, copy # type: ignore
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Optional, List
|
from typing import Callable, Optional, List, Union
|
||||||
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
|
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
|
||||||
import litellm
|
import litellm
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
|
@ -151,19 +151,135 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
def process_streaming_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
response: Union[requests.Response, httpx.Response],
|
||||||
|
model_response: ModelResponse,
|
||||||
|
stream: bool,
|
||||||
|
logging_obj: litellm.utils.Logging,
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: str,
|
||||||
|
data: Union[dict, str],
|
||||||
|
messages: List,
|
||||||
|
print_verbose,
|
||||||
|
encoding,
|
||||||
|
) -> CustomStreamWrapper:
|
||||||
|
"""
|
||||||
|
Return stream object for tool-calling + streaming
|
||||||
|
"""
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key=api_key,
|
||||||
|
original_response=response.text,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
print_verbose(f"raw model_response: {response.text}")
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
try:
|
||||||
|
completion_response = response.json()
|
||||||
|
except:
|
||||||
|
raise AnthropicError(
|
||||||
|
message=response.text, status_code=response.status_code
|
||||||
|
)
|
||||||
|
text_content = ""
|
||||||
|
tool_calls = []
|
||||||
|
for content in completion_response["content"]:
|
||||||
|
if content["type"] == "text":
|
||||||
|
text_content += content["text"]
|
||||||
|
## TOOL CALLING
|
||||||
|
elif content["type"] == "tool_use":
|
||||||
|
tool_calls.append(
|
||||||
|
{
|
||||||
|
"id": content["id"],
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": content["name"],
|
||||||
|
"arguments": json.dumps(content["input"]),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if "error" in completion_response:
|
||||||
|
raise AnthropicError(
|
||||||
|
message=str(completion_response["error"]),
|
||||||
|
status_code=response.status_code,
|
||||||
|
)
|
||||||
|
_message = litellm.Message(
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
content=text_content or None,
|
||||||
|
)
|
||||||
|
model_response.choices[0].message = _message # type: ignore
|
||||||
|
model_response._hidden_params["original_response"] = completion_response[
|
||||||
|
"content"
|
||||||
|
] # allow user to access raw anthropic tool calling response
|
||||||
|
|
||||||
|
model_response.choices[0].finish_reason = map_finish_reason(
|
||||||
|
completion_response["stop_reason"]
|
||||||
|
)
|
||||||
|
|
||||||
|
print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
|
||||||
|
# return an iterator
|
||||||
|
streaming_model_response = ModelResponse(stream=True)
|
||||||
|
streaming_model_response.choices[0].finish_reason = model_response.choices[ # type: ignore
|
||||||
|
0
|
||||||
|
].finish_reason
|
||||||
|
# streaming_model_response.choices = [litellm.utils.StreamingChoices()]
|
||||||
|
streaming_choice = litellm.utils.StreamingChoices()
|
||||||
|
streaming_choice.index = model_response.choices[0].index
|
||||||
|
_tool_calls = []
|
||||||
|
print_verbose(
|
||||||
|
f"type of model_response.choices[0]: {type(model_response.choices[0])}"
|
||||||
|
)
|
||||||
|
print_verbose(f"type of streaming_choice: {type(streaming_choice)}")
|
||||||
|
if isinstance(model_response.choices[0], litellm.Choices):
|
||||||
|
if getattr(
|
||||||
|
model_response.choices[0].message, "tool_calls", None
|
||||||
|
) is not None and isinstance(
|
||||||
|
model_response.choices[0].message.tool_calls, list
|
||||||
|
):
|
||||||
|
for tool_call in model_response.choices[0].message.tool_calls:
|
||||||
|
_tool_call = {**tool_call.dict(), "index": 0}
|
||||||
|
_tool_calls.append(_tool_call)
|
||||||
|
delta_obj = litellm.utils.Delta(
|
||||||
|
content=getattr(model_response.choices[0].message, "content", None),
|
||||||
|
role=model_response.choices[0].message.role,
|
||||||
|
tool_calls=_tool_calls,
|
||||||
|
)
|
||||||
|
streaming_choice.delta = delta_obj
|
||||||
|
streaming_model_response.choices = [streaming_choice]
|
||||||
|
completion_stream = ModelResponseIterator(
|
||||||
|
model_response=streaming_model_response
|
||||||
|
)
|
||||||
|
print_verbose(
|
||||||
|
"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
|
||||||
|
)
|
||||||
|
return CustomStreamWrapper(
|
||||||
|
completion_stream=completion_stream,
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="cached_response",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise AnthropicError(
|
||||||
|
status_code=422,
|
||||||
|
message="Unprocessable response object - {}".format(response.text),
|
||||||
|
)
|
||||||
|
|
||||||
def process_response(
|
def process_response(
|
||||||
self,
|
self,
|
||||||
model,
|
model: str,
|
||||||
response,
|
response: Union[requests.Response, httpx.Response],
|
||||||
model_response,
|
model_response: ModelResponse,
|
||||||
_is_function_call,
|
stream: bool,
|
||||||
stream,
|
logging_obj: litellm.utils.Logging,
|
||||||
logging_obj,
|
optional_params: dict,
|
||||||
api_key,
|
api_key: str,
|
||||||
data,
|
data: Union[dict, str],
|
||||||
messages,
|
messages: List,
|
||||||
print_verbose,
|
print_verbose,
|
||||||
):
|
encoding,
|
||||||
|
) -> ModelResponse:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
input=messages,
|
input=messages,
|
||||||
|
@ -216,51 +332,6 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
completion_response["stop_reason"]
|
completion_response["stop_reason"]
|
||||||
)
|
)
|
||||||
|
|
||||||
print_verbose(f"_is_function_call: {_is_function_call}; stream: {stream}")
|
|
||||||
if _is_function_call and stream:
|
|
||||||
print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
|
|
||||||
# return an iterator
|
|
||||||
streaming_model_response = ModelResponse(stream=True)
|
|
||||||
streaming_model_response.choices[0].finish_reason = model_response.choices[
|
|
||||||
0
|
|
||||||
].finish_reason
|
|
||||||
# streaming_model_response.choices = [litellm.utils.StreamingChoices()]
|
|
||||||
streaming_choice = litellm.utils.StreamingChoices()
|
|
||||||
streaming_choice.index = model_response.choices[0].index
|
|
||||||
_tool_calls = []
|
|
||||||
print_verbose(
|
|
||||||
f"type of model_response.choices[0]: {type(model_response.choices[0])}"
|
|
||||||
)
|
|
||||||
print_verbose(f"type of streaming_choice: {type(streaming_choice)}")
|
|
||||||
if isinstance(model_response.choices[0], litellm.Choices):
|
|
||||||
if getattr(
|
|
||||||
model_response.choices[0].message, "tool_calls", None
|
|
||||||
) is not None and isinstance(
|
|
||||||
model_response.choices[0].message.tool_calls, list
|
|
||||||
):
|
|
||||||
for tool_call in model_response.choices[0].message.tool_calls:
|
|
||||||
_tool_call = {**tool_call.dict(), "index": 0}
|
|
||||||
_tool_calls.append(_tool_call)
|
|
||||||
delta_obj = litellm.utils.Delta(
|
|
||||||
content=getattr(model_response.choices[0].message, "content", None),
|
|
||||||
role=model_response.choices[0].message.role,
|
|
||||||
tool_calls=_tool_calls,
|
|
||||||
)
|
|
||||||
streaming_choice.delta = delta_obj
|
|
||||||
streaming_model_response.choices = [streaming_choice]
|
|
||||||
completion_stream = ModelResponseIterator(
|
|
||||||
model_response=streaming_model_response
|
|
||||||
)
|
|
||||||
print_verbose(
|
|
||||||
"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
|
|
||||||
)
|
|
||||||
return CustomStreamWrapper(
|
|
||||||
completion_stream=completion_stream,
|
|
||||||
model=model,
|
|
||||||
custom_llm_provider="cached_response",
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
)
|
|
||||||
|
|
||||||
## CALCULATING USAGE
|
## CALCULATING USAGE
|
||||||
prompt_tokens = completion_response["usage"]["input_tokens"]
|
prompt_tokens = completion_response["usage"]["input_tokens"]
|
||||||
completion_tokens = completion_response["usage"]["output_tokens"]
|
completion_tokens = completion_response["usage"]["output_tokens"]
|
||||||
|
@ -273,7 +344,7 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
completion_tokens=completion_tokens,
|
completion_tokens=completion_tokens,
|
||||||
total_tokens=total_tokens,
|
total_tokens=total_tokens,
|
||||||
)
|
)
|
||||||
model_response.usage = usage
|
setattr(model_response, "usage", usage) # type: ignore
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
async def acompletion_stream_function(
|
async def acompletion_stream_function(
|
||||||
|
@ -289,7 +360,7 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
logging_obj,
|
logging_obj,
|
||||||
stream,
|
stream,
|
||||||
_is_function_call,
|
_is_function_call,
|
||||||
data=None,
|
data: dict,
|
||||||
optional_params=None,
|
optional_params=None,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
|
@ -331,29 +402,44 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
logging_obj,
|
logging_obj,
|
||||||
stream,
|
stream,
|
||||||
_is_function_call,
|
_is_function_call,
|
||||||
data=None,
|
data: dict,
|
||||||
optional_params=None,
|
optional_params: dict,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
headers={},
|
headers={},
|
||||||
):
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
self.async_handler = AsyncHTTPHandler(
|
self.async_handler = AsyncHTTPHandler(
|
||||||
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||||
)
|
)
|
||||||
response = await self.async_handler.post(
|
response = await self.async_handler.post(
|
||||||
api_base, headers=headers, data=json.dumps(data)
|
api_base, headers=headers, data=json.dumps(data)
|
||||||
)
|
)
|
||||||
return self.process_response(
|
if stream and _is_function_call:
|
||||||
|
return self.process_streaming_response(
|
||||||
model=model,
|
model=model,
|
||||||
response=response,
|
response=response,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
_is_function_call=_is_function_call,
|
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
data=data,
|
data=data,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
|
optional_params=optional_params,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
return self.process_response(
|
||||||
|
model=model,
|
||||||
|
response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
stream=stream,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_key=api_key,
|
||||||
|
data=data,
|
||||||
|
messages=messages,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
optional_params=optional_params,
|
||||||
|
encoding=encoding,
|
||||||
)
|
)
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
|
@ -367,7 +453,7 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
encoding,
|
encoding,
|
||||||
api_key,
|
api_key,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
optional_params=None,
|
optional_params: dict,
|
||||||
acompletion=None,
|
acompletion=None,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
|
@ -526,17 +612,33 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
raise AnthropicError(
|
raise AnthropicError(
|
||||||
status_code=response.status_code, message=response.text
|
status_code=response.status_code, message=response.text
|
||||||
)
|
)
|
||||||
return self.process_response(
|
|
||||||
|
if stream and _is_function_call:
|
||||||
|
return self.process_streaming_response(
|
||||||
model=model,
|
model=model,
|
||||||
response=response,
|
response=response,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
_is_function_call=_is_function_call,
|
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
data=data,
|
data=data,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
|
optional_params=optional_params,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
return self.process_response(
|
||||||
|
model=model,
|
||||||
|
response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
stream=stream,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_key=api_key,
|
||||||
|
data=data,
|
||||||
|
messages=messages,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
optional_params=optional_params,
|
||||||
|
encoding=encoding,
|
||||||
)
|
)
|
||||||
|
|
||||||
def embedding(self):
|
def embedding(self):
|
||||||
|
|
|
@ -100,7 +100,7 @@ class AnthropicTextCompletion(BaseLLM):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def process_response(
|
def _process_response(
|
||||||
self, model_response: ModelResponse, response, encoding, prompt: str, model: str
|
self, model_response: ModelResponse, response, encoding, prompt: str, model: str
|
||||||
):
|
):
|
||||||
## RESPONSE OBJECT
|
## RESPONSE OBJECT
|
||||||
|
@ -171,7 +171,7 @@ class AnthropicTextCompletion(BaseLLM):
|
||||||
additional_args={"complete_input_dict": data},
|
additional_args={"complete_input_dict": data},
|
||||||
)
|
)
|
||||||
|
|
||||||
response = self.process_response(
|
response = self._process_response(
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
response=response,
|
response=response,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
|
@ -330,7 +330,7 @@ class AnthropicTextCompletion(BaseLLM):
|
||||||
)
|
)
|
||||||
print_verbose(f"raw model_response: {response.text}")
|
print_verbose(f"raw model_response: {response.text}")
|
||||||
|
|
||||||
response = self.process_response(
|
response = self._process_response(
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
response=response,
|
response=response,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
|
|
|
@ -1,12 +1,32 @@
|
||||||
## This is a template base class to be used for adding new LLM providers via API calls
|
## This is a template base class to be used for adding new LLM providers via API calls
|
||||||
import litellm
|
import litellm
|
||||||
import httpx
|
import httpx, requests
|
||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
|
from litellm.utils import Logging
|
||||||
|
|
||||||
|
|
||||||
class BaseLLM:
|
class BaseLLM:
|
||||||
_client_session: Optional[httpx.Client] = None
|
_client_session: Optional[httpx.Client] = None
|
||||||
|
|
||||||
|
def process_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
response: Union[requests.Response, httpx.Response],
|
||||||
|
model_response: litellm.utils.ModelResponse,
|
||||||
|
stream: bool,
|
||||||
|
logging_obj: Logging,
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: str,
|
||||||
|
data: Union[dict, str],
|
||||||
|
messages: list,
|
||||||
|
print_verbose,
|
||||||
|
encoding,
|
||||||
|
) -> litellm.utils.ModelResponse:
|
||||||
|
"""
|
||||||
|
Helper function to process the response across sync + async completion calls
|
||||||
|
"""
|
||||||
|
return model_response
|
||||||
|
|
||||||
def create_client_session(self):
|
def create_client_session(self):
|
||||||
if litellm.client_session:
|
if litellm.client_session:
|
||||||
_client_session = litellm.client_session
|
_client_session = litellm.client_session
|
||||||
|
|
733
litellm/llms/bedrock_httpx.py
Normal file
733
litellm/llms/bedrock_httpx.py
Normal file
|
@ -0,0 +1,733 @@
|
||||||
|
# What is this?
|
||||||
|
## Initial implementation of calling bedrock via httpx client (allows for async calls).
|
||||||
|
## V0 - just covers cohere command-r support
|
||||||
|
|
||||||
|
import os, types
|
||||||
|
import json
|
||||||
|
from enum import Enum
|
||||||
|
import requests, copy # type: ignore
|
||||||
|
import time
|
||||||
|
from typing import (
|
||||||
|
Callable,
|
||||||
|
Optional,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Union,
|
||||||
|
Any,
|
||||||
|
TypedDict,
|
||||||
|
Tuple,
|
||||||
|
Iterator,
|
||||||
|
AsyncIterator,
|
||||||
|
)
|
||||||
|
from litellm.utils import (
|
||||||
|
ModelResponse,
|
||||||
|
Usage,
|
||||||
|
map_finish_reason,
|
||||||
|
CustomStreamWrapper,
|
||||||
|
Message,
|
||||||
|
Choices,
|
||||||
|
get_secret,
|
||||||
|
Logging,
|
||||||
|
)
|
||||||
|
import litellm
|
||||||
|
from .prompt_templates.factory import prompt_factory, custom_prompt, cohere_message_pt
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
|
from .base import BaseLLM
|
||||||
|
import httpx # type: ignore
|
||||||
|
from .bedrock import BedrockError, convert_messages_to_prompt
|
||||||
|
from litellm.types.llms.bedrock import *
|
||||||
|
|
||||||
|
|
||||||
|
class AmazonCohereChatConfig:
|
||||||
|
"""
|
||||||
|
Reference - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command-r-plus.html
|
||||||
|
"""
|
||||||
|
|
||||||
|
documents: Optional[List[Document]] = None
|
||||||
|
search_queries_only: Optional[bool] = None
|
||||||
|
preamble: Optional[str] = None
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
p: Optional[float] = None
|
||||||
|
k: Optional[float] = None
|
||||||
|
prompt_truncation: Optional[str] = None
|
||||||
|
frequency_penalty: Optional[float] = None
|
||||||
|
presence_penalty: Optional[float] = None
|
||||||
|
seed: Optional[int] = None
|
||||||
|
return_prompt: Optional[bool] = None
|
||||||
|
stop_sequences: Optional[List[str]] = None
|
||||||
|
raw_prompting: Optional[bool] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
documents: Optional[List[Document]] = None,
|
||||||
|
search_queries_only: Optional[bool] = None,
|
||||||
|
preamble: Optional[str] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
p: Optional[float] = None,
|
||||||
|
k: Optional[float] = None,
|
||||||
|
prompt_truncation: Optional[str] = None,
|
||||||
|
frequency_penalty: Optional[float] = None,
|
||||||
|
presence_penalty: Optional[float] = None,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
return_prompt: Optional[bool] = None,
|
||||||
|
stop_sequences: Optional[str] = None,
|
||||||
|
raw_prompting: Optional[bool] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in cls.__dict__.items()
|
||||||
|
if not k.startswith("__")
|
||||||
|
and not isinstance(
|
||||||
|
v,
|
||||||
|
(
|
||||||
|
types.FunctionType,
|
||||||
|
types.BuiltinFunctionType,
|
||||||
|
classmethod,
|
||||||
|
staticmethod,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_supported_openai_params(self) -> List[str]:
|
||||||
|
return [
|
||||||
|
"max_tokens",
|
||||||
|
"stream",
|
||||||
|
"stop",
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"frequency_penalty",
|
||||||
|
"presence_penalty",
|
||||||
|
"seed",
|
||||||
|
"stop",
|
||||||
|
]
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self, non_default_params: dict, optional_params: dict
|
||||||
|
) -> dict:
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param == "max_tokens":
|
||||||
|
optional_params["max_tokens"] = value
|
||||||
|
if param == "stream":
|
||||||
|
optional_params["stream"] = value
|
||||||
|
if param == "stop":
|
||||||
|
if isinstance(value, str):
|
||||||
|
value = [value]
|
||||||
|
optional_params["stop_sequences"] = value
|
||||||
|
if param == "temperature":
|
||||||
|
optional_params["temperature"] = value
|
||||||
|
if param == "top_p":
|
||||||
|
optional_params["p"] = value
|
||||||
|
if param == "frequency_penalty":
|
||||||
|
optional_params["frequency_penalty"] = value
|
||||||
|
if param == "presence_penalty":
|
||||||
|
optional_params["presence_penalty"] = value
|
||||||
|
if "seed":
|
||||||
|
optional_params["seed"] = value
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockLLM(BaseLLM):
|
||||||
|
"""
|
||||||
|
Example call
|
||||||
|
|
||||||
|
```
|
||||||
|
curl --location --request POST 'https://bedrock-runtime.{aws_region_name}.amazonaws.com/model/{bedrock_model_name}/invoke' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--header 'Accept: application/json' \
|
||||||
|
--user "$AWS_ACCESS_KEY_ID":"$AWS_SECRET_ACCESS_KEY" \
|
||||||
|
--aws-sigv4 "aws:amz:us-east-1:bedrock" \
|
||||||
|
--data-raw '{
|
||||||
|
"prompt": "Hi",
|
||||||
|
"temperature": 0,
|
||||||
|
"p": 0.9,
|
||||||
|
"max_tokens": 4096
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def convert_messages_to_prompt(
|
||||||
|
self, model, messages, provider, custom_prompt_dict
|
||||||
|
) -> Tuple[str, Optional[list]]:
|
||||||
|
# handle anthropic prompts and amazon titan prompts
|
||||||
|
prompt = ""
|
||||||
|
chat_history: Optional[list] = None
|
||||||
|
if provider == "anthropic" or provider == "amazon":
|
||||||
|
if model in custom_prompt_dict:
|
||||||
|
# check if the model has a registered custom prompt
|
||||||
|
model_prompt_details = custom_prompt_dict[model]
|
||||||
|
prompt = custom_prompt(
|
||||||
|
role_dict=model_prompt_details["roles"],
|
||||||
|
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||||
|
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt = prompt_factory(
|
||||||
|
model=model, messages=messages, custom_llm_provider="bedrock"
|
||||||
|
)
|
||||||
|
elif provider == "mistral":
|
||||||
|
prompt = prompt_factory(
|
||||||
|
model=model, messages=messages, custom_llm_provider="bedrock"
|
||||||
|
)
|
||||||
|
elif provider == "meta":
|
||||||
|
prompt = prompt_factory(
|
||||||
|
model=model, messages=messages, custom_llm_provider="bedrock"
|
||||||
|
)
|
||||||
|
elif provider == "cohere":
|
||||||
|
prompt, chat_history = cohere_message_pt(messages=messages)
|
||||||
|
else:
|
||||||
|
prompt = ""
|
||||||
|
for message in messages:
|
||||||
|
if "role" in message:
|
||||||
|
if message["role"] == "user":
|
||||||
|
prompt += f"{message['content']}"
|
||||||
|
else:
|
||||||
|
prompt += f"{message['content']}"
|
||||||
|
else:
|
||||||
|
prompt += f"{message['content']}"
|
||||||
|
return prompt, chat_history # type: ignore
|
||||||
|
|
||||||
|
def get_credentials(
|
||||||
|
self,
|
||||||
|
aws_access_key_id: Optional[str] = None,
|
||||||
|
aws_secret_access_key: Optional[str] = None,
|
||||||
|
aws_region_name: Optional[str] = None,
|
||||||
|
aws_session_name: Optional[str] = None,
|
||||||
|
aws_profile_name: Optional[str] = None,
|
||||||
|
aws_role_name: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Return a boto3.Credentials object
|
||||||
|
"""
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
## CHECK IS 'os.environ/' passed in
|
||||||
|
params_to_check: List[Optional[str]] = [
|
||||||
|
aws_access_key_id,
|
||||||
|
aws_secret_access_key,
|
||||||
|
aws_region_name,
|
||||||
|
aws_session_name,
|
||||||
|
aws_profile_name,
|
||||||
|
aws_role_name,
|
||||||
|
]
|
||||||
|
|
||||||
|
# Iterate over parameters and update if needed
|
||||||
|
for i, param in enumerate(params_to_check):
|
||||||
|
if param and param.startswith("os.environ/"):
|
||||||
|
_v = get_secret(param)
|
||||||
|
if _v is not None and isinstance(_v, str):
|
||||||
|
params_to_check[i] = _v
|
||||||
|
# Assign updated values back to parameters
|
||||||
|
(
|
||||||
|
aws_access_key_id,
|
||||||
|
aws_secret_access_key,
|
||||||
|
aws_region_name,
|
||||||
|
aws_session_name,
|
||||||
|
aws_profile_name,
|
||||||
|
aws_role_name,
|
||||||
|
) = params_to_check
|
||||||
|
|
||||||
|
### CHECK STS ###
|
||||||
|
if aws_role_name is not None and aws_session_name is not None:
|
||||||
|
sts_client = boto3.client(
|
||||||
|
"sts",
|
||||||
|
aws_access_key_id=aws_access_key_id, # [OPTIONAL]
|
||||||
|
aws_secret_access_key=aws_secret_access_key, # [OPTIONAL]
|
||||||
|
)
|
||||||
|
|
||||||
|
sts_response = sts_client.assume_role(
|
||||||
|
RoleArn=aws_role_name, RoleSessionName=aws_session_name
|
||||||
|
)
|
||||||
|
|
||||||
|
return sts_response["Credentials"]
|
||||||
|
elif aws_profile_name is not None: ### CHECK SESSION ###
|
||||||
|
# uses auth values from AWS profile usually stored in ~/.aws/credentials
|
||||||
|
client = boto3.Session(profile_name=aws_profile_name)
|
||||||
|
|
||||||
|
return client.get_credentials()
|
||||||
|
else:
|
||||||
|
session = boto3.Session(
|
||||||
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
region_name=aws_region_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
return session.get_credentials()
|
||||||
|
|
||||||
|
def process_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
response: Union[requests.Response, httpx.Response],
|
||||||
|
model_response: ModelResponse,
|
||||||
|
stream: bool,
|
||||||
|
logging_obj: Logging,
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: str,
|
||||||
|
data: Union[dict, str],
|
||||||
|
messages: List,
|
||||||
|
print_verbose,
|
||||||
|
encoding,
|
||||||
|
) -> ModelResponse:
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key=api_key,
|
||||||
|
original_response=response.text,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
print_verbose(f"raw model_response: {response.text}")
|
||||||
|
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
try:
|
||||||
|
completion_response = response.json()
|
||||||
|
except:
|
||||||
|
raise BedrockError(message=response.text, status_code=422)
|
||||||
|
|
||||||
|
try:
|
||||||
|
model_response.choices[0].message.content = completion_response["text"] # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
raise BedrockError(message=response.text, status_code=422)
|
||||||
|
|
||||||
|
## CALCULATING USAGE - bedrock returns usage in the headers
|
||||||
|
prompt_tokens = int(
|
||||||
|
response.headers.get(
|
||||||
|
"x-amzn-bedrock-input-token-count",
|
||||||
|
len(encoding.encode("".join(m.get("content", "") for m in messages))),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
completion_tokens = int(
|
||||||
|
response.headers.get(
|
||||||
|
"x-amzn-bedrock-output-token-count",
|
||||||
|
len(
|
||||||
|
encoding.encode(
|
||||||
|
model_response.choices[0].message.content, # type: ignore
|
||||||
|
disallowed_special=(),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model_response["created"] = int(time.time())
|
||||||
|
model_response["model"] = model
|
||||||
|
usage = Usage(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
|
)
|
||||||
|
setattr(model_response, "usage", usage)
|
||||||
|
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
def completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
custom_prompt_dict: dict,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
encoding,
|
||||||
|
logging_obj,
|
||||||
|
optional_params: dict,
|
||||||
|
acompletion: bool,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
extra_headers: Optional[dict] = None,
|
||||||
|
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
||||||
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
|
try:
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
from botocore.auth import SigV4Auth
|
||||||
|
from botocore.awsrequest import AWSRequest
|
||||||
|
from botocore.credentials import Credentials
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||||
|
|
||||||
|
## SETUP ##
|
||||||
|
stream = optional_params.pop("stream", None)
|
||||||
|
|
||||||
|
## CREDENTIALS ##
|
||||||
|
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||||
|
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||||
|
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||||
|
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||||
|
aws_role_name = optional_params.pop("aws_role_name", None)
|
||||||
|
aws_session_name = optional_params.pop("aws_session_name", None)
|
||||||
|
aws_profile_name = optional_params.pop("aws_profile_name", None)
|
||||||
|
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||||
|
"aws_bedrock_runtime_endpoint", None
|
||||||
|
) # https://bedrock-runtime.{region_name}.amazonaws.com
|
||||||
|
|
||||||
|
### SET REGION NAME ###
|
||||||
|
if aws_region_name is None:
|
||||||
|
# check env #
|
||||||
|
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||||
|
|
||||||
|
if litellm_aws_region_name is not None and isinstance(
|
||||||
|
litellm_aws_region_name, str
|
||||||
|
):
|
||||||
|
aws_region_name = litellm_aws_region_name
|
||||||
|
|
||||||
|
standard_aws_region_name = get_secret("AWS_REGION", None)
|
||||||
|
if standard_aws_region_name is not None and isinstance(
|
||||||
|
standard_aws_region_name, str
|
||||||
|
):
|
||||||
|
aws_region_name = standard_aws_region_name
|
||||||
|
|
||||||
|
if aws_region_name is None:
|
||||||
|
aws_region_name = "us-west-2"
|
||||||
|
|
||||||
|
credentials: Credentials = self.get_credentials(
|
||||||
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
aws_region_name=aws_region_name,
|
||||||
|
aws_session_name=aws_session_name,
|
||||||
|
aws_profile_name=aws_profile_name,
|
||||||
|
aws_role_name=aws_role_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
### SET RUNTIME ENDPOINT ###
|
||||||
|
endpoint_url = ""
|
||||||
|
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
|
||||||
|
if aws_bedrock_runtime_endpoint is not None and isinstance(
|
||||||
|
aws_bedrock_runtime_endpoint, str
|
||||||
|
):
|
||||||
|
endpoint_url = aws_bedrock_runtime_endpoint
|
||||||
|
elif env_aws_bedrock_runtime_endpoint and isinstance(
|
||||||
|
env_aws_bedrock_runtime_endpoint, str
|
||||||
|
):
|
||||||
|
endpoint_url = env_aws_bedrock_runtime_endpoint
|
||||||
|
else:
|
||||||
|
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
|
||||||
|
|
||||||
|
if stream is not None and stream == True:
|
||||||
|
endpoint_url = f"{endpoint_url}/model/{model}/invoke-with-response-stream"
|
||||||
|
else:
|
||||||
|
endpoint_url = f"{endpoint_url}/model/{model}/invoke"
|
||||||
|
|
||||||
|
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
||||||
|
|
||||||
|
provider = model.split(".")[0]
|
||||||
|
prompt, chat_history = self.convert_messages_to_prompt(
|
||||||
|
model, messages, provider, custom_prompt_dict
|
||||||
|
)
|
||||||
|
inference_params = copy.deepcopy(optional_params)
|
||||||
|
|
||||||
|
if provider == "cohere":
|
||||||
|
if model.startswith("cohere.command-r"):
|
||||||
|
## LOAD CONFIG
|
||||||
|
config = litellm.AmazonCohereChatConfig().get_config()
|
||||||
|
for k, v in config.items():
|
||||||
|
if (
|
||||||
|
k not in inference_params
|
||||||
|
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||||
|
inference_params[k] = v
|
||||||
|
_data = {"message": prompt, **inference_params}
|
||||||
|
if chat_history is not None:
|
||||||
|
_data["chat_history"] = chat_history
|
||||||
|
data = json.dumps(_data)
|
||||||
|
else:
|
||||||
|
## LOAD CONFIG
|
||||||
|
config = litellm.AmazonCohereConfig.get_config()
|
||||||
|
for k, v in config.items():
|
||||||
|
if (
|
||||||
|
k not in inference_params
|
||||||
|
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||||
|
inference_params[k] = v
|
||||||
|
if stream == True:
|
||||||
|
inference_params["stream"] = (
|
||||||
|
True # cohere requires stream = True in inference params
|
||||||
|
)
|
||||||
|
data = json.dumps({"prompt": prompt, **inference_params})
|
||||||
|
else:
|
||||||
|
raise Exception("UNSUPPORTED PROVIDER")
|
||||||
|
|
||||||
|
## COMPLETION CALL
|
||||||
|
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
if extra_headers is not None:
|
||||||
|
headers = {"Content-Type": "application/json", **extra_headers}
|
||||||
|
request = AWSRequest(
|
||||||
|
method="POST", url=endpoint_url, data=data, headers=headers
|
||||||
|
)
|
||||||
|
sigv4.add_auth(request)
|
||||||
|
prepped = request.prepare()
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=messages,
|
||||||
|
api_key="",
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": data,
|
||||||
|
"api_base": prepped.url,
|
||||||
|
"headers": prepped.headers,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
### ROUTING (ASYNC, STREAMING, SYNC)
|
||||||
|
if acompletion:
|
||||||
|
if isinstance(client, HTTPHandler):
|
||||||
|
client = None
|
||||||
|
if stream:
|
||||||
|
return self.async_streaming(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
data=data,
|
||||||
|
api_base=prepped.url,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
stream=True,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
headers=prepped.headers,
|
||||||
|
timeout=timeout,
|
||||||
|
client=client,
|
||||||
|
) # type: ignore
|
||||||
|
### ASYNC COMPLETION
|
||||||
|
return self.async_completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
data=data,
|
||||||
|
api_base=prepped.url,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
stream=False,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
headers=prepped.headers,
|
||||||
|
timeout=timeout,
|
||||||
|
client=client,
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
if client is None or isinstance(client, AsyncHTTPHandler):
|
||||||
|
_params = {}
|
||||||
|
if timeout is not None:
|
||||||
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||||
|
timeout = httpx.Timeout(timeout)
|
||||||
|
_params["timeout"] = timeout
|
||||||
|
self.client = HTTPHandler(**_params) # type: ignore
|
||||||
|
else:
|
||||||
|
self.client = client
|
||||||
|
if stream is not None and stream == True:
|
||||||
|
response = self.client.post(
|
||||||
|
url=prepped.url,
|
||||||
|
headers=prepped.headers, # type: ignore
|
||||||
|
data=data,
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise BedrockError(
|
||||||
|
status_code=response.status_code, message=response.text
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder = AWSEventStreamDecoder()
|
||||||
|
|
||||||
|
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
|
||||||
|
streaming_response = CustomStreamWrapper(
|
||||||
|
completion_stream=completion_stream,
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="bedrock",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
return streaming_response
|
||||||
|
|
||||||
|
response = self.client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore
|
||||||
|
|
||||||
|
try:
|
||||||
|
response.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError as err:
|
||||||
|
error_code = err.response.status_code
|
||||||
|
raise BedrockError(status_code=error_code, message=response.text)
|
||||||
|
|
||||||
|
return self.process_response(
|
||||||
|
model=model,
|
||||||
|
response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
stream=stream,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
api_key="",
|
||||||
|
data=data,
|
||||||
|
messages=messages,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
data: str,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
encoding,
|
||||||
|
logging_obj,
|
||||||
|
stream,
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
) -> ModelResponse:
|
||||||
|
if client is None:
|
||||||
|
_params = {}
|
||||||
|
if timeout is not None:
|
||||||
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||||
|
timeout = httpx.Timeout(timeout)
|
||||||
|
_params["timeout"] = timeout
|
||||||
|
self.client = AsyncHTTPHandler(**_params) # type: ignore
|
||||||
|
else:
|
||||||
|
self.client = client # type: ignore
|
||||||
|
|
||||||
|
response = await self.client.post(api_base, headers=headers, data=data) # type: ignore
|
||||||
|
return self.process_response(
|
||||||
|
model=model,
|
||||||
|
response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
stream=stream,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_key="",
|
||||||
|
data=data,
|
||||||
|
messages=messages,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
optional_params=optional_params,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_streaming(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
data: str,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
encoding,
|
||||||
|
logging_obj,
|
||||||
|
stream,
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
) -> CustomStreamWrapper:
|
||||||
|
if client is None:
|
||||||
|
_params = {}
|
||||||
|
if timeout is not None:
|
||||||
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||||
|
timeout = httpx.Timeout(timeout)
|
||||||
|
_params["timeout"] = timeout
|
||||||
|
self.client = AsyncHTTPHandler(**_params) # type: ignore
|
||||||
|
else:
|
||||||
|
self.client = client # type: ignore
|
||||||
|
|
||||||
|
response = await self.client.post(api_base, headers=headers, data=data, stream=True) # type: ignore
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise BedrockError(status_code=response.status_code, message=response.text)
|
||||||
|
|
||||||
|
decoder = AWSEventStreamDecoder()
|
||||||
|
|
||||||
|
completion_stream = decoder.aiter_bytes(response.aiter_bytes(chunk_size=1024))
|
||||||
|
streaming_response = CustomStreamWrapper(
|
||||||
|
completion_stream=completion_stream,
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="bedrock",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
return streaming_response
|
||||||
|
|
||||||
|
def embedding(self, *args, **kwargs):
|
||||||
|
return super().embedding(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def get_response_stream_shape():
|
||||||
|
from botocore.model import ServiceModel
|
||||||
|
from botocore.loaders import Loader
|
||||||
|
|
||||||
|
loader = Loader()
|
||||||
|
bedrock_service_dict = loader.load_service_model("bedrock-runtime", "service-2")
|
||||||
|
bedrock_service_model = ServiceModel(bedrock_service_dict)
|
||||||
|
return bedrock_service_model.shape_for("ResponseStream")
|
||||||
|
|
||||||
|
|
||||||
|
class AWSEventStreamDecoder:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
from botocore.parsers import EventStreamJSONParser
|
||||||
|
|
||||||
|
self.parser = EventStreamJSONParser()
|
||||||
|
|
||||||
|
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GenericStreamingChunk]:
|
||||||
|
"""Given an iterator that yields lines, iterate over it & yield every event encountered"""
|
||||||
|
from botocore.eventstream import EventStreamBuffer
|
||||||
|
|
||||||
|
event_stream_buffer = EventStreamBuffer()
|
||||||
|
for chunk in iterator:
|
||||||
|
event_stream_buffer.add_data(chunk)
|
||||||
|
for event in event_stream_buffer:
|
||||||
|
message = self._parse_message_from_event(event)
|
||||||
|
if message:
|
||||||
|
# sse_event = ServerSentEvent(data=message, event="completion")
|
||||||
|
_data = json.loads(message)
|
||||||
|
streaming_chunk: GenericStreamingChunk = GenericStreamingChunk(
|
||||||
|
text=_data.get("text", ""),
|
||||||
|
is_finished=_data.get("is_finished", False),
|
||||||
|
finish_reason=_data.get("finish_reason", ""),
|
||||||
|
)
|
||||||
|
yield streaming_chunk
|
||||||
|
|
||||||
|
async def aiter_bytes(
|
||||||
|
self, iterator: AsyncIterator[bytes]
|
||||||
|
) -> AsyncIterator[GenericStreamingChunk]:
|
||||||
|
"""Given an async iterator that yields lines, iterate over it & yield every event encountered"""
|
||||||
|
from botocore.eventstream import EventStreamBuffer
|
||||||
|
|
||||||
|
event_stream_buffer = EventStreamBuffer()
|
||||||
|
async for chunk in iterator:
|
||||||
|
event_stream_buffer.add_data(chunk)
|
||||||
|
for event in event_stream_buffer:
|
||||||
|
message = self._parse_message_from_event(event)
|
||||||
|
if message:
|
||||||
|
_data = json.loads(message)
|
||||||
|
streaming_chunk: GenericStreamingChunk = GenericStreamingChunk(
|
||||||
|
text=_data.get("text", ""),
|
||||||
|
is_finished=_data.get("is_finished", False),
|
||||||
|
finish_reason=_data.get("finish_reason", ""),
|
||||||
|
)
|
||||||
|
yield streaming_chunk
|
||||||
|
|
||||||
|
def _parse_message_from_event(self, event) -> Optional[str]:
|
||||||
|
response_dict = event.to_response_dict()
|
||||||
|
parsed_response = self.parser.parse(response_dict, get_response_stream_shape())
|
||||||
|
if response_dict["status_code"] != 200:
|
||||||
|
raise ValueError(f"Bad response code, expected 200: {response_dict}")
|
||||||
|
|
||||||
|
chunk = parsed_response.get("chunk")
|
||||||
|
if not chunk:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return chunk.get("bytes").decode() # type: ignore[no-any-return]
|
|
@ -58,8 +58,15 @@ class AsyncHTTPHandler:
|
||||||
|
|
||||||
class HTTPHandler:
|
class HTTPHandler:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, timeout: httpx.Timeout = _DEFAULT_TIMEOUT, concurrent_limit=1000
|
self,
|
||||||
|
timeout: Optional[httpx.Timeout] = None,
|
||||||
|
concurrent_limit=1000,
|
||||||
|
client: Optional[httpx.Client] = None,
|
||||||
):
|
):
|
||||||
|
if timeout is None:
|
||||||
|
timeout = _DEFAULT_TIMEOUT
|
||||||
|
|
||||||
|
if client is None:
|
||||||
# Create a client with a connection pool
|
# Create a client with a connection pool
|
||||||
self.client = httpx.Client(
|
self.client = httpx.Client(
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
|
@ -68,6 +75,8 @@ class HTTPHandler:
|
||||||
max_keepalive_connections=concurrent_limit,
|
max_keepalive_connections=concurrent_limit,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
self.client = client
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
# Close the client when you're done with it
|
# Close the client when you're done with it
|
||||||
|
@ -82,11 +91,15 @@ class HTTPHandler:
|
||||||
def post(
|
def post(
|
||||||
self,
|
self,
|
||||||
url: str,
|
url: str,
|
||||||
data: Optional[dict] = None,
|
data: Optional[Union[dict, str]] = None,
|
||||||
params: Optional[dict] = None,
|
params: Optional[dict] = None,
|
||||||
headers: Optional[dict] = None,
|
headers: Optional[dict] = None,
|
||||||
|
stream: bool = False,
|
||||||
):
|
):
|
||||||
response = self.client.post(url, data=data, params=params, headers=headers)
|
req = self.client.build_request(
|
||||||
|
"POST", url, data=data, params=params, headers=headers # type: ignore
|
||||||
|
)
|
||||||
|
response = self.client.send(req, stream=stream)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def __del__(self) -> None:
|
def __del__(self) -> None:
|
||||||
|
|
|
@ -168,7 +168,7 @@ class PredibaseChatCompletion(BaseLLM):
|
||||||
logging_obj: litellm.utils.Logging,
|
logging_obj: litellm.utils.Logging,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
data: dict,
|
data: Union[dict, str],
|
||||||
messages: list,
|
messages: list,
|
||||||
print_verbose,
|
print_verbose,
|
||||||
encoding,
|
encoding,
|
||||||
|
@ -185,9 +185,7 @@ class PredibaseChatCompletion(BaseLLM):
|
||||||
try:
|
try:
|
||||||
completion_response = response.json()
|
completion_response = response.json()
|
||||||
except:
|
except:
|
||||||
raise PredibaseError(
|
raise PredibaseError(message=response.text, status_code=422)
|
||||||
message=response.text, status_code=response.status_code
|
|
||||||
)
|
|
||||||
if "error" in completion_response:
|
if "error" in completion_response:
|
||||||
raise PredibaseError(
|
raise PredibaseError(
|
||||||
message=str(completion_response["error"]),
|
message=str(completion_response["error"]),
|
||||||
|
@ -363,7 +361,7 @@ class PredibaseChatCompletion(BaseLLM):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
if acompletion is True:
|
if acompletion == True:
|
||||||
### ASYNC STREAMING
|
### ASYNC STREAMING
|
||||||
if stream == True:
|
if stream == True:
|
||||||
return self.async_streaming(
|
return self.async_streaming(
|
||||||
|
|
|
@ -76,6 +76,7 @@ from .llms.anthropic import AnthropicChatCompletion
|
||||||
from .llms.anthropic_text import AnthropicTextCompletion
|
from .llms.anthropic_text import AnthropicTextCompletion
|
||||||
from .llms.huggingface_restapi import Huggingface
|
from .llms.huggingface_restapi import Huggingface
|
||||||
from .llms.predibase import PredibaseChatCompletion
|
from .llms.predibase import PredibaseChatCompletion
|
||||||
|
from .llms.bedrock_httpx import BedrockLLM
|
||||||
from .llms.triton import TritonChatCompletion
|
from .llms.triton import TritonChatCompletion
|
||||||
from .llms.prompt_templates.factory import (
|
from .llms.prompt_templates.factory import (
|
||||||
prompt_factory,
|
prompt_factory,
|
||||||
|
@ -105,7 +106,6 @@ from litellm.utils import (
|
||||||
)
|
)
|
||||||
|
|
||||||
####### ENVIRONMENT VARIABLES ###################
|
####### ENVIRONMENT VARIABLES ###################
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
openai_chat_completions = OpenAIChatCompletion()
|
openai_chat_completions = OpenAIChatCompletion()
|
||||||
openai_text_completions = OpenAITextCompletion()
|
openai_text_completions = OpenAITextCompletion()
|
||||||
anthropic_chat_completions = AnthropicChatCompletion()
|
anthropic_chat_completions = AnthropicChatCompletion()
|
||||||
|
@ -115,6 +115,7 @@ azure_text_completions = AzureTextCompletion()
|
||||||
huggingface = Huggingface()
|
huggingface = Huggingface()
|
||||||
predibase_chat_completions = PredibaseChatCompletion()
|
predibase_chat_completions = PredibaseChatCompletion()
|
||||||
triton_chat_completions = TritonChatCompletion()
|
triton_chat_completions = TritonChatCompletion()
|
||||||
|
bedrock_chat_completion = BedrockLLM()
|
||||||
####### COMPLETION ENDPOINTS ################
|
####### COMPLETION ENDPOINTS ################
|
||||||
|
|
||||||
|
|
||||||
|
@ -257,7 +258,7 @@ async def acompletion(
|
||||||
- If `stream` is True, the function returns an async generator that yields completion lines.
|
- If `stream` is True, the function returns an async generator that yields completion lines.
|
||||||
"""
|
"""
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
custom_llm_provider = None
|
custom_llm_provider = kwargs.get("custom_llm_provider", None)
|
||||||
# Adjusted to use explicit arguments instead of *args and **kwargs
|
# Adjusted to use explicit arguments instead of *args and **kwargs
|
||||||
completion_kwargs = {
|
completion_kwargs = {
|
||||||
"model": model,
|
"model": model,
|
||||||
|
@ -289,6 +290,7 @@ async def acompletion(
|
||||||
"model_list": model_list,
|
"model_list": model_list,
|
||||||
"acompletion": True, # assuming this is a required parameter
|
"acompletion": True, # assuming this is a required parameter
|
||||||
}
|
}
|
||||||
|
if custom_llm_provider is None:
|
||||||
_, custom_llm_provider, _, _ = get_llm_provider(
|
_, custom_llm_provider, _, _ = get_llm_provider(
|
||||||
model=model, api_base=completion_kwargs.get("base_url", None)
|
model=model, api_base=completion_kwargs.get("base_url", None)
|
||||||
)
|
)
|
||||||
|
@ -300,9 +302,6 @@ async def acompletion(
|
||||||
ctx = contextvars.copy_context()
|
ctx = contextvars.copy_context()
|
||||||
func_with_context = partial(ctx.run, func)
|
func_with_context = partial(ctx.run, func)
|
||||||
|
|
||||||
_, custom_llm_provider, _, _ = get_llm_provider(
|
|
||||||
model=model, api_base=kwargs.get("api_base", None)
|
|
||||||
)
|
|
||||||
if (
|
if (
|
||||||
custom_llm_provider == "openai"
|
custom_llm_provider == "openai"
|
||||||
or custom_llm_provider == "azure"
|
or custom_llm_provider == "azure"
|
||||||
|
@ -324,6 +323,7 @@ async def acompletion(
|
||||||
or custom_llm_provider == "sagemaker"
|
or custom_llm_provider == "sagemaker"
|
||||||
or custom_llm_provider == "anthropic"
|
or custom_llm_provider == "anthropic"
|
||||||
or custom_llm_provider == "predibase"
|
or custom_llm_provider == "predibase"
|
||||||
|
or (custom_llm_provider == "bedrock" and "cohere" in model)
|
||||||
or custom_llm_provider in litellm.openai_compatible_providers
|
or custom_llm_provider in litellm.openai_compatible_providers
|
||||||
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
|
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
|
||||||
init_response = await loop.run_in_executor(None, func_with_context)
|
init_response = await loop.run_in_executor(None, func_with_context)
|
||||||
|
@ -1976,6 +1976,24 @@ def completion(
|
||||||
elif custom_llm_provider == "bedrock":
|
elif custom_llm_provider == "bedrock":
|
||||||
# boto3 reads keys from .env
|
# boto3 reads keys from .env
|
||||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||||
|
|
||||||
|
if "cohere" in model:
|
||||||
|
response = bedrock_chat_completion.completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
custom_prompt_dict=litellm.custom_prompt_dict,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
encoding=encoding,
|
||||||
|
logging_obj=logging,
|
||||||
|
extra_headers=extra_headers,
|
||||||
|
timeout=timeout,
|
||||||
|
acompletion=acompletion,
|
||||||
|
)
|
||||||
|
else:
|
||||||
response = bedrock.completion(
|
response = bedrock.completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
|
|
@ -1,10 +1,7 @@
|
||||||
from litellm.proxy._types import UserAPIKeyAuth, GenerateKeyRequest
|
from litellm.proxy._types import UserAPIKeyAuth, GenerateKeyRequest
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from dotenv import load_dotenv
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
|
|
||||||
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
|
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -1507,7 +1507,6 @@ class Router:
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
original_exception = e
|
original_exception = e
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Retry Logic
|
Retry Logic
|
||||||
|
|
||||||
|
|
|
@ -8,8 +8,6 @@
|
||||||
|
|
||||||
import dotenv, os, requests, random # type: ignore
|
import dotenv, os, requests, random # type: ignore
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
|
|
@ -1,12 +1,11 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# picks based on response time (for streaming, this is time to first token)
|
# picks based on response time (for streaming, this is time to first token)
|
||||||
from pydantic import BaseModel, Extra, Field, root_validator
|
from pydantic import BaseModel, Extra, Field, root_validator
|
||||||
import dotenv, os, requests, random # type: ignore
|
import os, requests, random # type: ignore
|
||||||
from typing import Optional, Union, List, Dict
|
from typing import Optional, Union, List, Dict
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
import random
|
import random
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
|
|
@ -5,8 +5,6 @@ import dotenv, os, requests, random # type: ignore
|
||||||
from typing import Optional, Union, List, Dict
|
from typing import Optional, Union, List, Dict
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
import random
|
import random
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
|
|
@ -4,8 +4,6 @@
|
||||||
import dotenv, os, requests, random
|
import dotenv, os, requests, random
|
||||||
from typing import Optional, Union, List, Dict
|
from typing import Optional, Union, List, Dict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback
|
import traceback
|
||||||
from litellm import token_counter
|
from litellm import token_counter
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
|
|
|
@ -5,8 +5,6 @@ import dotenv, os, requests, random
|
||||||
from typing import Optional, Union, List, Dict
|
from typing import Optional, Union, List, Dict
|
||||||
import datetime as datetime_og
|
import datetime as datetime_og
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
import traceback, asyncio, httpx
|
import traceback, asyncio, httpx
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import token_counter
|
from litellm import token_counter
|
||||||
|
|
|
@ -2584,6 +2584,69 @@ def test_completion_chat_sagemaker_mistral():
|
||||||
# test_completion_chat_sagemaker_mistral()
|
# test_completion_chat_sagemaker_mistral()
|
||||||
|
|
||||||
|
|
||||||
|
def response_format_tests(response: litellm.ModelResponse):
|
||||||
|
assert isinstance(response.id, str)
|
||||||
|
assert response.id != ""
|
||||||
|
|
||||||
|
assert isinstance(response.object, str)
|
||||||
|
assert response.object != ""
|
||||||
|
|
||||||
|
assert isinstance(response.created, int)
|
||||||
|
|
||||||
|
assert isinstance(response.model, str)
|
||||||
|
assert response.model != ""
|
||||||
|
|
||||||
|
assert isinstance(response.choices, list)
|
||||||
|
assert len(response.choices) == 1
|
||||||
|
choice = response.choices[0]
|
||||||
|
assert isinstance(choice, litellm.Choices)
|
||||||
|
assert isinstance(choice.get("index"), int)
|
||||||
|
|
||||||
|
message = choice.get("message")
|
||||||
|
assert isinstance(message, litellm.Message)
|
||||||
|
assert isinstance(message.get("role"), str)
|
||||||
|
assert message.get("role") != ""
|
||||||
|
assert isinstance(message.get("content"), str)
|
||||||
|
assert message.get("content") != ""
|
||||||
|
|
||||||
|
assert choice.get("logprobs") is None
|
||||||
|
assert isinstance(choice.get("finish_reason"), str)
|
||||||
|
assert choice.get("finish_reason") != ""
|
||||||
|
|
||||||
|
assert isinstance(response.usage, litellm.Usage) # type: ignore
|
||||||
|
assert isinstance(response.usage.prompt_tokens, int) # type: ignore
|
||||||
|
assert isinstance(response.usage.completion_tokens, int) # type: ignore
|
||||||
|
assert isinstance(response.usage.total_tokens, int) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_completion_bedrock_command_r(sync_mode):
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
if sync_mode:
|
||||||
|
response = completion(
|
||||||
|
model="bedrock/cohere.command-r-plus-v1:0",
|
||||||
|
messages=[{"role": "user", "content": "Hey! how's it going?"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, litellm.ModelResponse)
|
||||||
|
|
||||||
|
response_format_tests(response=response)
|
||||||
|
else:
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="bedrock/cohere.command-r-plus-v1:0",
|
||||||
|
messages=[{"role": "user", "content": "Hey! how's it going?"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, litellm.ModelResponse)
|
||||||
|
|
||||||
|
print(f"response: {response}")
|
||||||
|
response_format_tests(response=response)
|
||||||
|
|
||||||
|
print(f"response: {response}")
|
||||||
|
|
||||||
|
|
||||||
def test_completion_bedrock_titan_null_response():
|
def test_completion_bedrock_titan_null_response():
|
||||||
try:
|
try:
|
||||||
response = completion(
|
response = completion(
|
||||||
|
|
|
@ -132,12 +132,15 @@ def test_post_call_rule_streaming():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_post_call_processing_error_async_response():
|
@pytest.mark.asyncio
|
||||||
response = asyncio.run(
|
async def test_post_call_processing_error_async_response():
|
||||||
acompletion(
|
try:
|
||||||
|
response = await acompletion(
|
||||||
model="command-nightly", # Just used as an example
|
model="command-nightly", # Just used as an example
|
||||||
messages=[{"content": "Hello, how are you?", "role": "user"}],
|
messages=[{"content": "Hello, how are you?", "role": "user"}],
|
||||||
api_base="https://openai-proxy.berriai.repl.co", # Just used as an example
|
api_base="https://openai-proxy.berriai.repl.co", # Just used as an example
|
||||||
custom_llm_provider="openai",
|
custom_llm_provider="openai",
|
||||||
)
|
)
|
||||||
)
|
pytest.fail("This call should have failed")
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
|
@ -983,6 +983,64 @@ def test_vertex_ai_stream():
|
||||||
# pytest.fail(f"Error occurred: {e}")
|
# pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bedrock_cohere_command_r_streaming(sync_mode):
|
||||||
|
try:
|
||||||
|
litellm.set_verbose = True
|
||||||
|
if sync_mode:
|
||||||
|
final_chunk: Optional[litellm.ModelResponse] = None
|
||||||
|
response: litellm.CustomStreamWrapper = completion( # type: ignore
|
||||||
|
model="bedrock/cohere.command-r-plus-v1:0",
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=10, # type: ignore
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
complete_response = ""
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
has_finish_reason = False
|
||||||
|
for idx, chunk in enumerate(response):
|
||||||
|
final_chunk = chunk
|
||||||
|
chunk, finished = streaming_format_tests(idx, chunk)
|
||||||
|
if finished:
|
||||||
|
has_finish_reason = True
|
||||||
|
break
|
||||||
|
complete_response += chunk
|
||||||
|
if has_finish_reason == False:
|
||||||
|
raise Exception("finish reason not set")
|
||||||
|
if complete_response.strip() == "":
|
||||||
|
raise Exception("Empty response received")
|
||||||
|
else:
|
||||||
|
response: litellm.CustomStreamWrapper = await litellm.acompletion( # type: ignore
|
||||||
|
model="bedrock/cohere.command-r-plus-v1:0",
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=100, # type: ignore
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
complete_response = ""
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
has_finish_reason = False
|
||||||
|
idx = 0
|
||||||
|
final_chunk: Optional[litellm.ModelResponse] = None
|
||||||
|
async for chunk in response:
|
||||||
|
final_chunk = chunk
|
||||||
|
chunk, finished = streaming_format_tests(idx, chunk)
|
||||||
|
if finished:
|
||||||
|
has_finish_reason = True
|
||||||
|
break
|
||||||
|
complete_response += chunk
|
||||||
|
idx += 1
|
||||||
|
if has_finish_reason == False:
|
||||||
|
raise Exception("finish reason not set")
|
||||||
|
if complete_response.strip() == "":
|
||||||
|
raise Exception("Empty response received")
|
||||||
|
print(f"completion_response: {complete_response}\n\nFinalChunk: {final_chunk}")
|
||||||
|
except RateLimitError:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
def test_bedrock_claude_3_streaming():
|
def test_bedrock_claude_3_streaming():
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
|
63
litellm/types/llms/bedrock.py
Normal file
63
litellm/types/llms/bedrock.py
Normal file
|
@ -0,0 +1,63 @@
|
||||||
|
from typing import TypedDict, Any, Union, Optional
|
||||||
|
import json
|
||||||
|
from typing_extensions import (
|
||||||
|
Self,
|
||||||
|
Protocol,
|
||||||
|
TypeGuard,
|
||||||
|
override,
|
||||||
|
get_origin,
|
||||||
|
runtime_checkable,
|
||||||
|
Required,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GenericStreamingChunk(TypedDict):
|
||||||
|
text: Required[str]
|
||||||
|
is_finished: Required[bool]
|
||||||
|
finish_reason: Required[str]
|
||||||
|
|
||||||
|
|
||||||
|
class Document(TypedDict):
|
||||||
|
title: str
|
||||||
|
snippet: str
|
||||||
|
|
||||||
|
|
||||||
|
class ServerSentEvent:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
event: Optional[str] = None,
|
||||||
|
data: Optional[str] = None,
|
||||||
|
id: Optional[str] = None,
|
||||||
|
retry: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
if data is None:
|
||||||
|
data = ""
|
||||||
|
|
||||||
|
self._id = id
|
||||||
|
self._data = data
|
||||||
|
self._event = event or None
|
||||||
|
self._retry = retry
|
||||||
|
|
||||||
|
@property
|
||||||
|
def event(self) -> Optional[str]:
|
||||||
|
return self._event
|
||||||
|
|
||||||
|
@property
|
||||||
|
def id(self) -> Optional[str]:
|
||||||
|
return self._id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def retry(self) -> Optional[int]:
|
||||||
|
return self._retry
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data(self) -> str:
|
||||||
|
return self._data
|
||||||
|
|
||||||
|
def json(self) -> Any:
|
||||||
|
return json.loads(self.data)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"ServerSentEvent(event={self.event}, data={self.data}, id={self.id}, retry={self.retry})"
|
|
@ -132,7 +132,6 @@ MAX_THREADS = 100
|
||||||
|
|
||||||
# Create a ThreadPoolExecutor
|
# Create a ThreadPoolExecutor
|
||||||
executor = ThreadPoolExecutor(max_workers=MAX_THREADS)
|
executor = ThreadPoolExecutor(max_workers=MAX_THREADS)
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
|
||||||
sentry_sdk_instance = None
|
sentry_sdk_instance = None
|
||||||
capture_exception = None
|
capture_exception = None
|
||||||
add_breadcrumb = None
|
add_breadcrumb = None
|
||||||
|
@ -10474,6 +10473,12 @@ class CustomStreamWrapper:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def handle_bedrock_stream(self, chunk):
|
def handle_bedrock_stream(self, chunk):
|
||||||
|
if "cohere" in self.model:
|
||||||
|
return {
|
||||||
|
"text": chunk["text"],
|
||||||
|
"is_finished": chunk["is_finished"],
|
||||||
|
"finish_reason": chunk["finish_reason"],
|
||||||
|
}
|
||||||
if hasattr(chunk, "get"):
|
if hasattr(chunk, "get"):
|
||||||
chunk = chunk.get("chunk")
|
chunk = chunk.get("chunk")
|
||||||
chunk_data = json.loads(chunk.get("bytes").decode())
|
chunk_data = json.loads(chunk.get("bytes").decode())
|
||||||
|
@ -11322,6 +11327,7 @@ class CustomStreamWrapper:
|
||||||
or self.custom_llm_provider == "gemini"
|
or self.custom_llm_provider == "gemini"
|
||||||
or self.custom_llm_provider == "cached_response"
|
or self.custom_llm_provider == "cached_response"
|
||||||
or self.custom_llm_provider == "predibase"
|
or self.custom_llm_provider == "predibase"
|
||||||
|
or (self.custom_llm_provider == "bedrock" and "cohere" in self.model)
|
||||||
or self.custom_llm_provider in litellm.openai_compatible_endpoints
|
or self.custom_llm_provider in litellm.openai_compatible_endpoints
|
||||||
):
|
):
|
||||||
async for chunk in self.completion_stream:
|
async for chunk in self.completion_stream:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue