forked from phoenix/litellm-mirror
Merge pull request #4788 from BerriAI/litellm_ai21_jamba
feat(bedrock_httpx.py): add ai21 jamba instruct as bedrock model
This commit is contained in:
commit
614e292bed
7 changed files with 158 additions and 30 deletions
|
@ -895,7 +895,7 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
|
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
|
||||||
print_verbose("makes anthropic streaming POST request")
|
print_verbose("makes anthropic streaming POST request")
|
||||||
data["stream"] = stream
|
data["stream"] = stream
|
||||||
response = requests.post(
|
response = client.post(
|
||||||
api_base,
|
api_base,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
data=json.dumps(data),
|
data=json.dumps(data),
|
||||||
|
|
|
@ -75,6 +75,7 @@ BEDROCK_CONVERSE_MODELS = [
|
||||||
"anthropic.claude-v2:1",
|
"anthropic.claude-v2:1",
|
||||||
"anthropic.claude-v1",
|
"anthropic.claude-v1",
|
||||||
"anthropic.claude-instant-v1",
|
"anthropic.claude-instant-v1",
|
||||||
|
"ai21.jamba-instruct-v1:0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -195,13 +196,39 @@ async def make_call(
|
||||||
if client is None:
|
if client is None:
|
||||||
client = _get_async_httpx_client() # Create a new client if none provided
|
client = _get_async_httpx_client() # Create a new client if none provided
|
||||||
|
|
||||||
response = await client.post(api_base, headers=headers, data=data, stream=True)
|
response = await client.post(
|
||||||
|
api_base,
|
||||||
|
headers=headers,
|
||||||
|
data=data,
|
||||||
|
stream=True if "ai21" not in api_base else False,
|
||||||
|
)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise BedrockError(status_code=response.status_code, message=response.text)
|
raise BedrockError(status_code=response.status_code, message=response.text)
|
||||||
|
|
||||||
|
if "ai21" in api_base:
|
||||||
|
aws_bedrock_process_response = BedrockConverseLLM()
|
||||||
|
model_response: (
|
||||||
|
ModelResponse
|
||||||
|
) = aws_bedrock_process_response.process_response(
|
||||||
|
model=model,
|
||||||
|
response=response,
|
||||||
|
model_response=litellm.ModelResponse(),
|
||||||
|
stream=True,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params={},
|
||||||
|
api_key="",
|
||||||
|
data=data,
|
||||||
|
messages=messages,
|
||||||
|
print_verbose=litellm.print_verbose,
|
||||||
|
encoding=litellm.encoding,
|
||||||
|
) # type: ignore
|
||||||
|
completion_stream: Any = MockResponseIterator(model_response=model_response)
|
||||||
|
else:
|
||||||
decoder = AWSEventStreamDecoder(model=model)
|
decoder = AWSEventStreamDecoder(model=model)
|
||||||
completion_stream = decoder.aiter_bytes(response.aiter_bytes(chunk_size=1024))
|
completion_stream = decoder.aiter_bytes(
|
||||||
|
response.aiter_bytes(chunk_size=1024)
|
||||||
|
)
|
||||||
|
|
||||||
# LOGGING
|
# LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
|
@ -233,11 +260,33 @@ def make_sync_call(
|
||||||
if client is None:
|
if client is None:
|
||||||
client = _get_httpx_client() # Create a new client if none provided
|
client = _get_httpx_client() # Create a new client if none provided
|
||||||
|
|
||||||
response = client.post(api_base, headers=headers, data=data, stream=True)
|
response = client.post(
|
||||||
|
api_base,
|
||||||
|
headers=headers,
|
||||||
|
data=data,
|
||||||
|
stream=True if "ai21" not in api_base else False,
|
||||||
|
)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise BedrockError(status_code=response.status_code, message=response.read())
|
raise BedrockError(status_code=response.status_code, message=response.read())
|
||||||
|
|
||||||
|
if "ai21" in api_base:
|
||||||
|
aws_bedrock_process_response = BedrockConverseLLM()
|
||||||
|
model_response: ModelResponse = aws_bedrock_process_response.process_response(
|
||||||
|
model=model,
|
||||||
|
response=response,
|
||||||
|
model_response=litellm.ModelResponse(),
|
||||||
|
stream=True,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params={},
|
||||||
|
api_key="",
|
||||||
|
data=data,
|
||||||
|
messages=messages,
|
||||||
|
print_verbose=litellm.print_verbose,
|
||||||
|
encoding=litellm.encoding,
|
||||||
|
) # type: ignore
|
||||||
|
completion_stream: Any = MockResponseIterator(model_response=model_response)
|
||||||
|
else:
|
||||||
decoder = AWSEventStreamDecoder(model=model)
|
decoder = AWSEventStreamDecoder(model=model)
|
||||||
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
|
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
|
||||||
|
|
||||||
|
@ -1348,7 +1397,7 @@ class BedrockConverseLLM(BaseLLM):
|
||||||
response: Union[requests.Response, httpx.Response],
|
response: Union[requests.Response, httpx.Response],
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
stream: bool,
|
stream: bool,
|
||||||
logging_obj: Logging,
|
logging_obj: Optional[Logging],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
data: Union[dict, str],
|
data: Union[dict, str],
|
||||||
|
@ -1358,6 +1407,7 @@ class BedrockConverseLLM(BaseLLM):
|
||||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
if logging_obj is not None:
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
input=messages,
|
input=messages,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
@ -1900,7 +1950,7 @@ class BedrockConverseLLM(BaseLLM):
|
||||||
if acompletion:
|
if acompletion:
|
||||||
if isinstance(client, HTTPHandler):
|
if isinstance(client, HTTPHandler):
|
||||||
client = None
|
client = None
|
||||||
if stream is True and provider != "ai21":
|
if stream is True:
|
||||||
return self.async_streaming(
|
return self.async_streaming(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -1937,7 +1987,7 @@ class BedrockConverseLLM(BaseLLM):
|
||||||
client=client,
|
client=client,
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
if (stream is not None and stream is True) and provider != "ai21":
|
if stream is not None and stream is True:
|
||||||
|
|
||||||
streaming_response = CustomStreamWrapper(
|
streaming_response = CustomStreamWrapper(
|
||||||
completion_stream=None,
|
completion_stream=None,
|
||||||
|
@ -1981,7 +2031,7 @@ class BedrockConverseLLM(BaseLLM):
|
||||||
model=model,
|
model=model,
|
||||||
response=response,
|
response=response,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
stream=stream,
|
stream=stream if isinstance(stream, bool) else False,
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
api_key="",
|
api_key="",
|
||||||
|
@ -2168,3 +2218,49 @@ class AWSEventStreamDecoder:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return chunk.decode() # type: ignore[no-any-return]
|
return chunk.decode() # type: ignore[no-any-return]
|
||||||
|
|
||||||
|
|
||||||
|
class MockResponseIterator: # for returning ai21 streaming responses
|
||||||
|
def __init__(self, model_response):
|
||||||
|
self.model_response = model_response
|
||||||
|
self.is_done = False
|
||||||
|
|
||||||
|
# Sync iterator
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _chunk_parser(self, chunk_data: ModelResponse) -> GenericStreamingChunk:
|
||||||
|
|
||||||
|
try:
|
||||||
|
chunk_usage: litellm.Usage = getattr(chunk_data, "usage")
|
||||||
|
processed_chunk = GenericStreamingChunk(
|
||||||
|
text=chunk_data.choices[0].message.content or "", # type: ignore
|
||||||
|
tool_use=None,
|
||||||
|
is_finished=True,
|
||||||
|
finish_reason=chunk_data.choices[0].finish_reason, # type: ignore
|
||||||
|
usage=ConverseTokenUsageBlock(
|
||||||
|
inputTokens=chunk_usage.prompt_tokens,
|
||||||
|
outputTokens=chunk_usage.completion_tokens,
|
||||||
|
totalTokens=chunk_usage.total_tokens,
|
||||||
|
),
|
||||||
|
index=0,
|
||||||
|
)
|
||||||
|
return processed_chunk
|
||||||
|
except Exception:
|
||||||
|
raise ValueError(f"Failed to decode chunk: {chunk_data}")
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
if self.is_done:
|
||||||
|
raise StopIteration
|
||||||
|
self.is_done = True
|
||||||
|
return self._chunk_parser(self.model_response)
|
||||||
|
|
||||||
|
# Async iterator
|
||||||
|
def __aiter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self):
|
||||||
|
if self.is_done:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
self.is_done = True
|
||||||
|
return self._chunk_parser(self.model_response)
|
||||||
|
|
|
@ -2803,6 +2803,16 @@
|
||||||
"litellm_provider": "bedrock",
|
"litellm_provider": "bedrock",
|
||||||
"mode": "chat"
|
"mode": "chat"
|
||||||
},
|
},
|
||||||
|
"ai21.jamba-instruct-v1:0": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"max_input_tokens": 70000,
|
||||||
|
"max_output_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.0000005,
|
||||||
|
"output_cost_per_token": 0.0000007,
|
||||||
|
"litellm_provider": "bedrock",
|
||||||
|
"mode": "chat",
|
||||||
|
"supports_system_messages": true
|
||||||
|
},
|
||||||
"amazon.titan-text-lite-v1": {
|
"amazon.titan-text-lite-v1": {
|
||||||
"max_tokens": 4000,
|
"max_tokens": 4000,
|
||||||
"max_input_tokens": 42000,
|
"max_input_tokens": 42000,
|
||||||
|
|
|
@ -1,5 +1,13 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: bad-azure-model
|
- model_name: bad-azure-model
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: gpt-4
|
model: azure/chatgpt-v-2
|
||||||
request_timeout: 1
|
azure_ad_token: ""
|
||||||
|
api_base: os.environ/AZURE_API_BASE
|
||||||
|
|
||||||
|
- model_name: good-openai-model
|
||||||
|
litellm_params:
|
||||||
|
model: gpt-3.5-turbo
|
||||||
|
|
||||||
|
litellm_settings:
|
||||||
|
fallbacks: [{"bad-azure-model": ["good-openai-model"]}]
|
||||||
|
|
|
@ -4,5 +4,7 @@ model_list:
|
||||||
model: fireworks_ai/accounts/fireworks/models/llama-v3-70b-instruct
|
model: fireworks_ai/accounts/fireworks/models/llama-v3-70b-instruct
|
||||||
api_key: "os.environ/FIREWORKS_AI_API_KEY"
|
api_key: "os.environ/FIREWORKS_AI_API_KEY"
|
||||||
|
|
||||||
|
router_settings:
|
||||||
|
enable_tag_filtering: True # 👈 Key Change
|
||||||
general_settings:
|
general_settings:
|
||||||
master_key: sk-1234
|
master_key: sk-1234
|
|
@ -1312,22 +1312,22 @@ async def test_completion_replicate_llama3_streaming(sync_mode):
|
||||||
# pytest.fail(f"Error occurred: {e}")
|
# pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("sync_mode", [True]) # False
|
@pytest.mark.parametrize("sync_mode", [True, False]) #
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model",
|
"model, region",
|
||||||
[
|
[
|
||||||
"bedrock/cohere.command-r-plus-v1:0",
|
["bedrock/ai21.jamba-instruct-v1:0", "us-east-1"],
|
||||||
"anthropic.claude-3-sonnet-20240229-v1:0",
|
["bedrock/cohere.command-r-plus-v1:0", None],
|
||||||
"anthropic.claude-instant-v1",
|
["anthropic.claude-3-sonnet-20240229-v1:0", None],
|
||||||
"bedrock/ai21.j2-mid",
|
["anthropic.claude-instant-v1", None],
|
||||||
"mistral.mistral-7b-instruct-v0:2",
|
["mistral.mistral-7b-instruct-v0:2", None],
|
||||||
"bedrock/amazon.titan-tg1-large",
|
["bedrock/amazon.titan-tg1-large", None],
|
||||||
"meta.llama3-8b-instruct-v1:0",
|
["meta.llama3-8b-instruct-v1:0", None],
|
||||||
"cohere.command-text-v14",
|
["cohere.command-text-v14", None],
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_bedrock_httpx_streaming(sync_mode, model):
|
async def test_bedrock_httpx_streaming(sync_mode, model, region):
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
if sync_mode:
|
if sync_mode:
|
||||||
|
@ -1337,6 +1337,7 @@ async def test_bedrock_httpx_streaming(sync_mode, model):
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_tokens=10, # type: ignore
|
max_tokens=10, # type: ignore
|
||||||
stream=True,
|
stream=True,
|
||||||
|
aws_region_name=region,
|
||||||
)
|
)
|
||||||
complete_response = ""
|
complete_response = ""
|
||||||
# Add any assertions here to check the response
|
# Add any assertions here to check the response
|
||||||
|
@ -1358,6 +1359,7 @@ async def test_bedrock_httpx_streaming(sync_mode, model):
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_tokens=100, # type: ignore
|
max_tokens=100, # type: ignore
|
||||||
stream=True,
|
stream=True,
|
||||||
|
aws_region_name=region,
|
||||||
)
|
)
|
||||||
complete_response = ""
|
complete_response = ""
|
||||||
# Add any assertions here to check the response
|
# Add any assertions here to check the response
|
||||||
|
|
|
@ -2803,6 +2803,16 @@
|
||||||
"litellm_provider": "bedrock",
|
"litellm_provider": "bedrock",
|
||||||
"mode": "chat"
|
"mode": "chat"
|
||||||
},
|
},
|
||||||
|
"ai21.jamba-instruct-v1:0": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"max_input_tokens": 70000,
|
||||||
|
"max_output_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.0000005,
|
||||||
|
"output_cost_per_token": 0.0000007,
|
||||||
|
"litellm_provider": "bedrock",
|
||||||
|
"mode": "chat",
|
||||||
|
"supports_system_messages": true
|
||||||
|
},
|
||||||
"amazon.titan-text-lite-v1": {
|
"amazon.titan-text-lite-v1": {
|
||||||
"max_tokens": 4000,
|
"max_tokens": 4000,
|
||||||
"max_input_tokens": 42000,
|
"max_input_tokens": 42000,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue