forked from phoenix/litellm-mirror
fix(bedrock_httpx.py): move anthropic bedrock calls to httpx
Fixing https://github.com/BerriAI/litellm/issues/2921
This commit is contained in:
parent
10a672634d
commit
180bc46ca4
7 changed files with 298 additions and 56 deletions
|
@ -21,7 +21,7 @@ class BaseLLM:
|
||||||
messages: list,
|
messages: list,
|
||||||
print_verbose,
|
print_verbose,
|
||||||
encoding,
|
encoding,
|
||||||
) -> litellm.utils.ModelResponse:
|
) -> Union[litellm.utils.ModelResponse, litellm.utils.CustomStreamWrapper]:
|
||||||
"""
|
"""
|
||||||
Helper function to process the response across sync + async completion calls
|
Helper function to process the response across sync + async completion calls
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# What is this?
|
# What is this?
|
||||||
## Initial implementation of calling bedrock via httpx client (allows for async calls).
|
## Initial implementation of calling bedrock via httpx client (allows for async calls).
|
||||||
## V0 - just covers cohere command-r support
|
## V1 - covers cohere + anthropic claude-3 support
|
||||||
|
|
||||||
import os, types
|
import os, types
|
||||||
import json
|
import json
|
||||||
|
@ -29,12 +29,20 @@ from litellm.utils import (
|
||||||
get_secret,
|
get_secret,
|
||||||
Logging,
|
Logging,
|
||||||
)
|
)
|
||||||
import litellm
|
import litellm, uuid
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt, cohere_message_pt
|
from .prompt_templates.factory import (
|
||||||
|
prompt_factory,
|
||||||
|
custom_prompt,
|
||||||
|
cohere_message_pt,
|
||||||
|
construct_tool_use_system_prompt,
|
||||||
|
extract_between_tags,
|
||||||
|
parse_xml_params,
|
||||||
|
contains_tag,
|
||||||
|
)
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from .base import BaseLLM
|
from .base import BaseLLM
|
||||||
import httpx # type: ignore
|
import httpx # type: ignore
|
||||||
from .bedrock import BedrockError, convert_messages_to_prompt
|
from .bedrock import BedrockError, convert_messages_to_prompt, ModelResponseIterator
|
||||||
from litellm.types.llms.bedrock import *
|
from litellm.types.llms.bedrock import *
|
||||||
|
|
||||||
|
|
||||||
|
@ -280,7 +288,8 @@ class BedrockLLM(BaseLLM):
|
||||||
messages: List,
|
messages: List,
|
||||||
print_verbose,
|
print_verbose,
|
||||||
encoding,
|
encoding,
|
||||||
) -> ModelResponse:
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
|
provider = model.split(".")[0]
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
input=messages,
|
input=messages,
|
||||||
|
@ -297,26 +306,147 @@ class BedrockLLM(BaseLLM):
|
||||||
raise BedrockError(message=response.text, status_code=422)
|
raise BedrockError(message=response.text, status_code=422)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model_response.choices[0].message.content = completion_response["text"] # type: ignore
|
if provider == "cohere":
|
||||||
|
model_response.choices[0].message.content = completion_response["text"] # type: ignore
|
||||||
|
elif provider == "anthropic":
|
||||||
|
if model.startswith("anthropic.claude-3"):
|
||||||
|
json_schemas: dict = {}
|
||||||
|
_is_function_call = False
|
||||||
|
## Handle Tool Calling
|
||||||
|
if "tools" in optional_params:
|
||||||
|
_is_function_call = True
|
||||||
|
for tool in optional_params["tools"]:
|
||||||
|
json_schemas[tool["function"]["name"]] = tool[
|
||||||
|
"function"
|
||||||
|
].get("parameters", None)
|
||||||
|
outputText = completion_response.get("content")[0].get("text", None)
|
||||||
|
if outputText is not None and contains_tag(
|
||||||
|
"invoke", outputText
|
||||||
|
): # OUTPUT PARSE FUNCTION CALL
|
||||||
|
function_name = extract_between_tags("tool_name", outputText)[0]
|
||||||
|
function_arguments_str = extract_between_tags(
|
||||||
|
"invoke", outputText
|
||||||
|
)[0].strip()
|
||||||
|
function_arguments_str = (
|
||||||
|
f"<invoke>{function_arguments_str}</invoke>"
|
||||||
|
)
|
||||||
|
function_arguments = parse_xml_params(
|
||||||
|
function_arguments_str,
|
||||||
|
json_schema=json_schemas.get(
|
||||||
|
function_name, None
|
||||||
|
), # check if we have a json schema for this function name)
|
||||||
|
)
|
||||||
|
_message = litellm.Message(
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"id": f"call_{uuid.uuid4()}",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": function_name,
|
||||||
|
"arguments": json.dumps(function_arguments),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
content=None,
|
||||||
|
)
|
||||||
|
model_response.choices[0].message = _message # type: ignore
|
||||||
|
model_response._hidden_params["original_response"] = (
|
||||||
|
outputText # allow user to access raw anthropic tool calling response
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
_is_function_call == True
|
||||||
|
and stream is not None
|
||||||
|
and stream == True
|
||||||
|
):
|
||||||
|
print_verbose(
|
||||||
|
f"INSIDE BEDROCK STREAMING TOOL CALLING CONDITION BLOCK"
|
||||||
|
)
|
||||||
|
# return an iterator
|
||||||
|
streaming_model_response = ModelResponse(stream=True)
|
||||||
|
streaming_model_response.choices[0].finish_reason = getattr(
|
||||||
|
model_response.choices[0], "finish_reason", "stop"
|
||||||
|
)
|
||||||
|
# streaming_model_response.choices = [litellm.utils.StreamingChoices()]
|
||||||
|
streaming_choice = litellm.utils.StreamingChoices()
|
||||||
|
streaming_choice.index = model_response.choices[0].index
|
||||||
|
_tool_calls = []
|
||||||
|
print_verbose(
|
||||||
|
f"type of model_response.choices[0]: {type(model_response.choices[0])}"
|
||||||
|
)
|
||||||
|
print_verbose(
|
||||||
|
f"type of streaming_choice: {type(streaming_choice)}"
|
||||||
|
)
|
||||||
|
if isinstance(model_response.choices[0], litellm.Choices):
|
||||||
|
if getattr(
|
||||||
|
model_response.choices[0].message, "tool_calls", None
|
||||||
|
) is not None and isinstance(
|
||||||
|
model_response.choices[0].message.tool_calls, list
|
||||||
|
):
|
||||||
|
for tool_call in model_response.choices[
|
||||||
|
0
|
||||||
|
].message.tool_calls:
|
||||||
|
_tool_call = {**tool_call.dict(), "index": 0}
|
||||||
|
_tool_calls.append(_tool_call)
|
||||||
|
delta_obj = litellm.utils.Delta(
|
||||||
|
content=getattr(
|
||||||
|
model_response.choices[0].message, "content", None
|
||||||
|
),
|
||||||
|
role=model_response.choices[0].message.role,
|
||||||
|
tool_calls=_tool_calls,
|
||||||
|
)
|
||||||
|
streaming_choice.delta = delta_obj
|
||||||
|
streaming_model_response.choices = [streaming_choice]
|
||||||
|
completion_stream = ModelResponseIterator(
|
||||||
|
model_response=streaming_model_response
|
||||||
|
)
|
||||||
|
print_verbose(
|
||||||
|
f"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
|
||||||
|
)
|
||||||
|
return litellm.CustomStreamWrapper(
|
||||||
|
completion_stream=completion_stream,
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="cached_response",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_response["finish_reason"] = map_finish_reason(
|
||||||
|
completion_response.get("stop_reason", "")
|
||||||
|
)
|
||||||
|
_usage = litellm.Usage(
|
||||||
|
prompt_tokens=completion_response["usage"]["input_tokens"],
|
||||||
|
completion_tokens=completion_response["usage"]["output_tokens"],
|
||||||
|
total_tokens=completion_response["usage"]["input_tokens"]
|
||||||
|
+ completion_response["usage"]["output_tokens"],
|
||||||
|
)
|
||||||
|
setattr(model_response, "usage", _usage)
|
||||||
|
else:
|
||||||
|
outputText = completion_response["completion"]
|
||||||
|
model_response["finish_reason"] = completion_response["stop_reason"]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise BedrockError(message=response.text, status_code=422)
|
raise BedrockError(
|
||||||
|
message="Error processing={}, Received error={}".format(
|
||||||
|
response.text, str(e)
|
||||||
|
),
|
||||||
|
status_code=422,
|
||||||
|
)
|
||||||
|
|
||||||
## CALCULATING USAGE - bedrock returns usage in the headers
|
## CALCULATING USAGE - bedrock returns usage in the headers
|
||||||
prompt_tokens = int(
|
bedrock_input_tokens = response.headers.get(
|
||||||
response.headers.get(
|
"x-amzn-bedrock-input-token-count", None
|
||||||
"x-amzn-bedrock-input-token-count",
|
|
||||||
len(encoding.encode("".join(m.get("content", "") for m in messages))),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
bedrock_output_tokens = response.headers.get(
|
||||||
|
"x-amzn-bedrock-output-token-count", None
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_tokens = int(
|
||||||
|
bedrock_input_tokens or litellm.token_counter(messages=messages)
|
||||||
|
)
|
||||||
|
|
||||||
completion_tokens = int(
|
completion_tokens = int(
|
||||||
response.headers.get(
|
bedrock_output_tokens
|
||||||
"x-amzn-bedrock-output-token-count",
|
or litellm.token_counter(
|
||||||
len(
|
text=model_response.choices[0].message.content, # type: ignore
|
||||||
encoding.encode(
|
count_response_tokens=True,
|
||||||
model_response.choices[0].message.content, # type: ignore
|
|
||||||
disallowed_special=(),
|
|
||||||
)
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -426,7 +556,7 @@ class BedrockLLM(BaseLLM):
|
||||||
model, messages, provider, custom_prompt_dict
|
model, messages, provider, custom_prompt_dict
|
||||||
)
|
)
|
||||||
inference_params = copy.deepcopy(optional_params)
|
inference_params = copy.deepcopy(optional_params)
|
||||||
|
json_schemas: dict = {}
|
||||||
if provider == "cohere":
|
if provider == "cohere":
|
||||||
if model.startswith("cohere.command-r"):
|
if model.startswith("cohere.command-r"):
|
||||||
## LOAD CONFIG
|
## LOAD CONFIG
|
||||||
|
@ -453,6 +583,56 @@ class BedrockLLM(BaseLLM):
|
||||||
True # cohere requires stream = True in inference params
|
True # cohere requires stream = True in inference params
|
||||||
)
|
)
|
||||||
data = json.dumps({"prompt": prompt, **inference_params})
|
data = json.dumps({"prompt": prompt, **inference_params})
|
||||||
|
elif provider == "anthropic":
|
||||||
|
if model.startswith("anthropic.claude-3"):
|
||||||
|
# Separate system prompt from rest of message
|
||||||
|
system_prompt_idx: list[int] = []
|
||||||
|
system_messages: list[str] = []
|
||||||
|
for idx, message in enumerate(messages):
|
||||||
|
if message["role"] == "system":
|
||||||
|
system_messages.append(message["content"])
|
||||||
|
system_prompt_idx.append(idx)
|
||||||
|
if len(system_prompt_idx) > 0:
|
||||||
|
inference_params["system"] = "\n".join(system_messages)
|
||||||
|
messages = [
|
||||||
|
i for j, i in enumerate(messages) if j not in system_prompt_idx
|
||||||
|
]
|
||||||
|
# Format rest of message according to anthropic guidelines
|
||||||
|
messages = prompt_factory(
|
||||||
|
model=model, messages=messages, custom_llm_provider="anthropic_xml"
|
||||||
|
) # type: ignore
|
||||||
|
## LOAD CONFIG
|
||||||
|
config = litellm.AmazonAnthropicClaude3Config.get_config()
|
||||||
|
for k, v in config.items():
|
||||||
|
if (
|
||||||
|
k not in inference_params
|
||||||
|
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||||
|
inference_params[k] = v
|
||||||
|
## Handle Tool Calling
|
||||||
|
if "tools" in inference_params:
|
||||||
|
_is_function_call = True
|
||||||
|
for tool in inference_params["tools"]:
|
||||||
|
json_schemas[tool["function"]["name"]] = tool["function"].get(
|
||||||
|
"parameters", None
|
||||||
|
)
|
||||||
|
tool_calling_system_prompt = construct_tool_use_system_prompt(
|
||||||
|
tools=inference_params["tools"]
|
||||||
|
)
|
||||||
|
inference_params["system"] = (
|
||||||
|
inference_params.get("system", "\n")
|
||||||
|
+ tool_calling_system_prompt
|
||||||
|
) # add the anthropic tool calling prompt to the system prompt
|
||||||
|
inference_params.pop("tools")
|
||||||
|
data = json.dumps({"messages": messages, **inference_params})
|
||||||
|
else:
|
||||||
|
## LOAD CONFIG
|
||||||
|
config = litellm.AmazonAnthropicConfig.get_config()
|
||||||
|
for k, v in config.items():
|
||||||
|
if (
|
||||||
|
k not in inference_params
|
||||||
|
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||||
|
inference_params[k] = v
|
||||||
|
data = json.dumps({"prompt": prompt, **inference_params})
|
||||||
else:
|
else:
|
||||||
raise Exception("UNSUPPORTED PROVIDER")
|
raise Exception("UNSUPPORTED PROVIDER")
|
||||||
|
|
||||||
|
@ -541,7 +721,7 @@ class BedrockLLM(BaseLLM):
|
||||||
status_code=response.status_code, message=response.text
|
status_code=response.status_code, message=response.text
|
||||||
)
|
)
|
||||||
|
|
||||||
decoder = AWSEventStreamDecoder()
|
decoder = AWSEventStreamDecoder(model=model)
|
||||||
|
|
||||||
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
|
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
|
||||||
streaming_response = CustomStreamWrapper(
|
streaming_response = CustomStreamWrapper(
|
||||||
|
@ -591,7 +771,7 @@ class BedrockLLM(BaseLLM):
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
headers={},
|
headers={},
|
||||||
client: Optional[AsyncHTTPHandler] = None,
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
) -> ModelResponse:
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
if client is None:
|
if client is None:
|
||||||
_params = {}
|
_params = {}
|
||||||
if timeout is not None:
|
if timeout is not None:
|
||||||
|
@ -650,7 +830,7 @@ class BedrockLLM(BaseLLM):
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise BedrockError(status_code=response.status_code, message=response.text)
|
raise BedrockError(status_code=response.status_code, message=response.text)
|
||||||
|
|
||||||
decoder = AWSEventStreamDecoder()
|
decoder = AWSEventStreamDecoder(model=model)
|
||||||
|
|
||||||
completion_stream = decoder.aiter_bytes(response.aiter_bytes(chunk_size=1024))
|
completion_stream = decoder.aiter_bytes(response.aiter_bytes(chunk_size=1024))
|
||||||
streaming_response = CustomStreamWrapper(
|
streaming_response = CustomStreamWrapper(
|
||||||
|
@ -676,11 +856,70 @@ def get_response_stream_shape():
|
||||||
|
|
||||||
|
|
||||||
class AWSEventStreamDecoder:
|
class AWSEventStreamDecoder:
|
||||||
def __init__(self) -> None:
|
def __init__(self, model: str) -> None:
|
||||||
from botocore.parsers import EventStreamJSONParser
|
from botocore.parsers import EventStreamJSONParser
|
||||||
|
|
||||||
|
self.model = model
|
||||||
self.parser = EventStreamJSONParser()
|
self.parser = EventStreamJSONParser()
|
||||||
|
|
||||||
|
def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
|
||||||
|
text = ""
|
||||||
|
is_finished = False
|
||||||
|
finish_reason = ""
|
||||||
|
if "outputText" in chunk_data:
|
||||||
|
text = chunk_data["outputText"]
|
||||||
|
# ai21 mapping
|
||||||
|
if "ai21" in self.model: # fake ai21 streaming
|
||||||
|
text = chunk_data.get("completions")[0].get("data").get("text") # type: ignore
|
||||||
|
is_finished = True
|
||||||
|
finish_reason = "stop"
|
||||||
|
######## bedrock.anthropic mappings ###############
|
||||||
|
elif "completion" in chunk_data: # not claude-3
|
||||||
|
text = chunk_data["completion"] # bedrock.anthropic
|
||||||
|
stop_reason = chunk_data.get("stop_reason", None)
|
||||||
|
if stop_reason != None:
|
||||||
|
is_finished = True
|
||||||
|
finish_reason = stop_reason
|
||||||
|
elif "delta" in chunk_data:
|
||||||
|
if chunk_data["delta"].get("text", None) is not None:
|
||||||
|
text = chunk_data["delta"]["text"]
|
||||||
|
stop_reason = chunk_data["delta"].get("stop_reason", None)
|
||||||
|
if stop_reason != None:
|
||||||
|
is_finished = True
|
||||||
|
finish_reason = stop_reason
|
||||||
|
######## bedrock.mistral mappings ###############
|
||||||
|
elif "outputs" in chunk_data:
|
||||||
|
if (
|
||||||
|
len(chunk_data["outputs"]) == 1
|
||||||
|
and chunk_data["outputs"][0].get("text", None) is not None
|
||||||
|
):
|
||||||
|
text = chunk_data["outputs"][0]["text"]
|
||||||
|
stop_reason = chunk_data.get("stop_reason", None)
|
||||||
|
if stop_reason != None:
|
||||||
|
is_finished = True
|
||||||
|
finish_reason = stop_reason
|
||||||
|
######## bedrock.cohere mappings ###############
|
||||||
|
# meta mapping
|
||||||
|
elif "generation" in chunk_data:
|
||||||
|
text = chunk_data["generation"] # bedrock.meta
|
||||||
|
# cohere mapping
|
||||||
|
elif "text" in chunk_data:
|
||||||
|
text = chunk_data["text"] # bedrock.cohere
|
||||||
|
# cohere mapping for finish reason
|
||||||
|
elif "finish_reason" in chunk_data:
|
||||||
|
finish_reason = chunk_data["finish_reason"]
|
||||||
|
is_finished = True
|
||||||
|
elif chunk_data.get("completionReason", None):
|
||||||
|
is_finished = True
|
||||||
|
finish_reason = chunk_data["completionReason"]
|
||||||
|
return GenericStreamingChunk(
|
||||||
|
**{
|
||||||
|
"text": text,
|
||||||
|
"is_finished": is_finished,
|
||||||
|
"finish_reason": finish_reason,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GenericStreamingChunk]:
|
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GenericStreamingChunk]:
|
||||||
"""Given an iterator that yields lines, iterate over it & yield every event encountered"""
|
"""Given an iterator that yields lines, iterate over it & yield every event encountered"""
|
||||||
from botocore.eventstream import EventStreamBuffer
|
from botocore.eventstream import EventStreamBuffer
|
||||||
|
@ -693,12 +932,7 @@ class AWSEventStreamDecoder:
|
||||||
if message:
|
if message:
|
||||||
# sse_event = ServerSentEvent(data=message, event="completion")
|
# sse_event = ServerSentEvent(data=message, event="completion")
|
||||||
_data = json.loads(message)
|
_data = json.loads(message)
|
||||||
streaming_chunk: GenericStreamingChunk = GenericStreamingChunk(
|
yield self._chunk_parser(chunk_data=_data)
|
||||||
text=_data.get("text", ""),
|
|
||||||
is_finished=_data.get("is_finished", False),
|
|
||||||
finish_reason=_data.get("finish_reason", ""),
|
|
||||||
)
|
|
||||||
yield streaming_chunk
|
|
||||||
|
|
||||||
async def aiter_bytes(
|
async def aiter_bytes(
|
||||||
self, iterator: AsyncIterator[bytes]
|
self, iterator: AsyncIterator[bytes]
|
||||||
|
@ -713,12 +947,7 @@ class AWSEventStreamDecoder:
|
||||||
message = self._parse_message_from_event(event)
|
message = self._parse_message_from_event(event)
|
||||||
if message:
|
if message:
|
||||||
_data = json.loads(message)
|
_data = json.loads(message)
|
||||||
streaming_chunk: GenericStreamingChunk = GenericStreamingChunk(
|
yield self._chunk_parser(chunk_data=_data)
|
||||||
text=_data.get("text", ""),
|
|
||||||
is_finished=_data.get("is_finished", False),
|
|
||||||
finish_reason=_data.get("finish_reason", ""),
|
|
||||||
)
|
|
||||||
yield streaming_chunk
|
|
||||||
|
|
||||||
def _parse_message_from_event(self, event) -> Optional[str]:
|
def _parse_message_from_event(self, event) -> Optional[str]:
|
||||||
response_dict = event.to_response_dict()
|
response_dict = event.to_response_dict()
|
||||||
|
|
|
@ -326,7 +326,10 @@ async def acompletion(
|
||||||
or custom_llm_provider == "sagemaker"
|
or custom_llm_provider == "sagemaker"
|
||||||
or custom_llm_provider == "anthropic"
|
or custom_llm_provider == "anthropic"
|
||||||
or custom_llm_provider == "predibase"
|
or custom_llm_provider == "predibase"
|
||||||
or (custom_llm_provider == "bedrock" and "cohere" in model)
|
or (
|
||||||
|
custom_llm_provider == "bedrock"
|
||||||
|
and ("cohere" in model or "anthropic" in model)
|
||||||
|
)
|
||||||
or custom_llm_provider in litellm.openai_compatible_providers
|
or custom_llm_provider in litellm.openai_compatible_providers
|
||||||
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
|
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
|
||||||
init_response = await loop.run_in_executor(None, func_with_context)
|
init_response = await loop.run_in_executor(None, func_with_context)
|
||||||
|
@ -1979,7 +1982,7 @@ def completion(
|
||||||
# boto3 reads keys from .env
|
# boto3 reads keys from .env
|
||||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||||
|
|
||||||
if "cohere" in model:
|
if "cohere" in model or "anthropic" in model:
|
||||||
response = bedrock_chat_completion.completion(
|
response = bedrock_chat_completion.completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
|
|
@ -4874,11 +4874,12 @@ async def token_counter(request: TokenCountRequest):
|
||||||
model_to_use = (
|
model_to_use = (
|
||||||
litellm_model_name or request.model
|
litellm_model_name or request.model
|
||||||
) # use litellm model name, if it's not avalable then fallback to request.model
|
) # use litellm model name, if it's not avalable then fallback to request.model
|
||||||
total_tokens, tokenizer_used = token_counter(
|
_tokenizer_used = litellm.utils._select_tokenizer(model=model_to_use)
|
||||||
|
tokenizer_used = str(_tokenizer_used["type"])
|
||||||
|
total_tokens = token_counter(
|
||||||
model=model_to_use,
|
model=model_to_use,
|
||||||
text=prompt,
|
text=prompt,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
return_tokenizer_used=True,
|
|
||||||
)
|
)
|
||||||
return TokenCountResponse(
|
return TokenCountResponse(
|
||||||
total_tokens=total_tokens,
|
total_tokens=total_tokens,
|
||||||
|
|
|
@ -2663,13 +2663,17 @@ def response_format_tests(response: litellm.ModelResponse):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model",
|
||||||
|
["bedrock/cohere.command-r-plus-v1:0", "anthropic.claude-3-sonnet-20240229-v1:0"],
|
||||||
|
)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_completion_bedrock_command_r(sync_mode):
|
async def test_completion_bedrock_httpx_models(sync_mode, model):
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
|
||||||
if sync_mode:
|
if sync_mode:
|
||||||
response = completion(
|
response = completion(
|
||||||
model="bedrock/cohere.command-r-plus-v1:0",
|
model=model,
|
||||||
messages=[{"role": "user", "content": "Hey! how's it going?"}],
|
messages=[{"role": "user", "content": "Hey! how's it going?"}],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -2678,7 +2682,7 @@ async def test_completion_bedrock_command_r(sync_mode):
|
||||||
response_format_tests(response=response)
|
response_format_tests(response=response)
|
||||||
else:
|
else:
|
||||||
response = await litellm.acompletion(
|
response = await litellm.acompletion(
|
||||||
model="bedrock/cohere.command-r-plus-v1:0",
|
model=model,
|
||||||
messages=[{"role": "user", "content": "Hey! how's it going?"}],
|
messages=[{"role": "user", "content": "Hey! how's it going?"}],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1041,14 +1041,21 @@ async def test_completion_replicate_llama3_streaming(sync_mode):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model",
|
||||||
|
[
|
||||||
|
# "bedrock/cohere.command-r-plus-v1:0",
|
||||||
|
"anthropic.claude-3-sonnet-20240229-v1:0"
|
||||||
|
],
|
||||||
|
)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_bedrock_cohere_command_r_streaming(sync_mode):
|
async def test_bedrock_httpx_streaming(sync_mode, model):
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
if sync_mode:
|
if sync_mode:
|
||||||
final_chunk: Optional[litellm.ModelResponse] = None
|
final_chunk: Optional[litellm.ModelResponse] = None
|
||||||
response: litellm.CustomStreamWrapper = completion( # type: ignore
|
response: litellm.CustomStreamWrapper = completion( # type: ignore
|
||||||
model="bedrock/cohere.command-r-plus-v1:0",
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_tokens=10, # type: ignore
|
max_tokens=10, # type: ignore
|
||||||
stream=True,
|
stream=True,
|
||||||
|
@ -1069,7 +1076,7 @@ async def test_bedrock_cohere_command_r_streaming(sync_mode):
|
||||||
raise Exception("Empty response received")
|
raise Exception("Empty response received")
|
||||||
else:
|
else:
|
||||||
response: litellm.CustomStreamWrapper = await litellm.acompletion( # type: ignore
|
response: litellm.CustomStreamWrapper = await litellm.acompletion( # type: ignore
|
||||||
model="bedrock/cohere.command-r-plus-v1:0",
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_tokens=100, # type: ignore
|
max_tokens=100, # type: ignore
|
||||||
stream=True,
|
stream=True,
|
||||||
|
|
|
@ -4123,8 +4123,7 @@ def token_counter(
|
||||||
text: Optional[Union[str, List[str]]] = None,
|
text: Optional[Union[str, List[str]]] = None,
|
||||||
messages: Optional[List] = None,
|
messages: Optional[List] = None,
|
||||||
count_response_tokens: Optional[bool] = False,
|
count_response_tokens: Optional[bool] = False,
|
||||||
return_tokenizer_used: Optional[bool] = False,
|
) -> int:
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Count the number of tokens in a given text using a specified model.
|
Count the number of tokens in a given text using a specified model.
|
||||||
|
|
||||||
|
@ -4216,10 +4215,6 @@ def token_counter(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
num_tokens = len(encoding.encode(text, disallowed_special=())) # type: ignore
|
num_tokens = len(encoding.encode(text, disallowed_special=())) # type: ignore
|
||||||
_tokenizer_type = tokenizer_json["type"]
|
|
||||||
if return_tokenizer_used:
|
|
||||||
# used by litellm proxy server -> POST /utils/token_counter
|
|
||||||
return num_tokens, _tokenizer_type
|
|
||||||
return num_tokens
|
return num_tokens
|
||||||
|
|
||||||
|
|
||||||
|
@ -10642,7 +10637,7 @@ class CustomStreamWrapper:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def handle_bedrock_stream(self, chunk):
|
def handle_bedrock_stream(self, chunk):
|
||||||
if "cohere" in self.model:
|
if "cohere" in self.model or "anthropic" in self.model:
|
||||||
return {
|
return {
|
||||||
"text": chunk["text"],
|
"text": chunk["text"],
|
||||||
"is_finished": chunk["is_finished"],
|
"is_finished": chunk["is_finished"],
|
||||||
|
@ -11513,7 +11508,10 @@ class CustomStreamWrapper:
|
||||||
or self.custom_llm_provider == "replicate"
|
or self.custom_llm_provider == "replicate"
|
||||||
or self.custom_llm_provider == "cached_response"
|
or self.custom_llm_provider == "cached_response"
|
||||||
or self.custom_llm_provider == "predibase"
|
or self.custom_llm_provider == "predibase"
|
||||||
or (self.custom_llm_provider == "bedrock" and "cohere" in self.model)
|
or (
|
||||||
|
self.custom_llm_provider == "bedrock"
|
||||||
|
and ("cohere" in self.model or "anthropic" in self.model)
|
||||||
|
)
|
||||||
or self.custom_llm_provider in litellm.openai_compatible_endpoints
|
or self.custom_llm_provider in litellm.openai_compatible_endpoints
|
||||||
):
|
):
|
||||||
async for chunk in self.completion_stream:
|
async for chunk in self.completion_stream:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue