mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
refactor(main.py): only route anthropic calls through converse api
v0 scope let's move function calling to converse api
This commit is contained in:
parent
51ba5652a0
commit
35e4323095
6 changed files with 263 additions and 4332 deletions
|
@ -1125,7 +1125,7 @@ class AmazonConverseConfig:
|
||||||
maxTokens: Optional[int] = None,
|
maxTokens: Optional[int] = None,
|
||||||
stopSequences: Optional[List[str]] = None,
|
stopSequences: Optional[List[str]] = None,
|
||||||
temperature: Optional[int] = None,
|
temperature: Optional[int] = None,
|
||||||
top_p: Optional[int] = None,
|
topP: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
locals_ = locals()
|
locals_ = locals()
|
||||||
for key, value in locals_.items():
|
for key, value in locals_.items():
|
||||||
|
@ -1481,6 +1481,93 @@ class BedrockConverseLLM(BaseLLM):
|
||||||
|
|
||||||
return session.get_credentials()
|
return session.get_credentials()
|
||||||
|
|
||||||
|
async def async_streaming(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
data: str,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
encoding,
|
||||||
|
logging_obj,
|
||||||
|
stream,
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
) -> CustomStreamWrapper:
|
||||||
|
streaming_response = CustomStreamWrapper(
|
||||||
|
completion_stream=None,
|
||||||
|
make_call=partial(
|
||||||
|
make_call,
|
||||||
|
client=client,
|
||||||
|
api_base=api_base,
|
||||||
|
headers=headers,
|
||||||
|
data=data,
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
),
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="bedrock",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
return streaming_response
|
||||||
|
|
||||||
|
async def async_completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
data: str,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
encoding,
|
||||||
|
logging_obj,
|
||||||
|
stream,
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
|
if client is None:
|
||||||
|
_params = {}
|
||||||
|
if timeout is not None:
|
||||||
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||||
|
timeout = httpx.Timeout(timeout)
|
||||||
|
_params["timeout"] = timeout
|
||||||
|
client = AsyncHTTPHandler(**_params) # type: ignore
|
||||||
|
else:
|
||||||
|
client = client # type: ignore
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.post(api_base, headers=headers, data=data) # type: ignore
|
||||||
|
response.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError as err:
|
||||||
|
error_code = err.response.status_code
|
||||||
|
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||||
|
except httpx.TimeoutException as e:
|
||||||
|
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||||
|
|
||||||
|
return self.process_response(
|
||||||
|
model=model,
|
||||||
|
response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
stream=stream if isinstance(stream, bool) else False,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_key="",
|
||||||
|
data=data,
|
||||||
|
messages=messages,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
optional_params=optional_params,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -1504,7 +1591,7 @@ class BedrockConverseLLM(BaseLLM):
|
||||||
from botocore.auth import SigV4Auth
|
from botocore.auth import SigV4Auth
|
||||||
from botocore.awsrequest import AWSRequest
|
from botocore.awsrequest import AWSRequest
|
||||||
from botocore.credentials import Credentials
|
from botocore.credentials import Credentials
|
||||||
except ImportError as e:
|
except ImportError:
|
||||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||||
|
|
||||||
## SETUP ##
|
## SETUP ##
|
||||||
|
@ -1658,6 +1745,46 @@ class BedrockConverseLLM(BaseLLM):
|
||||||
)
|
)
|
||||||
|
|
||||||
### ROUTING (ASYNC, STREAMING, SYNC)
|
### ROUTING (ASYNC, STREAMING, SYNC)
|
||||||
|
if acompletion:
|
||||||
|
if isinstance(client, HTTPHandler):
|
||||||
|
client = None
|
||||||
|
if stream is True and provider != "ai21":
|
||||||
|
return self.async_streaming(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
data=data,
|
||||||
|
api_base=prepped.url,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
stream=True,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
headers=prepped.headers,
|
||||||
|
timeout=timeout,
|
||||||
|
client=client,
|
||||||
|
) # type: ignore
|
||||||
|
### ASYNC COMPLETION
|
||||||
|
return self.async_completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
data=data,
|
||||||
|
api_base=prepped.url,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
stream=stream, # type: ignore
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
headers=prepped.headers,
|
||||||
|
timeout=timeout,
|
||||||
|
client=client,
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
if (stream is not None and stream is True) and provider != "ai21":
|
if (stream is not None and stream is True) and provider != "ai21":
|
||||||
|
|
||||||
streaming_response = CustomStreamWrapper(
|
streaming_response = CustomStreamWrapper(
|
||||||
|
@ -1666,7 +1793,7 @@ class BedrockConverseLLM(BaseLLM):
|
||||||
make_sync_call,
|
make_sync_call,
|
||||||
client=None,
|
client=None,
|
||||||
api_base=prepped.url,
|
api_base=prepped.url,
|
||||||
headers=prepped.headers,
|
headers=prepped.headers, # type: ignore
|
||||||
data=data,
|
data=data,
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -1702,7 +1829,7 @@ class BedrockConverseLLM(BaseLLM):
|
||||||
except httpx.HTTPStatusError as err:
|
except httpx.HTTPStatusError as err:
|
||||||
error_code = err.response.status_code
|
error_code = err.response.status_code
|
||||||
raise BedrockError(status_code=error_code, message=response.text)
|
raise BedrockError(status_code=error_code, message=response.text)
|
||||||
except httpx.TimeoutException as e:
|
except httpx.TimeoutException:
|
||||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||||
|
|
||||||
return self.process_response(
|
return self.process_response(
|
||||||
|
@ -1737,7 +1864,7 @@ class AWSEventStreamDecoder:
|
||||||
self.model = model
|
self.model = model
|
||||||
self.parser = EventStreamJSONParser()
|
self.parser = EventStreamJSONParser()
|
||||||
|
|
||||||
def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
|
def converse_chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
|
||||||
text = ""
|
text = ""
|
||||||
tool_str = ""
|
tool_str = ""
|
||||||
is_finished = False
|
is_finished = False
|
||||||
|
@ -1762,7 +1889,7 @@ class AWSEventStreamDecoder:
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def _old_chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
|
def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
|
||||||
text = ""
|
text = ""
|
||||||
is_finished = False
|
is_finished = False
|
||||||
finish_reason = ""
|
finish_reason = ""
|
||||||
|
@ -1774,19 +1901,8 @@ class AWSEventStreamDecoder:
|
||||||
is_finished = True
|
is_finished = True
|
||||||
finish_reason = "stop"
|
finish_reason = "stop"
|
||||||
######## bedrock.anthropic mappings ###############
|
######## bedrock.anthropic mappings ###############
|
||||||
elif "completion" in chunk_data: # not claude-3
|
|
||||||
text = chunk_data["completion"] # bedrock.anthropic
|
|
||||||
stop_reason = chunk_data.get("stop_reason", None)
|
|
||||||
if stop_reason != None:
|
|
||||||
is_finished = True
|
|
||||||
finish_reason = stop_reason
|
|
||||||
elif "delta" in chunk_data:
|
elif "delta" in chunk_data:
|
||||||
if chunk_data["delta"].get("text", None) is not None:
|
return self.converse_chunk_parser(chunk_data=chunk_data)
|
||||||
text = chunk_data["delta"]["text"]
|
|
||||||
stop_reason = chunk_data["delta"].get("stop_reason", None)
|
|
||||||
if stop_reason != None:
|
|
||||||
is_finished = True
|
|
||||||
finish_reason = stop_reason
|
|
||||||
######## bedrock.mistral mappings ###############
|
######## bedrock.mistral mappings ###############
|
||||||
elif "outputs" in chunk_data:
|
elif "outputs" in chunk_data:
|
||||||
if (
|
if (
|
||||||
|
@ -1851,9 +1967,15 @@ class AWSEventStreamDecoder:
|
||||||
|
|
||||||
def _parse_message_from_event(self, event) -> Optional[str]:
|
def _parse_message_from_event(self, event) -> Optional[str]:
|
||||||
response_dict = event.to_response_dict()
|
response_dict = event.to_response_dict()
|
||||||
|
parsed_response = self.parser.parse(response_dict, get_response_stream_shape())
|
||||||
if response_dict["status_code"] != 200:
|
if response_dict["status_code"] != 200:
|
||||||
raise ValueError(f"Bad response code, expected 200: {response_dict}")
|
raise ValueError(f"Bad response code, expected 200: {response_dict}")
|
||||||
|
if "chunk" in parsed_response:
|
||||||
|
chunk = parsed_response.get("chunk")
|
||||||
|
if not chunk:
|
||||||
|
return None
|
||||||
|
return chunk.get("bytes").decode() # type: ignore[no-any-return]
|
||||||
|
else:
|
||||||
chunk = response_dict.get("body")
|
chunk = response_dict.get("body")
|
||||||
if not chunk:
|
if not chunk:
|
||||||
return None
|
return None
|
||||||
|
|
|
@ -168,7 +168,6 @@ class HTTPHandler:
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def __del__(self) -> None:
|
def __del__(self) -> None:
|
||||||
traceback.print_stack()
|
|
||||||
try:
|
try:
|
||||||
self.close()
|
self.close()
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
|
@ -121,7 +121,8 @@ azure_text_completions = AzureTextCompletion()
|
||||||
huggingface = Huggingface()
|
huggingface = Huggingface()
|
||||||
predibase_chat_completions = PredibaseChatCompletion()
|
predibase_chat_completions = PredibaseChatCompletion()
|
||||||
triton_chat_completions = TritonChatCompletion()
|
triton_chat_completions = TritonChatCompletion()
|
||||||
bedrock_chat_completion = BedrockConverseLLM()
|
bedrock_chat_completion = BedrockLLM()
|
||||||
|
bedrock_converse_chat_completion = BedrockConverseLLM()
|
||||||
vertex_chat_completion = VertexLLM()
|
vertex_chat_completion = VertexLLM()
|
||||||
####### COMPLETION ENDPOINTS ################
|
####### COMPLETION ENDPOINTS ################
|
||||||
|
|
||||||
|
@ -2096,6 +2097,24 @@ def completion(
|
||||||
custom_llm_provider="bedrock",
|
custom_llm_provider="bedrock",
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
if model.startswith("anthropic"):
|
||||||
|
response = bedrock_converse_chat_completion.completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
encoding=encoding,
|
||||||
|
logging_obj=logging,
|
||||||
|
extra_headers=extra_headers,
|
||||||
|
timeout=timeout,
|
||||||
|
acompletion=acompletion,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
response = bedrock_chat_completion.completion(
|
response = bedrock_chat_completion.completion(
|
||||||
model=model,
|
model=model,
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -1288,14 +1288,14 @@ async def test_completion_replicate_llama3_streaming(sync_mode):
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model",
|
"model",
|
||||||
[
|
[
|
||||||
# "bedrock/cohere.command-r-plus-v1:0",
|
"bedrock/cohere.command-r-plus-v1:0",
|
||||||
# "anthropic.claude-3-sonnet-20240229-v1:0",
|
"anthropic.claude-3-sonnet-20240229-v1:0",
|
||||||
# "anthropic.claude-instant-v1",
|
"anthropic.claude-instant-v1",
|
||||||
# "bedrock/ai21.j2-mid",
|
"bedrock/ai21.j2-mid",
|
||||||
# "mistral.mistral-7b-instruct-v0:2",
|
"mistral.mistral-7b-instruct-v0:2",
|
||||||
# "bedrock/amazon.titan-tg1-large",
|
"bedrock/amazon.titan-tg1-large",
|
||||||
# "meta.llama3-8b-instruct-v1:0",
|
"meta.llama3-8b-instruct-v1:0",
|
||||||
"cohere.command-text-v14"
|
"cohere.command-text-v14",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -1324,8 +1324,6 @@ async def test_bedrock_httpx_streaming(sync_mode, model):
|
||||||
raise Exception("finish reason not set")
|
raise Exception("finish reason not set")
|
||||||
if complete_response.strip() == "":
|
if complete_response.strip() == "":
|
||||||
raise Exception("Empty response received")
|
raise Exception("Empty response received")
|
||||||
|
|
||||||
assert False
|
|
||||||
else:
|
else:
|
||||||
response: litellm.CustomStreamWrapper = await litellm.acompletion( # type: ignore
|
response: litellm.CustomStreamWrapper = await litellm.acompletion( # type: ignore
|
||||||
model=model,
|
model=model,
|
||||||
|
|
|
@ -5620,13 +5620,80 @@ def get_optional_params(
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
model=model, custom_llm_provider=custom_llm_provider
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
)
|
)
|
||||||
|
if "ai21" in model:
|
||||||
|
_check_valid_arg(supported_params=supported_params)
|
||||||
|
# params "maxTokens":200,"temperature":0,"topP":250,"stop_sequences":[],
|
||||||
|
# https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra
|
||||||
|
if max_tokens is not None:
|
||||||
|
optional_params["maxTokens"] = max_tokens
|
||||||
|
if temperature is not None:
|
||||||
|
optional_params["temperature"] = temperature
|
||||||
|
if top_p is not None:
|
||||||
|
optional_params["topP"] = top_p
|
||||||
|
if stream:
|
||||||
|
optional_params["stream"] = stream
|
||||||
|
elif "anthropic" in model:
|
||||||
_check_valid_arg(supported_params=supported_params)
|
_check_valid_arg(supported_params=supported_params)
|
||||||
optional_params = litellm.AmazonConverseConfig().map_openai_params(
|
optional_params = litellm.AmazonConverseConfig().map_openai_params(
|
||||||
model=model,
|
model=model,
|
||||||
non_default_params=non_default_params,
|
non_default_params=non_default_params,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
drop_params=drop_params,
|
drop_params=(
|
||||||
|
drop_params
|
||||||
|
if drop_params is not None and isinstance(drop_params, bool)
|
||||||
|
else False
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
elif "amazon" in model: # amazon titan llms
|
||||||
|
_check_valid_arg(supported_params=supported_params)
|
||||||
|
# see https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-large
|
||||||
|
if max_tokens is not None:
|
||||||
|
optional_params["maxTokenCount"] = max_tokens
|
||||||
|
if temperature is not None:
|
||||||
|
optional_params["temperature"] = temperature
|
||||||
|
if stop is not None:
|
||||||
|
filtered_stop = _map_and_modify_arg(
|
||||||
|
{"stop": stop}, provider="bedrock", model=model
|
||||||
|
)
|
||||||
|
optional_params["stopSequences"] = filtered_stop["stop"]
|
||||||
|
if top_p is not None:
|
||||||
|
optional_params["topP"] = top_p
|
||||||
|
if stream:
|
||||||
|
optional_params["stream"] = stream
|
||||||
|
elif "meta" in model: # amazon / meta llms
|
||||||
|
_check_valid_arg(supported_params=supported_params)
|
||||||
|
# see https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-large
|
||||||
|
if max_tokens is not None:
|
||||||
|
optional_params["max_gen_len"] = max_tokens
|
||||||
|
if temperature is not None:
|
||||||
|
optional_params["temperature"] = temperature
|
||||||
|
if top_p is not None:
|
||||||
|
optional_params["top_p"] = top_p
|
||||||
|
if stream:
|
||||||
|
optional_params["stream"] = stream
|
||||||
|
elif "cohere" in model: # cohere models on bedrock
|
||||||
|
_check_valid_arg(supported_params=supported_params)
|
||||||
|
# handle cohere params
|
||||||
|
if stream:
|
||||||
|
optional_params["stream"] = stream
|
||||||
|
if temperature is not None:
|
||||||
|
optional_params["temperature"] = temperature
|
||||||
|
if max_tokens is not None:
|
||||||
|
optional_params["max_tokens"] = max_tokens
|
||||||
|
elif "mistral" in model:
|
||||||
|
_check_valid_arg(supported_params=supported_params)
|
||||||
|
# mistral params on bedrock
|
||||||
|
# \"max_tokens\":400,\"temperature\":0.7,\"top_p\":0.7,\"stop\":[\"\\\\n\\\\nHuman:\"]}"
|
||||||
|
if max_tokens is not None:
|
||||||
|
optional_params["max_tokens"] = max_tokens
|
||||||
|
if temperature is not None:
|
||||||
|
optional_params["temperature"] = temperature
|
||||||
|
if top_p is not None:
|
||||||
|
optional_params["top_p"] = top_p
|
||||||
|
if stop is not None:
|
||||||
|
optional_params["stop"] = stop
|
||||||
|
if stream is not None:
|
||||||
|
optional_params["stream"] = stream
|
||||||
elif custom_llm_provider == "aleph_alpha":
|
elif custom_llm_provider == "aleph_alpha":
|
||||||
supported_params = [
|
supported_params = [
|
||||||
"max_tokens",
|
"max_tokens",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue