LiteLLM Merged PR's (#5538)

* Fix typo in #5509 (#5532)

* Reapply "(bedrock): Fix usage with Cloudflare AI Gateway, and proxies in gener…" (#5519)

This reverts commit 995019c08a.

* (bedrock): Fix obvious typo

* test: cleanup linting error

---------

Co-authored-by: David Manouchehri <david.manouchehri@ai.moda>
This commit is contained in:
Krish Dholakia 2024-09-05 17:11:31 -07:00 committed by GitHub
parent d8ef8c133e
commit 6cd8951f56
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 33 additions and 17 deletions

View file

@ -728,7 +728,7 @@ class BedrockLLM(BaseAWSLLM):
) )
### SET RUNTIME ENDPOINT ### ### SET RUNTIME ENDPOINT ###
endpoint_url = get_runtime_endpoint( endpoint_url, proxy_endpoint_url = get_runtime_endpoint(
api_base=api_base, api_base=api_base,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
aws_region_name=aws_region_name, aws_region_name=aws_region_name,
@ -736,8 +736,10 @@ class BedrockLLM(BaseAWSLLM):
if (stream is not None and stream is True) and provider != "ai21": if (stream is not None and stream is True) and provider != "ai21":
endpoint_url = f"{endpoint_url}/model/{modelId}/invoke-with-response-stream" endpoint_url = f"{endpoint_url}/model/{modelId}/invoke-with-response-stream"
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/invoke-with-response-stream"
else: else:
endpoint_url = f"{endpoint_url}/model/{modelId}/invoke" endpoint_url = f"{endpoint_url}/model/{modelId}/invoke"
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/invoke"
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name) sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
@ -903,7 +905,7 @@ class BedrockLLM(BaseAWSLLM):
api_key="", api_key="",
additional_args={ additional_args={
"complete_input_dict": data, "complete_input_dict": data,
"api_base": prepped.url, "api_base": proxy_endpoint_url,
"headers": prepped.headers, "headers": prepped.headers,
}, },
) )
@ -917,7 +919,7 @@ class BedrockLLM(BaseAWSLLM):
model=model, model=model,
messages=messages, messages=messages,
data=data, data=data,
api_base=prepped.url, api_base=proxy_endpoint_url,
model_response=model_response, model_response=model_response,
print_verbose=print_verbose, print_verbose=print_verbose,
encoding=encoding, encoding=encoding,
@ -935,7 +937,7 @@ class BedrockLLM(BaseAWSLLM):
model=model, model=model,
messages=messages, messages=messages,
data=data, data=data,
api_base=prepped.url, api_base=proxy_endpoint_url,
model_response=model_response, model_response=model_response,
print_verbose=print_verbose, print_verbose=print_verbose,
encoding=encoding, encoding=encoding,
@ -960,7 +962,7 @@ class BedrockLLM(BaseAWSLLM):
self.client = client self.client = client
if (stream is not None and stream == True) and provider != "ai21": if (stream is not None and stream == True) and provider != "ai21":
response = self.client.post( response = self.client.post(
url=prepped.url, url=proxy_endpoint_url,
headers=prepped.headers, # type: ignore headers=prepped.headers, # type: ignore
data=data, data=data,
stream=stream, stream=stream,
@ -991,7 +993,7 @@ class BedrockLLM(BaseAWSLLM):
return streaming_response return streaming_response
try: try:
response = self.client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore response = self.client.post(url=proxy_endpoint_url, headers=prepped.headers, data=data) # type: ignore
response.raise_for_status() response.raise_for_status()
except httpx.HTTPStatusError as err: except httpx.HTTPStatusError as err:
error_code = err.response.status_code error_code = err.response.status_code
@ -1604,15 +1606,17 @@ class BedrockConverseLLM(BaseAWSLLM):
) )
### SET RUNTIME ENDPOINT ### ### SET RUNTIME ENDPOINT ###
endpoint_url = get_runtime_endpoint( endpoint_url, proxy_endpoint_url = get_runtime_endpoint(
api_base=api_base, api_base=api_base,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
aws_region_name=aws_region_name, aws_region_name=aws_region_name,
) )
if (stream is not None and stream is True) and provider != "ai21": if (stream is not None and stream is True) and provider != "ai21":
endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream" endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream"
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse-stream"
else: else:
endpoint_url = f"{endpoint_url}/model/{modelId}/converse" endpoint_url = f"{endpoint_url}/model/{modelId}/converse"
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse"
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name) sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
@ -1719,7 +1723,7 @@ class BedrockConverseLLM(BaseAWSLLM):
api_key="", api_key="",
additional_args={ additional_args={
"complete_input_dict": data, "complete_input_dict": data,
"api_base": prepped.url, "api_base": proxy_endpoint_url,
"headers": prepped.headers, "headers": prepped.headers,
}, },
) )
@ -1733,7 +1737,7 @@ class BedrockConverseLLM(BaseAWSLLM):
model=model, model=model,
messages=messages, messages=messages,
data=data, data=data,
api_base=prepped.url, api_base=proxy_endpoint_url,
model_response=model_response, model_response=model_response,
print_verbose=print_verbose, print_verbose=print_verbose,
encoding=encoding, encoding=encoding,
@ -1751,7 +1755,7 @@ class BedrockConverseLLM(BaseAWSLLM):
model=model, model=model,
messages=messages, messages=messages,
data=data, data=data,
api_base=prepped.url, api_base=proxy_endpoint_url,
model_response=model_response, model_response=model_response,
print_verbose=print_verbose, print_verbose=print_verbose,
encoding=encoding, encoding=encoding,
@ -1772,7 +1776,7 @@ class BedrockConverseLLM(BaseAWSLLM):
make_call=partial( make_call=partial(
make_sync_call, make_sync_call,
client=None, client=None,
api_base=prepped.url, api_base=proxy_endpoint_url,
headers=prepped.headers, # type: ignore headers=prepped.headers, # type: ignore
data=data, data=data,
model=model, model=model,
@ -1797,7 +1801,7 @@ class BedrockConverseLLM(BaseAWSLLM):
else: else:
client = client client = client
try: try:
response = client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore response = client.post(url=proxy_endpoint_url, headers=prepped.headers, data=data) # type: ignore
response.raise_for_status() response.raise_for_status()
except httpx.HTTPStatusError as err: except httpx.HTTPStatusError as err:
error_code = err.response.status_code error_code = err.response.status_code

View file

@ -5,7 +5,7 @@ Common utilities used across bedrock chat/embedding/image generation
import os import os
import types import types
from enum import Enum from enum import Enum
from typing import List, Optional, Union from typing import List, Optional, Union, Tuple
import httpx import httpx
@ -729,7 +729,7 @@ def get_runtime_endpoint(
api_base: Optional[str], api_base: Optional[str],
aws_bedrock_runtime_endpoint: Optional[str], aws_bedrock_runtime_endpoint: Optional[str],
aws_region_name: str, aws_region_name: str,
) -> str: ) -> Tuple[str, str]:
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT") env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
if api_base is not None: if api_base is not None:
endpoint_url = api_base endpoint_url = api_base
@ -744,7 +744,19 @@ def get_runtime_endpoint(
else: else:
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com" endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
return endpoint_url # Determine proxy_endpoint_url
if env_aws_bedrock_runtime_endpoint and isinstance(
env_aws_bedrock_runtime_endpoint, str
):
proxy_endpoint_url = env_aws_bedrock_runtime_endpoint
elif aws_bedrock_runtime_endpoint is not None and isinstance(
aws_bedrock_runtime_endpoint, str
):
proxy_endpoint_url = aws_bedrock_runtime_endpoint
else:
proxy_endpoint_url = endpoint_url
return endpoint_url, proxy_endpoint_url
class ModelResponseIterator: class ModelResponseIterator:

View file

@ -393,7 +393,7 @@ class BedrockEmbedding(BaseAWSLLM):
batch_data.append(transformed_request) batch_data.append(transformed_request)
### SET RUNTIME ENDPOINT ### ### SET RUNTIME ENDPOINT ###
endpoint_url = get_runtime_endpoint( endpoint_url, proxy_endpoint_url = get_runtime_endpoint(
api_base=api_base, api_base=api_base,
aws_bedrock_runtime_endpoint=optional_params.pop( aws_bedrock_runtime_endpoint=optional_params.pop(
"aws_bedrock_runtime_endpoint", None "aws_bedrock_runtime_endpoint", None

View file

@ -146,4 +146,4 @@ def trade(model_name: str) -> List[Trade]: # type: ignore
) )
def test_function_call_parsing(model): def test_function_call_parsing(model):
trades = trade(model) trades = trade(model)
print([trade.order for trade in trades]) print([trade.order for trade in trades if trade is not None])