Merge pull request #4788 from BerriAI/litellm_ai21_jamba

feat(bedrock_httpx.py): add ai21 jamba instruct as bedrock model
This commit is contained in:
Krish Dholakia 2024-07-19 17:11:36 -07:00 committed by GitHub
commit 614e292bed
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 158 additions and 30 deletions

View file

@ -895,7 +895,7 @@ class AnthropicChatCompletion(BaseLLM):
): # if function call - fake the streaming (need complete blocks for output parsing in openai format) ): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
print_verbose("makes anthropic streaming POST request") print_verbose("makes anthropic streaming POST request")
data["stream"] = stream data["stream"] = stream
response = requests.post( response = client.post(
api_base, api_base,
headers=headers, headers=headers,
data=json.dumps(data), data=json.dumps(data),

View file

@ -75,6 +75,7 @@ BEDROCK_CONVERSE_MODELS = [
"anthropic.claude-v2:1", "anthropic.claude-v2:1",
"anthropic.claude-v1", "anthropic.claude-v1",
"anthropic.claude-instant-v1", "anthropic.claude-instant-v1",
"ai21.jamba-instruct-v1:0",
] ]
@ -195,13 +196,39 @@ async def make_call(
if client is None: if client is None:
client = _get_async_httpx_client() # Create a new client if none provided client = _get_async_httpx_client() # Create a new client if none provided
response = await client.post(api_base, headers=headers, data=data, stream=True) response = await client.post(
api_base,
headers=headers,
data=data,
stream=True if "ai21" not in api_base else False,
)
if response.status_code != 200: if response.status_code != 200:
raise BedrockError(status_code=response.status_code, message=response.text) raise BedrockError(status_code=response.status_code, message=response.text)
decoder = AWSEventStreamDecoder(model=model) if "ai21" in api_base:
completion_stream = decoder.aiter_bytes(response.aiter_bytes(chunk_size=1024)) aws_bedrock_process_response = BedrockConverseLLM()
model_response: (
ModelResponse
) = aws_bedrock_process_response.process_response(
model=model,
response=response,
model_response=litellm.ModelResponse(),
stream=True,
logging_obj=logging_obj,
optional_params={},
api_key="",
data=data,
messages=messages,
print_verbose=litellm.print_verbose,
encoding=litellm.encoding,
) # type: ignore
completion_stream: Any = MockResponseIterator(model_response=model_response)
else:
decoder = AWSEventStreamDecoder(model=model)
completion_stream = decoder.aiter_bytes(
response.aiter_bytes(chunk_size=1024)
)
# LOGGING # LOGGING
logging_obj.post_call( logging_obj.post_call(
@ -233,13 +260,35 @@ def make_sync_call(
if client is None: if client is None:
client = _get_httpx_client() # Create a new client if none provided client = _get_httpx_client() # Create a new client if none provided
response = client.post(api_base, headers=headers, data=data, stream=True) response = client.post(
api_base,
headers=headers,
data=data,
stream=True if "ai21" not in api_base else False,
)
if response.status_code != 200: if response.status_code != 200:
raise BedrockError(status_code=response.status_code, message=response.read()) raise BedrockError(status_code=response.status_code, message=response.read())
decoder = AWSEventStreamDecoder(model=model) if "ai21" in api_base:
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024)) aws_bedrock_process_response = BedrockConverseLLM()
model_response: ModelResponse = aws_bedrock_process_response.process_response(
model=model,
response=response,
model_response=litellm.ModelResponse(),
stream=True,
logging_obj=logging_obj,
optional_params={},
api_key="",
data=data,
messages=messages,
print_verbose=litellm.print_verbose,
encoding=litellm.encoding,
) # type: ignore
completion_stream: Any = MockResponseIterator(model_response=model_response)
else:
decoder = AWSEventStreamDecoder(model=model)
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
# LOGGING # LOGGING
logging_obj.post_call( logging_obj.post_call(
@ -1348,7 +1397,7 @@ class BedrockConverseLLM(BaseLLM):
response: Union[requests.Response, httpx.Response], response: Union[requests.Response, httpx.Response],
model_response: ModelResponse, model_response: ModelResponse,
stream: bool, stream: bool,
logging_obj: Logging, logging_obj: Optional[Logging],
optional_params: dict, optional_params: dict,
api_key: str, api_key: str,
data: Union[dict, str], data: Union[dict, str],
@ -1358,12 +1407,13 @@ class BedrockConverseLLM(BaseLLM):
) -> Union[ModelResponse, CustomStreamWrapper]: ) -> Union[ModelResponse, CustomStreamWrapper]:
## LOGGING ## LOGGING
logging_obj.post_call( if logging_obj is not None:
input=messages, logging_obj.post_call(
api_key=api_key, input=messages,
original_response=response.text, api_key=api_key,
additional_args={"complete_input_dict": data}, original_response=response.text,
) additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}") print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT ## RESPONSE OBJECT
@ -1900,7 +1950,7 @@ class BedrockConverseLLM(BaseLLM):
if acompletion: if acompletion:
if isinstance(client, HTTPHandler): if isinstance(client, HTTPHandler):
client = None client = None
if stream is True and provider != "ai21": if stream is True:
return self.async_streaming( return self.async_streaming(
model=model, model=model,
messages=messages, messages=messages,
@ -1937,7 +1987,7 @@ class BedrockConverseLLM(BaseLLM):
client=client, client=client,
) # type: ignore ) # type: ignore
if (stream is not None and stream is True) and provider != "ai21": if stream is not None and stream is True:
streaming_response = CustomStreamWrapper( streaming_response = CustomStreamWrapper(
completion_stream=None, completion_stream=None,
@ -1981,7 +2031,7 @@ class BedrockConverseLLM(BaseLLM):
model=model, model=model,
response=response, response=response,
model_response=model_response, model_response=model_response,
stream=stream, stream=stream if isinstance(stream, bool) else False,
logging_obj=logging_obj, logging_obj=logging_obj,
optional_params=optional_params, optional_params=optional_params,
api_key="", api_key="",
@ -2168,3 +2218,49 @@ class AWSEventStreamDecoder:
return None return None
return chunk.decode() # type: ignore[no-any-return] return chunk.decode() # type: ignore[no-any-return]
class MockResponseIterator: # for returning ai21 streaming responses
def __init__(self, model_response):
self.model_response = model_response
self.is_done = False
# Sync iterator
def __iter__(self):
return self
def _chunk_parser(self, chunk_data: ModelResponse) -> GenericStreamingChunk:
try:
chunk_usage: litellm.Usage = getattr(chunk_data, "usage")
processed_chunk = GenericStreamingChunk(
text=chunk_data.choices[0].message.content or "", # type: ignore
tool_use=None,
is_finished=True,
finish_reason=chunk_data.choices[0].finish_reason, # type: ignore
usage=ConverseTokenUsageBlock(
inputTokens=chunk_usage.prompt_tokens,
outputTokens=chunk_usage.completion_tokens,
totalTokens=chunk_usage.total_tokens,
),
index=0,
)
return processed_chunk
except Exception:
raise ValueError(f"Failed to decode chunk: {chunk_data}")
def __next__(self):
if self.is_done:
raise StopIteration
self.is_done = True
return self._chunk_parser(self.model_response)
# Async iterator
def __aiter__(self):
return self
async def __anext__(self):
if self.is_done:
raise StopAsyncIteration
self.is_done = True
return self._chunk_parser(self.model_response)

View file

@ -2803,6 +2803,16 @@
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
"mode": "chat" "mode": "chat"
}, },
"ai21.jamba-instruct-v1:0": {
"max_tokens": 4096,
"max_input_tokens": 70000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.0000005,
"output_cost_per_token": 0.0000007,
"litellm_provider": "bedrock",
"mode": "chat",
"supports_system_messages": true
},
"amazon.titan-text-lite-v1": { "amazon.titan-text-lite-v1": {
"max_tokens": 4000, "max_tokens": 4000,
"max_input_tokens": 42000, "max_input_tokens": 42000,

View file

@ -1,5 +1,13 @@
model_list: model_list:
- model_name: bad-azure-model - model_name: bad-azure-model
litellm_params: litellm_params:
model: gpt-4 model: azure/chatgpt-v-2
request_timeout: 1 azure_ad_token: ""
api_base: os.environ/AZURE_API_BASE
- model_name: good-openai-model
litellm_params:
model: gpt-3.5-turbo
litellm_settings:
fallbacks: [{"bad-azure-model": ["good-openai-model"]}]

View file

@ -4,5 +4,7 @@ model_list:
model: fireworks_ai/accounts/fireworks/models/llama-v3-70b-instruct model: fireworks_ai/accounts/fireworks/models/llama-v3-70b-instruct
api_key: "os.environ/FIREWORKS_AI_API_KEY" api_key: "os.environ/FIREWORKS_AI_API_KEY"
router_settings:
enable_tag_filtering: True # 👈 Key Change
general_settings: general_settings:
master_key: sk-1234 master_key: sk-1234

View file

@ -1312,22 +1312,22 @@ async def test_completion_replicate_llama3_streaming(sync_mode):
# pytest.fail(f"Error occurred: {e}") # pytest.fail(f"Error occurred: {e}")
@pytest.mark.parametrize("sync_mode", [True]) # False @pytest.mark.parametrize("sync_mode", [True, False]) #
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model", "model, region",
[ [
"bedrock/cohere.command-r-plus-v1:0", ["bedrock/ai21.jamba-instruct-v1:0", "us-east-1"],
"anthropic.claude-3-sonnet-20240229-v1:0", ["bedrock/cohere.command-r-plus-v1:0", None],
"anthropic.claude-instant-v1", ["anthropic.claude-3-sonnet-20240229-v1:0", None],
"bedrock/ai21.j2-mid", ["anthropic.claude-instant-v1", None],
"mistral.mistral-7b-instruct-v0:2", ["mistral.mistral-7b-instruct-v0:2", None],
"bedrock/amazon.titan-tg1-large", ["bedrock/amazon.titan-tg1-large", None],
"meta.llama3-8b-instruct-v1:0", ["meta.llama3-8b-instruct-v1:0", None],
"cohere.command-text-v14", ["cohere.command-text-v14", None],
], ],
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bedrock_httpx_streaming(sync_mode, model): async def test_bedrock_httpx_streaming(sync_mode, model, region):
try: try:
litellm.set_verbose = True litellm.set_verbose = True
if sync_mode: if sync_mode:
@ -1337,6 +1337,7 @@ async def test_bedrock_httpx_streaming(sync_mode, model):
messages=messages, messages=messages,
max_tokens=10, # type: ignore max_tokens=10, # type: ignore
stream=True, stream=True,
aws_region_name=region,
) )
complete_response = "" complete_response = ""
# Add any assertions here to check the response # Add any assertions here to check the response
@ -1358,6 +1359,7 @@ async def test_bedrock_httpx_streaming(sync_mode, model):
messages=messages, messages=messages,
max_tokens=100, # type: ignore max_tokens=100, # type: ignore
stream=True, stream=True,
aws_region_name=region,
) )
complete_response = "" complete_response = ""
# Add any assertions here to check the response # Add any assertions here to check the response

View file

@ -2803,6 +2803,16 @@
"litellm_provider": "bedrock", "litellm_provider": "bedrock",
"mode": "chat" "mode": "chat"
}, },
"ai21.jamba-instruct-v1:0": {
"max_tokens": 4096,
"max_input_tokens": 70000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.0000005,
"output_cost_per_token": 0.0000007,
"litellm_provider": "bedrock",
"mode": "chat",
"supports_system_messages": true
},
"amazon.titan-text-lite-v1": { "amazon.titan-text-lite-v1": {
"max_tokens": 4000, "max_tokens": 4000,
"max_input_tokens": 42000, "max_input_tokens": 42000,