refactor(main.py): only route anthropic calls through converse api

v0 scope let's move function calling to converse api
This commit is contained in:
Krrish Dholakia 2024-06-07 08:47:51 -07:00
parent 51ba5652a0
commit 35e4323095
6 changed files with 263 additions and 4332 deletions

View file

@ -1125,7 +1125,7 @@ class AmazonConverseConfig:
maxTokens: Optional[int] = None, maxTokens: Optional[int] = None,
stopSequences: Optional[List[str]] = None, stopSequences: Optional[List[str]] = None,
temperature: Optional[int] = None, temperature: Optional[int] = None,
top_p: Optional[int] = None, topP: Optional[int] = None,
) -> None: ) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
@ -1481,6 +1481,93 @@ class BedrockConverseLLM(BaseLLM):
return session.get_credentials() return session.get_credentials()
async def async_streaming(
self,
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
data: str,
timeout: Optional[Union[float, httpx.Timeout]],
encoding,
logging_obj,
stream,
optional_params: dict,
litellm_params=None,
logger_fn=None,
headers={},
client: Optional[AsyncHTTPHandler] = None,
) -> CustomStreamWrapper:
streaming_response = CustomStreamWrapper(
completion_stream=None,
make_call=partial(
make_call,
client=client,
api_base=api_base,
headers=headers,
data=data,
model=model,
messages=messages,
logging_obj=logging_obj,
),
model=model,
custom_llm_provider="bedrock",
logging_obj=logging_obj,
)
return streaming_response
async def async_completion(
self,
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
data: str,
timeout: Optional[Union[float, httpx.Timeout]],
encoding,
logging_obj,
stream,
optional_params: dict,
litellm_params=None,
logger_fn=None,
headers={},
client: Optional[AsyncHTTPHandler] = None,
) -> Union[ModelResponse, CustomStreamWrapper]:
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
client = AsyncHTTPHandler(**_params) # type: ignore
else:
client = client # type: ignore
try:
response = await client.post(api_base, headers=headers, data=data) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException as e:
raise BedrockError(status_code=408, message="Timeout error occurred.")
return self.process_response(
model=model,
response=response,
model_response=model_response,
stream=stream if isinstance(stream, bool) else False,
logging_obj=logging_obj,
api_key="",
data=data,
messages=messages,
print_verbose=print_verbose,
optional_params=optional_params,
encoding=encoding,
)
def completion( def completion(
self, self,
model: str, model: str,
@ -1504,7 +1591,7 @@ class BedrockConverseLLM(BaseLLM):
from botocore.auth import SigV4Auth from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials from botocore.credentials import Credentials
except ImportError as e: except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
## SETUP ## ## SETUP ##
@ -1658,6 +1745,46 @@ class BedrockConverseLLM(BaseLLM):
) )
### ROUTING (ASYNC, STREAMING, SYNC) ### ROUTING (ASYNC, STREAMING, SYNC)
if acompletion:
if isinstance(client, HTTPHandler):
client = None
if stream is True and provider != "ai21":
return self.async_streaming(
model=model,
messages=messages,
data=data,
api_base=prepped.url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
logging_obj=logging_obj,
optional_params=optional_params,
stream=True,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=prepped.headers,
timeout=timeout,
client=client,
) # type: ignore
### ASYNC COMPLETION
return self.async_completion(
model=model,
messages=messages,
data=data,
api_base=prepped.url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
logging_obj=logging_obj,
optional_params=optional_params,
stream=stream, # type: ignore
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=prepped.headers,
timeout=timeout,
client=client,
) # type: ignore
if (stream is not None and stream is True) and provider != "ai21": if (stream is not None and stream is True) and provider != "ai21":
streaming_response = CustomStreamWrapper( streaming_response = CustomStreamWrapper(
@ -1666,7 +1793,7 @@ class BedrockConverseLLM(BaseLLM):
make_sync_call, make_sync_call,
client=None, client=None,
api_base=prepped.url, api_base=prepped.url,
headers=prepped.headers, headers=prepped.headers, # type: ignore
data=data, data=data,
model=model, model=model,
messages=messages, messages=messages,
@ -1702,7 +1829,7 @@ class BedrockConverseLLM(BaseLLM):
except httpx.HTTPStatusError as err: except httpx.HTTPStatusError as err:
error_code = err.response.status_code error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=response.text) raise BedrockError(status_code=error_code, message=response.text)
except httpx.TimeoutException as e: except httpx.TimeoutException:
raise BedrockError(status_code=408, message="Timeout error occurred.") raise BedrockError(status_code=408, message="Timeout error occurred.")
return self.process_response( return self.process_response(
@ -1737,7 +1864,7 @@ class AWSEventStreamDecoder:
self.model = model self.model = model
self.parser = EventStreamJSONParser() self.parser = EventStreamJSONParser()
def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk: def converse_chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
text = "" text = ""
tool_str = "" tool_str = ""
is_finished = False is_finished = False
@ -1762,7 +1889,7 @@ class AWSEventStreamDecoder:
) )
return response return response
def _old_chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk: def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
text = "" text = ""
is_finished = False is_finished = False
finish_reason = "" finish_reason = ""
@ -1774,19 +1901,8 @@ class AWSEventStreamDecoder:
is_finished = True is_finished = True
finish_reason = "stop" finish_reason = "stop"
######## bedrock.anthropic mappings ############### ######## bedrock.anthropic mappings ###############
elif "completion" in chunk_data: # not claude-3
text = chunk_data["completion"] # bedrock.anthropic
stop_reason = chunk_data.get("stop_reason", None)
if stop_reason != None:
is_finished = True
finish_reason = stop_reason
elif "delta" in chunk_data: elif "delta" in chunk_data:
if chunk_data["delta"].get("text", None) is not None: return self.converse_chunk_parser(chunk_data=chunk_data)
text = chunk_data["delta"]["text"]
stop_reason = chunk_data["delta"].get("stop_reason", None)
if stop_reason != None:
is_finished = True
finish_reason = stop_reason
######## bedrock.mistral mappings ############### ######## bedrock.mistral mappings ###############
elif "outputs" in chunk_data: elif "outputs" in chunk_data:
if ( if (
@ -1851,9 +1967,15 @@ class AWSEventStreamDecoder:
def _parse_message_from_event(self, event) -> Optional[str]: def _parse_message_from_event(self, event) -> Optional[str]:
response_dict = event.to_response_dict() response_dict = event.to_response_dict()
parsed_response = self.parser.parse(response_dict, get_response_stream_shape())
if response_dict["status_code"] != 200: if response_dict["status_code"] != 200:
raise ValueError(f"Bad response code, expected 200: {response_dict}") raise ValueError(f"Bad response code, expected 200: {response_dict}")
if "chunk" in parsed_response:
chunk = parsed_response.get("chunk")
if not chunk:
return None
return chunk.get("bytes").decode() # type: ignore[no-any-return]
else:
chunk = response_dict.get("body") chunk = response_dict.get("body")
if not chunk: if not chunk:
return None return None

View file

@ -168,7 +168,6 @@ class HTTPHandler:
return response return response
def __del__(self) -> None: def __del__(self) -> None:
traceback.print_stack()
try: try:
self.close() self.close()
except Exception: except Exception:

View file

@ -121,7 +121,8 @@ azure_text_completions = AzureTextCompletion()
huggingface = Huggingface() huggingface = Huggingface()
predibase_chat_completions = PredibaseChatCompletion() predibase_chat_completions = PredibaseChatCompletion()
triton_chat_completions = TritonChatCompletion() triton_chat_completions = TritonChatCompletion()
bedrock_chat_completion = BedrockConverseLLM() bedrock_chat_completion = BedrockLLM()
bedrock_converse_chat_completion = BedrockConverseLLM()
vertex_chat_completion = VertexLLM() vertex_chat_completion = VertexLLM()
####### COMPLETION ENDPOINTS ################ ####### COMPLETION ENDPOINTS ################
@ -2096,6 +2097,24 @@ def completion(
custom_llm_provider="bedrock", custom_llm_provider="bedrock",
logging_obj=logging, logging_obj=logging,
) )
else:
if model.startswith("anthropic"):
response = bedrock_converse_chat_completion.completion(
model=model,
messages=messages,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
extra_headers=extra_headers,
timeout=timeout,
acompletion=acompletion,
client=client,
)
else: else:
response = bedrock_chat_completion.completion( response = bedrock_chat_completion.completion(
model=model, model=model,

File diff suppressed because it is too large Load diff

View file

@ -1288,14 +1288,14 @@ async def test_completion_replicate_llama3_streaming(sync_mode):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model", "model",
[ [
# "bedrock/cohere.command-r-plus-v1:0", "bedrock/cohere.command-r-plus-v1:0",
# "anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-3-sonnet-20240229-v1:0",
# "anthropic.claude-instant-v1", "anthropic.claude-instant-v1",
# "bedrock/ai21.j2-mid", "bedrock/ai21.j2-mid",
# "mistral.mistral-7b-instruct-v0:2", "mistral.mistral-7b-instruct-v0:2",
# "bedrock/amazon.titan-tg1-large", "bedrock/amazon.titan-tg1-large",
# "meta.llama3-8b-instruct-v1:0", "meta.llama3-8b-instruct-v1:0",
"cohere.command-text-v14" "cohere.command-text-v14",
], ],
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@ -1324,8 +1324,6 @@ async def test_bedrock_httpx_streaming(sync_mode, model):
raise Exception("finish reason not set") raise Exception("finish reason not set")
if complete_response.strip() == "": if complete_response.strip() == "":
raise Exception("Empty response received") raise Exception("Empty response received")
assert False
else: else:
response: litellm.CustomStreamWrapper = await litellm.acompletion( # type: ignore response: litellm.CustomStreamWrapper = await litellm.acompletion( # type: ignore
model=model, model=model,

View file

@ -5620,13 +5620,80 @@ def get_optional_params(
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider model=model, custom_llm_provider=custom_llm_provider
) )
if "ai21" in model:
_check_valid_arg(supported_params=supported_params)
# params "maxTokens":200,"temperature":0,"topP":250,"stop_sequences":[],
# https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra
if max_tokens is not None:
optional_params["maxTokens"] = max_tokens
if temperature is not None:
optional_params["temperature"] = temperature
if top_p is not None:
optional_params["topP"] = top_p
if stream:
optional_params["stream"] = stream
elif "anthropic" in model:
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
optional_params = litellm.AmazonConverseConfig().map_openai_params( optional_params = litellm.AmazonConverseConfig().map_openai_params(
model=model, model=model,
non_default_params=non_default_params, non_default_params=non_default_params,
optional_params=optional_params, optional_params=optional_params,
drop_params=drop_params, drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
) )
elif "amazon" in model: # amazon titan llms
_check_valid_arg(supported_params=supported_params)
# see https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-large
if max_tokens is not None:
optional_params["maxTokenCount"] = max_tokens
if temperature is not None:
optional_params["temperature"] = temperature
if stop is not None:
filtered_stop = _map_and_modify_arg(
{"stop": stop}, provider="bedrock", model=model
)
optional_params["stopSequences"] = filtered_stop["stop"]
if top_p is not None:
optional_params["topP"] = top_p
if stream:
optional_params["stream"] = stream
elif "meta" in model: # amazon / meta llms
_check_valid_arg(supported_params=supported_params)
# see https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-large
if max_tokens is not None:
optional_params["max_gen_len"] = max_tokens
if temperature is not None:
optional_params["temperature"] = temperature
if top_p is not None:
optional_params["top_p"] = top_p
if stream:
optional_params["stream"] = stream
elif "cohere" in model: # cohere models on bedrock
_check_valid_arg(supported_params=supported_params)
# handle cohere params
if stream:
optional_params["stream"] = stream
if temperature is not None:
optional_params["temperature"] = temperature
if max_tokens is not None:
optional_params["max_tokens"] = max_tokens
elif "mistral" in model:
_check_valid_arg(supported_params=supported_params)
# mistral params on bedrock
# \"max_tokens\":400,\"temperature\":0.7,\"top_p\":0.7,\"stop\":[\"\\\\n\\\\nHuman:\"]}"
if max_tokens is not None:
optional_params["max_tokens"] = max_tokens
if temperature is not None:
optional_params["temperature"] = temperature
if top_p is not None:
optional_params["top_p"] = top_p
if stop is not None:
optional_params["stop"] = stop
if stream is not None:
optional_params["stream"] = stream
elif custom_llm_provider == "aleph_alpha": elif custom_llm_provider == "aleph_alpha":
supported_params = [ supported_params = [
"max_tokens", "max_tokens",