feat(bedrock_httpx.py): moves to using httpx client for bedrock cohere calls

This commit is contained in:
Krrish Dholakia 2024-05-11 13:43:08 -07:00
parent c12af219af
commit 4a3b084961
29 changed files with 147 additions and 64 deletions

View file

@ -10,7 +10,6 @@ from litellm.caching import DualCache
from typing import Literal, Union from typing import Literal, Union
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
@ -19,8 +18,6 @@ import traceback
import dotenv, os import dotenv, os
import requests import requests
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
import datetime, subprocess, sys import datetime, subprocess, sys
import litellm, uuid import litellm, uuid

View file

@ -1,8 +1,6 @@
#### What this does #### #### What this does ####
# On success + failure, log events to aispend.io # On success + failure, log events to aispend.io
import dotenv, os import dotenv, os
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
import datetime import datetime

View file

@ -3,7 +3,6 @@
import dotenv, os import dotenv, os
import requests # type: ignore import requests # type: ignore
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
import datetime import datetime

View file

@ -8,8 +8,6 @@ from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache from litellm.caching import DualCache
from typing import Literal, Union from typing import Literal, Union
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
@ -18,8 +16,6 @@ import traceback
import dotenv, os import dotenv, os
import requests import requests
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
import datetime, subprocess, sys import datetime, subprocess, sys
import litellm, uuid import litellm, uuid

View file

@ -6,8 +6,6 @@ from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache from litellm.caching import DualCache
from typing import Literal, Union, Optional from typing import Literal, Union, Optional
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback

View file

@ -3,8 +3,6 @@
import dotenv, os import dotenv, os
import requests # type: ignore import requests # type: ignore
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
import datetime, subprocess, sys import datetime, subprocess, sys
import litellm, uuid import litellm, uuid

View file

@ -3,8 +3,6 @@
import dotenv, os import dotenv, os
import requests # type: ignore import requests # type: ignore
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
import datetime, subprocess, sys import datetime, subprocess, sys
import litellm, uuid import litellm, uuid

View file

@ -3,8 +3,6 @@
import dotenv, os import dotenv, os
import requests # type: ignore import requests # type: ignore
import litellm import litellm
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback

View file

@ -1,8 +1,6 @@
#### What this does #### #### What this does ####
# On success, logs events to Langfuse # On success, logs events to Langfuse
import dotenv, os import os
dotenv.load_dotenv() # Loading env variables using dotenv
import copy import copy
import traceback import traceback
from packaging.version import Version from packaging.version import Version

View file

@ -3,8 +3,6 @@
import dotenv, os # type: ignore import dotenv, os # type: ignore
import requests # type: ignore import requests # type: ignore
from datetime import datetime from datetime import datetime
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
import asyncio import asyncio
import types import types

View file

@ -2,14 +2,11 @@
# On success + failure, log events to lunary.ai # On success + failure, log events to lunary.ai
from datetime import datetime, timezone from datetime import datetime, timezone
import traceback import traceback
import dotenv
import importlib import importlib
import sys import sys
import packaging import packaging
dotenv.load_dotenv()
# convert to {completion: xx, tokens: xx} # convert to {completion: xx, tokens: xx}
def parse_usage(usage): def parse_usage(usage):
@ -62,14 +59,16 @@ class LunaryLogger:
version = importlib.metadata.version("lunary") version = importlib.metadata.version("lunary")
# if version < 0.1.43 then raise ImportError # if version < 0.1.43 then raise ImportError
if packaging.version.Version(version) < packaging.version.Version("0.1.43"): if packaging.version.Version(version) < packaging.version.Version("0.1.43"):
print( print( # noqa
"Lunary version outdated. Required: >= 0.1.43. Upgrade via 'pip install lunary --upgrade'" "Lunary version outdated. Required: >= 0.1.43. Upgrade via 'pip install lunary --upgrade'"
) )
raise ImportError raise ImportError
self.lunary_client = lunary self.lunary_client = lunary
except ImportError: except ImportError:
print("Lunary not installed. Please install it using 'pip install lunary'") print( # noqa
"Lunary not installed. Please install it using 'pip install lunary'"
) # noqa
raise ImportError raise ImportError
def log_event( def log_event(

View file

@ -3,8 +3,6 @@
import dotenv, os, json import dotenv, os, json
import litellm import litellm
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler

View file

@ -4,8 +4,6 @@
import dotenv, os import dotenv, os
import requests # type: ignore import requests # type: ignore
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
import datetime, subprocess, sys import datetime, subprocess, sys
import litellm, uuid import litellm, uuid

View file

@ -5,8 +5,6 @@
import dotenv, os import dotenv, os
import requests # type: ignore import requests # type: ignore
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
import datetime, subprocess, sys import datetime, subprocess, sys
import litellm, uuid import litellm, uuid

View file

@ -3,8 +3,6 @@
import dotenv, os import dotenv, os
import requests # type: ignore import requests # type: ignore
from pydantic import BaseModel from pydantic import BaseModel
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback

View file

@ -1,9 +1,7 @@
#### What this does #### #### What this does ####
# On success + failure, log events to Supabase # On success + failure, log events to Supabase
import dotenv, os import os
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
import datetime, subprocess, sys import datetime, subprocess, sys
import litellm, uuid import litellm, uuid

View file

@ -2,8 +2,6 @@
# Class for sending Slack Alerts # # Class for sending Slack Alerts #
import dotenv, os import dotenv, os
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
dotenv.load_dotenv() # Loading env variables using dotenv
from litellm._logging import verbose_logger, verbose_proxy_logger from litellm._logging import verbose_logger, verbose_proxy_logger
import litellm, threading import litellm, threading
from typing import List, Literal, Any, Union, Optional, Dict from typing import List, Literal, Any, Union, Optional, Dict

View file

@ -3,8 +3,6 @@
import dotenv, os import dotenv, os
import requests # type: ignore import requests # type: ignore
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
import datetime, subprocess, sys import datetime, subprocess, sys
import litellm import litellm

View file

@ -21,11 +21,11 @@ try:
# contains a (known) object attribute # contains a (known) object attribute
object: Literal["chat.completion", "edit", "text_completion"] object: Literal["chat.completion", "edit", "text_completion"]
def __getitem__(self, key: K) -> V: def __getitem__(self, key: K) -> V: ... # noqa
... # pragma: no cover
def get(self, key: K, default: Optional[V] = None) -> Optional[V]: def get( # noqa
... # pragma: no cover self, key: K, default: Optional[V] = None
) -> Optional[V]: ... # pragma: no cover
class OpenAIRequestResponseResolver: class OpenAIRequestResponseResolver:
def __call__( def __call__(
@ -173,12 +173,11 @@ except:
#### What this does #### #### What this does ####
# On success, logs events to Langfuse # On success, logs events to Langfuse
import dotenv, os import os
import requests import requests
import requests import requests
from datetime import datetime from datetime import datetime
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback

View file

@ -0,0 +1,124 @@
# What is this?
## Initial implementation of calling bedrock via httpx client (allows for async calls).
## V0 - just covers cohere command-r support
import os, types
import json
from enum import Enum
import requests, copy # type: ignore
import time
from typing import Callable, Optional, List, Literal, Union
from litellm.utils import (
ModelResponse,
Usage,
map_finish_reason,
CustomStreamWrapper,
Message,
Choices,
get_secret,
)
import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from .base import BaseLLM
import httpx # type: ignore
from .bedrock import BedrockError
class BedrockLLM(BaseLLM):
"""
Example call
```
curl --location --request POST 'https://bedrock-runtime.{aws_region_name}.amazonaws.com/model/{bedrock_model_name}/invoke' \
--header 'Content-Type: application/json' \
--header 'Accept: application/json' \
--user "$AWS_ACCESS_KEY_ID":"$AWS_SECRET_ACCESS_KEY" \
--aws-sigv4 "aws:amz:us-east-1:bedrock" \
--data-raw '{
"prompt": "Hi",
"temperature": 0,
"p": 0.9,
"max_tokens": 4096
}'
```
"""
def __init__(self) -> None:
super().__init__()
def get_credentials(
self,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_region_name: Optional[str] = None,
aws_session_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
aws_role_name: Optional[str] = None,
):
"""
Return a boto3.Credentials object
"""
import boto3
## CHECK IS 'os.environ/' passed in
params_to_check: List[Optional[str]] = [
aws_access_key_id,
aws_secret_access_key,
aws_region_name,
aws_session_name,
aws_profile_name,
aws_role_name,
]
# Iterate over parameters and update if needed
for i, param in enumerate(params_to_check):
if param and param.startswith("os.environ/"):
_v = get_secret(param)
if _v is not None and isinstance(_v, str):
params_to_check[i] = _v
# Assign updated values back to parameters
(
aws_access_key_id,
aws_secret_access_key,
aws_region_name,
aws_session_name,
aws_profile_name,
aws_role_name,
) = params_to_check
### CHECK STS ###
if aws_role_name is not None and aws_session_name is not None:
sts_client = boto3.client(
"sts",
aws_access_key_id=aws_access_key_id, # [OPTIONAL]
aws_secret_access_key=aws_secret_access_key, # [OPTIONAL]
)
sts_response = sts_client.assume_role(
RoleArn=aws_role_name, RoleSessionName=aws_session_name
)
return sts_response["Credentials"]
elif aws_profile_name is not None: ### CHECK SESSION ###
# uses auth values from AWS profile usually stored in ~/.aws/credentials
client = boto3.Session(profile_name=aws_profile_name)
return client.get_credentials()
else:
session = boto3.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=aws_region_name,
)
return session.get_credentials()
def completion(self, *args, **kwargs) -> Union[ModelResponse, CustomStreamWrapper]:
## get credentials
## generate signature
## make request
return super().completion(*args, **kwargs)
def embedding(self, *args, **kwargs):
return super().embedding(*args, **kwargs)

View file

@ -75,6 +75,7 @@ from .llms.anthropic import AnthropicChatCompletion
from .llms.anthropic_text import AnthropicTextCompletion from .llms.anthropic_text import AnthropicTextCompletion
from .llms.huggingface_restapi import Huggingface from .llms.huggingface_restapi import Huggingface
from .llms.predibase import PredibaseChatCompletion from .llms.predibase import PredibaseChatCompletion
from .llms.bedrock_httpx import BedrockLLM
from .llms.triton import TritonChatCompletion from .llms.triton import TritonChatCompletion
from .llms.prompt_templates.factory import ( from .llms.prompt_templates.factory import (
prompt_factory, prompt_factory,
@ -104,7 +105,6 @@ from litellm.utils import (
) )
####### ENVIRONMENT VARIABLES ################### ####### ENVIRONMENT VARIABLES ###################
dotenv.load_dotenv() # Loading env variables using dotenv
openai_chat_completions = OpenAIChatCompletion() openai_chat_completions = OpenAIChatCompletion()
openai_text_completions = OpenAITextCompletion() openai_text_completions = OpenAITextCompletion()
anthropic_chat_completions = AnthropicChatCompletion() anthropic_chat_completions = AnthropicChatCompletion()
@ -114,6 +114,7 @@ azure_text_completions = AzureTextCompletion()
huggingface = Huggingface() huggingface = Huggingface()
predibase_chat_completions = PredibaseChatCompletion() predibase_chat_completions = PredibaseChatCompletion()
triton_chat_completions = TritonChatCompletion() triton_chat_completions = TritonChatCompletion()
bedrock_chat_completion = BedrockLLM()
####### COMPLETION ENDPOINTS ################ ####### COMPLETION ENDPOINTS ################

View file

@ -1,10 +1,7 @@
from litellm.proxy._types import UserAPIKeyAuth, GenerateKeyRequest from litellm.proxy._types import UserAPIKeyAuth, GenerateKeyRequest
from fastapi import Request from fastapi import Request
from dotenv import load_dotenv
import os import os
load_dotenv()
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth: async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
try: try:

View file

@ -8,8 +8,6 @@
import dotenv, os, requests, random # type: ignore import dotenv, os, requests, random # type: ignore
from typing import Optional from typing import Optional
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger

View file

@ -1,12 +1,11 @@
#### What this does #### #### What this does ####
# picks based on response time (for streaming, this is time to first token) # picks based on response time (for streaming, this is time to first token)
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import BaseModel, Extra, Field, root_validator
import dotenv, os, requests, random # type: ignore import os, requests, random # type: ignore
from typing import Optional, Union, List, Dict from typing import Optional, Union, List, Dict
from datetime import datetime, timedelta from datetime import datetime, timedelta
import random import random
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger

View file

@ -5,8 +5,6 @@ import dotenv, os, requests, random # type: ignore
from typing import Optional, Union, List, Dict from typing import Optional, Union, List, Dict
from datetime import datetime, timedelta from datetime import datetime, timedelta
import random import random
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger

View file

@ -4,8 +4,6 @@
import dotenv, os, requests, random import dotenv, os, requests, random
from typing import Optional, Union, List, Dict from typing import Optional, Union, List, Dict
from datetime import datetime from datetime import datetime
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
from litellm import token_counter from litellm import token_counter
from litellm.caching import DualCache from litellm.caching import DualCache

View file

@ -5,8 +5,6 @@ import dotenv, os, requests, random
from typing import Optional, Union, List, Dict from typing import Optional, Union, List, Dict
import datetime as datetime_og import datetime as datetime_og
from datetime import datetime from datetime import datetime
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback, asyncio, httpx import traceback, asyncio, httpx
import litellm import litellm
from litellm import token_counter from litellm import token_counter

View file

@ -2584,6 +2584,15 @@ def test_completion_chat_sagemaker_mistral():
# test_completion_chat_sagemaker_mistral() # test_completion_chat_sagemaker_mistral()
def test_completion_bedrock_command_r():
response = completion(
model="bedrock/cohere.command-r-plus-v1:0",
messages=[{"role": "user", "content": "Hey! how's it going?"}],
)
print(f"response: {response}")
def test_completion_bedrock_titan_null_response(): def test_completion_bedrock_titan_null_response():
try: try:
response = completion( response = completion(

View file

@ -117,7 +117,6 @@ MAX_THREADS = 100
# Create a ThreadPoolExecutor # Create a ThreadPoolExecutor
executor = ThreadPoolExecutor(max_workers=MAX_THREADS) executor = ThreadPoolExecutor(max_workers=MAX_THREADS)
dotenv.load_dotenv() # Loading env variables using dotenv
sentry_sdk_instance = None sentry_sdk_instance = None
capture_exception = None capture_exception = None
add_breadcrumb = None add_breadcrumb = None