mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(bedrock_httpx.py): working async bedrock command r calls
This commit is contained in:
parent
59c8c0adff
commit
49ab1a1d3f
6 changed files with 374 additions and 78 deletions
|
@ -151,19 +151,120 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
def process_streaming_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
response: requests.Response | httpx.Response,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
stream: bool,
|
||||||
|
logging_obj: litellm.utils.Logging,
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: str,
|
||||||
|
data: dict | str,
|
||||||
|
messages: List,
|
||||||
|
print_verbose,
|
||||||
|
encoding,
|
||||||
|
) -> CustomStreamWrapper:
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key=api_key,
|
||||||
|
original_response=response.text,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
print_verbose(f"raw model_response: {response.text}")
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
try:
|
||||||
|
completion_response = response.json()
|
||||||
|
except:
|
||||||
|
raise AnthropicError(
|
||||||
|
message=response.text, status_code=response.status_code
|
||||||
|
)
|
||||||
|
text_content = ""
|
||||||
|
tool_calls = []
|
||||||
|
for content in completion_response["content"]:
|
||||||
|
if content["type"] == "text":
|
||||||
|
text_content += content["text"]
|
||||||
|
## TOOL CALLING
|
||||||
|
elif content["type"] == "tool_use":
|
||||||
|
tool_calls.append(
|
||||||
|
{
|
||||||
|
"id": content["id"],
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": content["name"],
|
||||||
|
"arguments": json.dumps(content["input"]),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if "error" in completion_response:
|
||||||
|
raise AnthropicError(
|
||||||
|
message=str(completion_response["error"]),
|
||||||
|
status_code=response.status_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
|
||||||
|
# return an iterator
|
||||||
|
streaming_model_response = ModelResponse(stream=True)
|
||||||
|
streaming_model_response.choices[0].finish_reason = model_response.choices[ # type: ignore
|
||||||
|
0
|
||||||
|
].finish_reason
|
||||||
|
# streaming_model_response.choices = [litellm.utils.StreamingChoices()]
|
||||||
|
streaming_choice = litellm.utils.StreamingChoices()
|
||||||
|
streaming_choice.index = model_response.choices[0].index
|
||||||
|
_tool_calls = []
|
||||||
|
print_verbose(
|
||||||
|
f"type of model_response.choices[0]: {type(model_response.choices[0])}"
|
||||||
|
)
|
||||||
|
print_verbose(f"type of streaming_choice: {type(streaming_choice)}")
|
||||||
|
if isinstance(model_response.choices[0], litellm.Choices):
|
||||||
|
if getattr(
|
||||||
|
model_response.choices[0].message, "tool_calls", None
|
||||||
|
) is not None and isinstance(
|
||||||
|
model_response.choices[0].message.tool_calls, list
|
||||||
|
):
|
||||||
|
for tool_call in model_response.choices[0].message.tool_calls:
|
||||||
|
_tool_call = {**tool_call.dict(), "index": 0}
|
||||||
|
_tool_calls.append(_tool_call)
|
||||||
|
delta_obj = litellm.utils.Delta(
|
||||||
|
content=getattr(model_response.choices[0].message, "content", None),
|
||||||
|
role=model_response.choices[0].message.role,
|
||||||
|
tool_calls=_tool_calls,
|
||||||
|
)
|
||||||
|
streaming_choice.delta = delta_obj
|
||||||
|
streaming_model_response.choices = [streaming_choice]
|
||||||
|
completion_stream = ModelResponseIterator(
|
||||||
|
model_response=streaming_model_response
|
||||||
|
)
|
||||||
|
print_verbose(
|
||||||
|
"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
|
||||||
|
)
|
||||||
|
return CustomStreamWrapper(
|
||||||
|
completion_stream=completion_stream,
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="cached_response",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise AnthropicError(
|
||||||
|
status_code=422,
|
||||||
|
message="Unprocessable response object - {}".format(response.text),
|
||||||
|
)
|
||||||
|
|
||||||
def process_response(
|
def process_response(
|
||||||
self,
|
self,
|
||||||
model,
|
model: str,
|
||||||
response,
|
response: requests.Response | httpx.Response,
|
||||||
model_response,
|
model_response: ModelResponse,
|
||||||
_is_function_call,
|
stream: bool,
|
||||||
stream,
|
logging_obj: litellm.utils.Logging,
|
||||||
logging_obj,
|
optional_params: dict,
|
||||||
api_key,
|
api_key: str,
|
||||||
data,
|
data: dict | str,
|
||||||
messages,
|
messages: List,
|
||||||
print_verbose,
|
print_verbose,
|
||||||
):
|
encoding,
|
||||||
|
) -> ModelResponse:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
input=messages,
|
input=messages,
|
||||||
|
@ -216,51 +317,6 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
completion_response["stop_reason"]
|
completion_response["stop_reason"]
|
||||||
)
|
)
|
||||||
|
|
||||||
print_verbose(f"_is_function_call: {_is_function_call}; stream: {stream}")
|
|
||||||
if _is_function_call and stream:
|
|
||||||
print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
|
|
||||||
# return an iterator
|
|
||||||
streaming_model_response = ModelResponse(stream=True)
|
|
||||||
streaming_model_response.choices[0].finish_reason = model_response.choices[
|
|
||||||
0
|
|
||||||
].finish_reason
|
|
||||||
# streaming_model_response.choices = [litellm.utils.StreamingChoices()]
|
|
||||||
streaming_choice = litellm.utils.StreamingChoices()
|
|
||||||
streaming_choice.index = model_response.choices[0].index
|
|
||||||
_tool_calls = []
|
|
||||||
print_verbose(
|
|
||||||
f"type of model_response.choices[0]: {type(model_response.choices[0])}"
|
|
||||||
)
|
|
||||||
print_verbose(f"type of streaming_choice: {type(streaming_choice)}")
|
|
||||||
if isinstance(model_response.choices[0], litellm.Choices):
|
|
||||||
if getattr(
|
|
||||||
model_response.choices[0].message, "tool_calls", None
|
|
||||||
) is not None and isinstance(
|
|
||||||
model_response.choices[0].message.tool_calls, list
|
|
||||||
):
|
|
||||||
for tool_call in model_response.choices[0].message.tool_calls:
|
|
||||||
_tool_call = {**tool_call.dict(), "index": 0}
|
|
||||||
_tool_calls.append(_tool_call)
|
|
||||||
delta_obj = litellm.utils.Delta(
|
|
||||||
content=getattr(model_response.choices[0].message, "content", None),
|
|
||||||
role=model_response.choices[0].message.role,
|
|
||||||
tool_calls=_tool_calls,
|
|
||||||
)
|
|
||||||
streaming_choice.delta = delta_obj
|
|
||||||
streaming_model_response.choices = [streaming_choice]
|
|
||||||
completion_stream = ModelResponseIterator(
|
|
||||||
model_response=streaming_model_response
|
|
||||||
)
|
|
||||||
print_verbose(
|
|
||||||
"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
|
|
||||||
)
|
|
||||||
return CustomStreamWrapper(
|
|
||||||
completion_stream=completion_stream,
|
|
||||||
model=model,
|
|
||||||
custom_llm_provider="cached_response",
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
)
|
|
||||||
|
|
||||||
## CALCULATING USAGE
|
## CALCULATING USAGE
|
||||||
prompt_tokens = completion_response["usage"]["input_tokens"]
|
prompt_tokens = completion_response["usage"]["input_tokens"]
|
||||||
completion_tokens = completion_response["usage"]["output_tokens"]
|
completion_tokens = completion_response["usage"]["output_tokens"]
|
||||||
|
@ -273,7 +329,7 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
completion_tokens=completion_tokens,
|
completion_tokens=completion_tokens,
|
||||||
total_tokens=total_tokens,
|
total_tokens=total_tokens,
|
||||||
)
|
)
|
||||||
model_response.usage = usage
|
setattr(model_response, "usage", usage) # type: ignore
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
async def acompletion_stream_function(
|
async def acompletion_stream_function(
|
||||||
|
@ -289,7 +345,7 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
logging_obj,
|
logging_obj,
|
||||||
stream,
|
stream,
|
||||||
_is_function_call,
|
_is_function_call,
|
||||||
data=None,
|
data: dict,
|
||||||
optional_params=None,
|
optional_params=None,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
|
@ -331,12 +387,12 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
logging_obj,
|
logging_obj,
|
||||||
stream,
|
stream,
|
||||||
_is_function_call,
|
_is_function_call,
|
||||||
data=None,
|
data: dict,
|
||||||
optional_params=None,
|
optional_params: dict,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
headers={},
|
headers={},
|
||||||
):
|
) -> ModelResponse:
|
||||||
self.async_handler = AsyncHTTPHandler(
|
self.async_handler = AsyncHTTPHandler(
|
||||||
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||||
)
|
)
|
||||||
|
@ -347,13 +403,14 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
model=model,
|
model=model,
|
||||||
response=response,
|
response=response,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
_is_function_call=_is_function_call,
|
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
data=data,
|
data=data,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
|
optional_params=optional_params,
|
||||||
|
encoding=encoding,
|
||||||
)
|
)
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
|
@ -367,7 +424,7 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
encoding,
|
encoding,
|
||||||
api_key,
|
api_key,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
optional_params=None,
|
optional_params: dict,
|
||||||
acompletion=None,
|
acompletion=None,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
|
@ -526,17 +583,33 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
raise AnthropicError(
|
raise AnthropicError(
|
||||||
status_code=response.status_code, message=response.text
|
status_code=response.status_code, message=response.text
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if stream and _is_function_call:
|
||||||
|
return self.process_streaming_response(
|
||||||
|
model=model,
|
||||||
|
response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
stream=stream,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_key=api_key,
|
||||||
|
data=data,
|
||||||
|
messages=messages,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
optional_params=optional_params,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
return self.process_response(
|
return self.process_response(
|
||||||
model=model,
|
model=model,
|
||||||
response=response,
|
response=response,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
_is_function_call=_is_function_call,
|
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
data=data,
|
data=data,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
|
optional_params=optional_params,
|
||||||
|
encoding=encoding,
|
||||||
)
|
)
|
||||||
|
|
||||||
def embedding(self):
|
def embedding(self):
|
||||||
|
|
|
@ -1,12 +1,32 @@
|
||||||
## This is a template base class to be used for adding new LLM providers via API calls
|
## This is a template base class to be used for adding new LLM providers via API calls
|
||||||
import litellm
|
import litellm
|
||||||
import httpx
|
import httpx, requests
|
||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
|
from litellm.utils import Logging
|
||||||
|
|
||||||
|
|
||||||
class BaseLLM:
|
class BaseLLM:
|
||||||
_client_session: Optional[httpx.Client] = None
|
_client_session: Optional[httpx.Client] = None
|
||||||
|
|
||||||
|
def process_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
response: Union[requests.Response, httpx.Response],
|
||||||
|
model_response: litellm.utils.ModelResponse,
|
||||||
|
stream: bool,
|
||||||
|
logging_obj: Logging,
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: str,
|
||||||
|
data: Union[dict, str],
|
||||||
|
messages: list,
|
||||||
|
print_verbose,
|
||||||
|
encoding,
|
||||||
|
) -> litellm.utils.ModelResponse:
|
||||||
|
"""
|
||||||
|
Helper function to process the response across sync + async completion calls
|
||||||
|
"""
|
||||||
|
return model_response
|
||||||
|
|
||||||
def create_client_session(self):
|
def create_client_session(self):
|
||||||
if litellm.client_session:
|
if litellm.client_session:
|
||||||
_client_session = litellm.client_session
|
_client_session = litellm.client_session
|
||||||
|
|
|
@ -16,6 +16,7 @@ from litellm.utils import (
|
||||||
Message,
|
Message,
|
||||||
Choices,
|
Choices,
|
||||||
get_secret,
|
get_secret,
|
||||||
|
Logging,
|
||||||
)
|
)
|
||||||
import litellm
|
import litellm
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt, cohere_message_pt
|
from .prompt_templates.factory import prompt_factory, custom_prompt, cohere_message_pt
|
||||||
|
@ -255,6 +256,70 @@ class BedrockLLM(BaseLLM):
|
||||||
|
|
||||||
return session.get_credentials()
|
return session.get_credentials()
|
||||||
|
|
||||||
|
def process_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
response: requests.Response | httpx.Response,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
stream: bool,
|
||||||
|
logging_obj: Logging,
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: str,
|
||||||
|
data: Union[dict, str],
|
||||||
|
messages: List,
|
||||||
|
print_verbose,
|
||||||
|
encoding,
|
||||||
|
) -> ModelResponse:
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key=api_key,
|
||||||
|
original_response=response.text,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
print_verbose(f"raw model_response: {response.text}")
|
||||||
|
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
try:
|
||||||
|
completion_response = response.json()
|
||||||
|
except:
|
||||||
|
raise BedrockError(message=response.text, status_code=422)
|
||||||
|
|
||||||
|
try:
|
||||||
|
model_response.choices[0].message.content = completion_response["text"] # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
raise BedrockError(message=response.text, status_code=422)
|
||||||
|
|
||||||
|
## CALCULATING USAGE - bedrock returns usage in the headers
|
||||||
|
prompt_tokens = int(
|
||||||
|
response.headers.get(
|
||||||
|
"x-amzn-bedrock-input-token-count",
|
||||||
|
len(encoding.encode("".join(m.get("content", "") for m in messages))),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
completion_tokens = int(
|
||||||
|
response.headers.get(
|
||||||
|
"x-amzn-bedrock-output-token-count",
|
||||||
|
len(
|
||||||
|
encoding.encode(
|
||||||
|
model_response.choices[0].message.content, # type: ignore
|
||||||
|
disallowed_special=(),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model_response["created"] = int(time.time())
|
||||||
|
model_response["model"] = model
|
||||||
|
usage = Usage(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
|
)
|
||||||
|
setattr(model_response, "usage", usage)
|
||||||
|
|
||||||
|
return model_response
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -268,8 +333,9 @@ class BedrockLLM(BaseLLM):
|
||||||
timeout: Optional[Union[float, httpx.Timeout]],
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
|
acompletion: bool = False,
|
||||||
extra_headers: Optional[dict] = None,
|
extra_headers: Optional[dict] = None,
|
||||||
client: Optional[HTTPHandler] = None,
|
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
||||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
try:
|
try:
|
||||||
import boto3
|
import boto3
|
||||||
|
@ -381,13 +447,39 @@ class BedrockLLM(BaseLLM):
|
||||||
|
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
headers = {"Content-Type": "application/json"}
|
headers = {"Content-Type": "application/json"}
|
||||||
|
if extra_headers is not None:
|
||||||
|
headers = {"Content-Type": "application/json", **extra_headers}
|
||||||
request = AWSRequest(
|
request = AWSRequest(
|
||||||
method="POST", url=endpoint_url, data=data, headers=headers
|
method="POST", url=endpoint_url, data=data, headers=headers
|
||||||
)
|
)
|
||||||
sigv4.add_auth(request)
|
sigv4.add_auth(request)
|
||||||
prepped = request.prepare()
|
prepped = request.prepare()
|
||||||
|
|
||||||
if client is None:
|
### ROUTING (ASYNC, STREAMING, SYNC)
|
||||||
|
if acompletion:
|
||||||
|
if isinstance(client, HTTPHandler):
|
||||||
|
client = None
|
||||||
|
|
||||||
|
### ASYNC COMPLETION
|
||||||
|
return self.async_completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
data=data,
|
||||||
|
api_base=prepped.url,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
stream=False,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
headers=prepped.headers,
|
||||||
|
timeout=timeout,
|
||||||
|
client=client,
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
if client is None or isinstance(client, AsyncHTTPHandler):
|
||||||
_params = {}
|
_params = {}
|
||||||
if timeout is not None:
|
if timeout is not None:
|
||||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||||
|
@ -416,7 +508,62 @@ class BedrockLLM(BaseLLM):
|
||||||
error_code = err.response.status_code
|
error_code = err.response.status_code
|
||||||
raise BedrockError(status_code=error_code, message=response.text)
|
raise BedrockError(status_code=error_code, message=response.text)
|
||||||
|
|
||||||
return response
|
return self.process_response(
|
||||||
|
model=model,
|
||||||
|
response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
stream=stream,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
api_key="",
|
||||||
|
data=data,
|
||||||
|
messages=messages,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
data: str,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
encoding,
|
||||||
|
logging_obj,
|
||||||
|
stream,
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
) -> ModelResponse:
|
||||||
|
if client is None:
|
||||||
|
_params = {}
|
||||||
|
if timeout is not None:
|
||||||
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||||
|
timeout = httpx.Timeout(timeout)
|
||||||
|
_params["timeout"] = timeout
|
||||||
|
self.client = AsyncHTTPHandler(**_params) # type: ignore
|
||||||
|
else:
|
||||||
|
self.client = client # type: ignore
|
||||||
|
|
||||||
|
response = await self.client.post(api_base, headers=headers, data=data) # type: ignore
|
||||||
|
return self.process_response(
|
||||||
|
model=model,
|
||||||
|
response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
stream=stream,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_key="",
|
||||||
|
data=data,
|
||||||
|
messages=messages,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
optional_params=optional_params,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
|
||||||
def embedding(self, *args, **kwargs):
|
def embedding(self, *args, **kwargs):
|
||||||
return super().embedding(*args, **kwargs)
|
return super().embedding(*args, **kwargs)
|
||||||
|
|
|
@ -168,7 +168,7 @@ class PredibaseChatCompletion(BaseLLM):
|
||||||
logging_obj: litellm.utils.Logging,
|
logging_obj: litellm.utils.Logging,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
data: dict,
|
data: Union[dict, str],
|
||||||
messages: list,
|
messages: list,
|
||||||
print_verbose,
|
print_verbose,
|
||||||
encoding,
|
encoding,
|
||||||
|
@ -185,9 +185,7 @@ class PredibaseChatCompletion(BaseLLM):
|
||||||
try:
|
try:
|
||||||
completion_response = response.json()
|
completion_response = response.json()
|
||||||
except:
|
except:
|
||||||
raise PredibaseError(
|
raise PredibaseError(message=response.text, status_code=422)
|
||||||
message=response.text, status_code=response.status_code
|
|
||||||
)
|
|
||||||
if "error" in completion_response:
|
if "error" in completion_response:
|
||||||
raise PredibaseError(
|
raise PredibaseError(
|
||||||
message=str(completion_response["error"]),
|
message=str(completion_response["error"]),
|
||||||
|
@ -363,7 +361,7 @@ class PredibaseChatCompletion(BaseLLM):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
if acompletion is True:
|
if acompletion == True:
|
||||||
### ASYNC STREAMING
|
### ASYNC STREAMING
|
||||||
if stream == True:
|
if stream == True:
|
||||||
return self.async_streaming(
|
return self.async_streaming(
|
||||||
|
|
|
@ -1479,6 +1479,11 @@ class Router:
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
original_exception = e
|
original_exception = e
|
||||||
|
"""
|
||||||
|
- Check if available deployments - 'get_healthy_deployments() -> List`
|
||||||
|
- if no, Check if available fallbacks - `is_fallback(model_group: str, exception) -> bool`
|
||||||
|
- if no, back-off and retry up till num_retries - `_router_should_retry -> float`
|
||||||
|
"""
|
||||||
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available / Bad Request Error
|
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available / Bad Request Error
|
||||||
if (
|
if (
|
||||||
isinstance(original_exception, litellm.ContextWindowExceededError)
|
isinstance(original_exception, litellm.ContextWindowExceededError)
|
||||||
|
|
|
@ -2584,12 +2584,65 @@ def test_completion_chat_sagemaker_mistral():
|
||||||
# test_completion_chat_sagemaker_mistral()
|
# test_completion_chat_sagemaker_mistral()
|
||||||
|
|
||||||
|
|
||||||
def test_completion_bedrock_command_r():
|
def response_format_tests(response: litellm.ModelResponse):
|
||||||
|
assert isinstance(response.id, str)
|
||||||
|
assert response.id != ""
|
||||||
|
|
||||||
|
assert isinstance(response.object, str)
|
||||||
|
assert response.object != ""
|
||||||
|
|
||||||
|
assert isinstance(response.created, int)
|
||||||
|
|
||||||
|
assert isinstance(response.model, str)
|
||||||
|
assert response.model != ""
|
||||||
|
|
||||||
|
assert isinstance(response.choices, list)
|
||||||
|
assert len(response.choices) == 1
|
||||||
|
choice = response.choices[0]
|
||||||
|
assert isinstance(choice, litellm.Choices)
|
||||||
|
assert isinstance(choice.get("index"), int)
|
||||||
|
|
||||||
|
message = choice.get("message")
|
||||||
|
assert isinstance(message, litellm.Message)
|
||||||
|
assert isinstance(message.get("role"), str)
|
||||||
|
assert message.get("role") != ""
|
||||||
|
assert isinstance(message.get("content"), str)
|
||||||
|
assert message.get("content") != ""
|
||||||
|
|
||||||
|
assert choice.get("logprobs") is None
|
||||||
|
assert isinstance(choice.get("finish_reason"), str)
|
||||||
|
assert choice.get("finish_reason") != ""
|
||||||
|
|
||||||
|
assert isinstance(response.usage, litellm.Usage) # type: ignore
|
||||||
|
assert isinstance(response.usage.prompt_tokens, int) # type: ignore
|
||||||
|
assert isinstance(response.usage.completion_tokens, int) # type: ignore
|
||||||
|
assert isinstance(response.usage.total_tokens, int) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_completion_bedrock_command_r(sync_mode):
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
response = completion(
|
|
||||||
model="bedrock/cohere.command-r-plus-v1:0",
|
if sync_mode:
|
||||||
messages=[{"role": "user", "content": "Hey! how's it going?"}],
|
response = completion(
|
||||||
)
|
model="bedrock/cohere.command-r-plus-v1:0",
|
||||||
|
messages=[{"role": "user", "content": "Hey! how's it going?"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, litellm.ModelResponse)
|
||||||
|
|
||||||
|
response_format_tests(response=response)
|
||||||
|
else:
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="bedrock/cohere.command-r-plus-v1:0",
|
||||||
|
messages=[{"role": "user", "content": "Hey! how's it going?"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, litellm.ModelResponse)
|
||||||
|
|
||||||
|
print(f"response: {response}")
|
||||||
|
response_format_tests(response=response)
|
||||||
|
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue