feat(bedrock_httpx.py): moves to using httpx client for bedrock cohere calls

This commit is contained in:
Krrish Dholakia 2024-05-11 13:43:08 -07:00
parent c12af219af
commit 4a3b084961
29 changed files with 147 additions and 64 deletions

View file

@ -10,7 +10,6 @@ from litellm.caching import DualCache
from typing import Literal, Union
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
@ -19,8 +18,6 @@ import traceback
import dotenv, os
import requests
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
import datetime, subprocess, sys
import litellm, uuid

View file

@ -1,8 +1,6 @@
#### What this does ####
# On success + failure, log events to aispend.io
import dotenv, os
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
import datetime

View file

@ -3,7 +3,6 @@
import dotenv, os
import requests # type: ignore
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
import datetime

View file

@ -8,8 +8,6 @@ from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
from typing import Literal, Union
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
@ -18,8 +16,6 @@ import traceback
import dotenv, os
import requests
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
import datetime, subprocess, sys
import litellm, uuid

View file

@ -6,8 +6,6 @@ from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
from typing import Literal, Union, Optional
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback

View file

@ -3,8 +3,6 @@
import dotenv, os
import requests # type: ignore
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
import datetime, subprocess, sys
import litellm, uuid

View file

@ -3,8 +3,6 @@
import dotenv, os
import requests # type: ignore
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
import datetime, subprocess, sys
import litellm, uuid

View file

@ -3,8 +3,6 @@
import dotenv, os
import requests # type: ignore
import litellm
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback

View file

@ -1,8 +1,6 @@
#### What this does ####
# On success, logs events to Langfuse
import dotenv, os
dotenv.load_dotenv() # Loading env variables using dotenv
import os
import copy
import traceback
from packaging.version import Version

View file

@ -3,8 +3,6 @@
import dotenv, os # type: ignore
import requests # type: ignore
from datetime import datetime
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
import asyncio
import types

View file

@ -2,14 +2,11 @@
# On success + failure, log events to lunary.ai
from datetime import datetime, timezone
import traceback
import dotenv
import importlib
import sys
import packaging
dotenv.load_dotenv()
# convert to {completion: xx, tokens: xx}
def parse_usage(usage):
@ -62,14 +59,16 @@ class LunaryLogger:
version = importlib.metadata.version("lunary")
# if version < 0.1.43 then raise ImportError
if packaging.version.Version(version) < packaging.version.Version("0.1.43"):
print(
print( # noqa
"Lunary version outdated. Required: >= 0.1.43. Upgrade via 'pip install lunary --upgrade'"
)
raise ImportError
self.lunary_client = lunary
except ImportError:
print("Lunary not installed. Please install it using 'pip install lunary'")
print( # noqa
"Lunary not installed. Please install it using 'pip install lunary'"
) # noqa
raise ImportError
def log_event(

View file

@ -3,8 +3,6 @@
import dotenv, os, json
import litellm
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler

View file

@ -4,8 +4,6 @@
import dotenv, os
import requests # type: ignore
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
import datetime, subprocess, sys
import litellm, uuid

View file

@ -5,8 +5,6 @@
import dotenv, os
import requests # type: ignore
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
import datetime, subprocess, sys
import litellm, uuid

View file

@ -3,8 +3,6 @@
import dotenv, os
import requests # type: ignore
from pydantic import BaseModel
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback

View file

@ -1,9 +1,7 @@
#### What this does ####
# On success + failure, log events to Supabase
import dotenv, os
dotenv.load_dotenv() # Loading env variables using dotenv
import os
import traceback
import datetime, subprocess, sys
import litellm, uuid

View file

@ -2,8 +2,6 @@
# Class for sending Slack Alerts #
import dotenv, os
from litellm.proxy._types import UserAPIKeyAuth
dotenv.load_dotenv() # Loading env variables using dotenv
from litellm._logging import verbose_logger, verbose_proxy_logger
import litellm, threading
from typing import List, Literal, Any, Union, Optional, Dict

View file

@ -3,8 +3,6 @@
import dotenv, os
import requests # type: ignore
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
import datetime, subprocess, sys
import litellm

View file

@ -21,11 +21,11 @@ try:
# contains a (known) object attribute
object: Literal["chat.completion", "edit", "text_completion"]
def __getitem__(self, key: K) -> V:
... # pragma: no cover
def __getitem__(self, key: K) -> V: ... # noqa
def get(self, key: K, default: Optional[V] = None) -> Optional[V]:
... # pragma: no cover
def get( # noqa
self, key: K, default: Optional[V] = None
) -> Optional[V]: ... # pragma: no cover
class OpenAIRequestResponseResolver:
def __call__(
@ -173,12 +173,11 @@ except:
#### What this does ####
# On success, logs events to Langfuse
import dotenv, os
import os
import requests
import requests
from datetime import datetime
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback

View file

@ -0,0 +1,124 @@
# What is this?
## Initial implementation of calling bedrock via httpx client (allows for async calls).
## V0 - just covers cohere command-r support
import os, types
import json
from enum import Enum
import requests, copy # type: ignore
import time
from typing import Callable, Optional, List, Literal, Union
from litellm.utils import (
ModelResponse,
Usage,
map_finish_reason,
CustomStreamWrapper,
Message,
Choices,
get_secret,
)
import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from .base import BaseLLM
import httpx # type: ignore
from .bedrock import BedrockError
class BedrockLLM(BaseLLM):
"""
Example call
```
curl --location --request POST 'https://bedrock-runtime.{aws_region_name}.amazonaws.com/model/{bedrock_model_name}/invoke' \
--header 'Content-Type: application/json' \
--header 'Accept: application/json' \
--user "$AWS_ACCESS_KEY_ID":"$AWS_SECRET_ACCESS_KEY" \
--aws-sigv4 "aws:amz:us-east-1:bedrock" \
--data-raw '{
"prompt": "Hi",
"temperature": 0,
"p": 0.9,
"max_tokens": 4096
}'
```
"""
def __init__(self) -> None:
super().__init__()
def get_credentials(
self,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_region_name: Optional[str] = None,
aws_session_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
aws_role_name: Optional[str] = None,
):
"""
Return a boto3.Credentials object
"""
import boto3
## CHECK IS 'os.environ/' passed in
params_to_check: List[Optional[str]] = [
aws_access_key_id,
aws_secret_access_key,
aws_region_name,
aws_session_name,
aws_profile_name,
aws_role_name,
]
# Iterate over parameters and update if needed
for i, param in enumerate(params_to_check):
if param and param.startswith("os.environ/"):
_v = get_secret(param)
if _v is not None and isinstance(_v, str):
params_to_check[i] = _v
# Assign updated values back to parameters
(
aws_access_key_id,
aws_secret_access_key,
aws_region_name,
aws_session_name,
aws_profile_name,
aws_role_name,
) = params_to_check
### CHECK STS ###
if aws_role_name is not None and aws_session_name is not None:
sts_client = boto3.client(
"sts",
aws_access_key_id=aws_access_key_id, # [OPTIONAL]
aws_secret_access_key=aws_secret_access_key, # [OPTIONAL]
)
sts_response = sts_client.assume_role(
RoleArn=aws_role_name, RoleSessionName=aws_session_name
)
return sts_response["Credentials"]
elif aws_profile_name is not None: ### CHECK SESSION ###
# uses auth values from AWS profile usually stored in ~/.aws/credentials
client = boto3.Session(profile_name=aws_profile_name)
return client.get_credentials()
else:
session = boto3.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=aws_region_name,
)
return session.get_credentials()
def completion(self, *args, **kwargs) -> Union[ModelResponse, CustomStreamWrapper]:
## get credentials
## generate signature
## make request
return super().completion(*args, **kwargs)
def embedding(self, *args, **kwargs):
return super().embedding(*args, **kwargs)

View file

@ -75,6 +75,7 @@ from .llms.anthropic import AnthropicChatCompletion
from .llms.anthropic_text import AnthropicTextCompletion
from .llms.huggingface_restapi import Huggingface
from .llms.predibase import PredibaseChatCompletion
from .llms.bedrock_httpx import BedrockLLM
from .llms.triton import TritonChatCompletion
from .llms.prompt_templates.factory import (
prompt_factory,
@ -104,7 +105,6 @@ from litellm.utils import (
)
####### ENVIRONMENT VARIABLES ###################
dotenv.load_dotenv() # Loading env variables using dotenv
openai_chat_completions = OpenAIChatCompletion()
openai_text_completions = OpenAITextCompletion()
anthropic_chat_completions = AnthropicChatCompletion()
@ -114,6 +114,7 @@ azure_text_completions = AzureTextCompletion()
huggingface = Huggingface()
predibase_chat_completions = PredibaseChatCompletion()
triton_chat_completions = TritonChatCompletion()
bedrock_chat_completion = BedrockLLM()
####### COMPLETION ENDPOINTS ################

View file

@ -1,10 +1,7 @@
from litellm.proxy._types import UserAPIKeyAuth, GenerateKeyRequest
from fastapi import Request
from dotenv import load_dotenv
import os
load_dotenv()
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
try:

View file

@ -8,8 +8,6 @@
import dotenv, os, requests, random # type: ignore
from typing import Optional
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger

View file

@ -1,12 +1,11 @@
#### What this does ####
# picks based on response time (for streaming, this is time to first token)
from pydantic import BaseModel, Extra, Field, root_validator
import dotenv, os, requests, random # type: ignore
import os, requests, random # type: ignore
from typing import Optional, Union, List, Dict
from datetime import datetime, timedelta
import random
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger

View file

@ -5,8 +5,6 @@ import dotenv, os, requests, random # type: ignore
from typing import Optional, Union, List, Dict
from datetime import datetime, timedelta
import random
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger

View file

@ -4,8 +4,6 @@
import dotenv, os, requests, random
from typing import Optional, Union, List, Dict
from datetime import datetime
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
from litellm import token_counter
from litellm.caching import DualCache

View file

@ -5,8 +5,6 @@ import dotenv, os, requests, random
from typing import Optional, Union, List, Dict
import datetime as datetime_og
from datetime import datetime
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback, asyncio, httpx
import litellm
from litellm import token_counter

View file

@ -2584,6 +2584,15 @@ def test_completion_chat_sagemaker_mistral():
# test_completion_chat_sagemaker_mistral()
def test_completion_bedrock_command_r():
response = completion(
model="bedrock/cohere.command-r-plus-v1:0",
messages=[{"role": "user", "content": "Hey! how's it going?"}],
)
print(f"response: {response}")
def test_completion_bedrock_titan_null_response():
try:
response = completion(

View file

@ -117,7 +117,6 @@ MAX_THREADS = 100
# Create a ThreadPoolExecutor
executor = ThreadPoolExecutor(max_workers=MAX_THREADS)
dotenv.load_dotenv() # Loading env variables using dotenv
sentry_sdk_instance = None
capture_exception = None
add_breadcrumb = None