forked from phoenix/litellm-mirror
feat(bedrock_httpx.py): moves to using httpx client for bedrock cohere calls
This commit is contained in:
parent
c12af219af
commit
4a3b084961
29 changed files with 147 additions and 64 deletions
|
@ -10,7 +10,6 @@ from litellm.caching import DualCache
|
|||
|
||||
from typing import Literal, Union
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
|
||||
|
||||
|
@ -19,8 +18,6 @@ import traceback
|
|||
|
||||
import dotenv, os
|
||||
import requests
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
import datetime, subprocess, sys
|
||||
import litellm, uuid
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
#### What this does ####
|
||||
# On success + failure, log events to aispend.io
|
||||
import dotenv, os
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
import datetime
|
||||
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
import dotenv, os
|
||||
import requests # type: ignore
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
import datetime
|
||||
|
||||
|
|
|
@ -8,8 +8,6 @@ from litellm.proxy._types import UserAPIKeyAuth
|
|||
from litellm.caching import DualCache
|
||||
|
||||
from typing import Literal, Union
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
|
||||
|
||||
|
@ -18,8 +16,6 @@ import traceback
|
|||
|
||||
import dotenv, os
|
||||
import requests
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
import datetime, subprocess, sys
|
||||
import litellm, uuid
|
||||
|
|
|
@ -6,8 +6,6 @@ from litellm.proxy._types import UserAPIKeyAuth
|
|||
from litellm.caching import DualCache
|
||||
|
||||
from typing import Literal, Union, Optional
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
|
||||
|
||||
|
|
|
@ -3,8 +3,6 @@
|
|||
|
||||
import dotenv, os
|
||||
import requests # type: ignore
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
import datetime, subprocess, sys
|
||||
import litellm, uuid
|
||||
|
|
|
@ -3,8 +3,6 @@
|
|||
|
||||
import dotenv, os
|
||||
import requests # type: ignore
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
import datetime, subprocess, sys
|
||||
import litellm, uuid
|
||||
|
|
|
@ -3,8 +3,6 @@
|
|||
import dotenv, os
|
||||
import requests # type: ignore
|
||||
import litellm
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
|
||||
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
#### What this does ####
|
||||
# On success, logs events to Langfuse
|
||||
import dotenv, os
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import os
|
||||
import copy
|
||||
import traceback
|
||||
from packaging.version import Version
|
||||
|
|
|
@ -3,8 +3,6 @@
|
|||
import dotenv, os # type: ignore
|
||||
import requests # type: ignore
|
||||
from datetime import datetime
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
import asyncio
|
||||
import types
|
||||
|
|
|
@ -2,14 +2,11 @@
|
|||
# On success + failure, log events to lunary.ai
|
||||
from datetime import datetime, timezone
|
||||
import traceback
|
||||
import dotenv
|
||||
import importlib
|
||||
import sys
|
||||
|
||||
import packaging
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
|
||||
# convert to {completion: xx, tokens: xx}
|
||||
def parse_usage(usage):
|
||||
|
@ -62,14 +59,16 @@ class LunaryLogger:
|
|||
version = importlib.metadata.version("lunary")
|
||||
# if version < 0.1.43 then raise ImportError
|
||||
if packaging.version.Version(version) < packaging.version.Version("0.1.43"):
|
||||
print(
|
||||
print( # noqa
|
||||
"Lunary version outdated. Required: >= 0.1.43. Upgrade via 'pip install lunary --upgrade'"
|
||||
)
|
||||
raise ImportError
|
||||
|
||||
self.lunary_client = lunary
|
||||
except ImportError:
|
||||
print("Lunary not installed. Please install it using 'pip install lunary'")
|
||||
print( # noqa
|
||||
"Lunary not installed. Please install it using 'pip install lunary'"
|
||||
) # noqa
|
||||
raise ImportError
|
||||
|
||||
def log_event(
|
||||
|
|
|
@ -3,8 +3,6 @@
|
|||
|
||||
import dotenv, os, json
|
||||
import litellm
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
|
|
|
@ -4,8 +4,6 @@
|
|||
|
||||
import dotenv, os
|
||||
import requests # type: ignore
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
import datetime, subprocess, sys
|
||||
import litellm, uuid
|
||||
|
|
|
@ -5,8 +5,6 @@
|
|||
|
||||
import dotenv, os
|
||||
import requests # type: ignore
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
import datetime, subprocess, sys
|
||||
import litellm, uuid
|
||||
|
|
|
@ -3,8 +3,6 @@
|
|||
import dotenv, os
|
||||
import requests # type: ignore
|
||||
from pydantic import BaseModel
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
|
||||
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
#### What this does ####
|
||||
# On success + failure, log events to Supabase
|
||||
|
||||
import dotenv, os
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import os
|
||||
import traceback
|
||||
import datetime, subprocess, sys
|
||||
import litellm, uuid
|
||||
|
|
|
@ -2,8 +2,6 @@
|
|||
# Class for sending Slack Alerts #
|
||||
import dotenv, os
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
from litellm._logging import verbose_logger, verbose_proxy_logger
|
||||
import litellm, threading
|
||||
from typing import List, Literal, Any, Union, Optional, Dict
|
||||
|
|
|
@ -3,8 +3,6 @@
|
|||
|
||||
import dotenv, os
|
||||
import requests # type: ignore
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
import datetime, subprocess, sys
|
||||
import litellm
|
||||
|
|
|
@ -21,11 +21,11 @@ try:
|
|||
# contains a (known) object attribute
|
||||
object: Literal["chat.completion", "edit", "text_completion"]
|
||||
|
||||
def __getitem__(self, key: K) -> V:
|
||||
... # pragma: no cover
|
||||
def __getitem__(self, key: K) -> V: ... # noqa
|
||||
|
||||
def get(self, key: K, default: Optional[V] = None) -> Optional[V]:
|
||||
... # pragma: no cover
|
||||
def get( # noqa
|
||||
self, key: K, default: Optional[V] = None
|
||||
) -> Optional[V]: ... # pragma: no cover
|
||||
|
||||
class OpenAIRequestResponseResolver:
|
||||
def __call__(
|
||||
|
@ -173,12 +173,11 @@ except:
|
|||
|
||||
#### What this does ####
|
||||
# On success, logs events to Langfuse
|
||||
import dotenv, os
|
||||
import os
|
||||
import requests
|
||||
import requests
|
||||
from datetime import datetime
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
|
||||
|
||||
|
|
124
litellm/llms/bedrock_httpx.py
Normal file
124
litellm/llms/bedrock_httpx.py
Normal file
|
@ -0,0 +1,124 @@
|
|||
# What is this?
|
||||
## Initial implementation of calling bedrock via httpx client (allows for async calls).
|
||||
## V0 - just covers cohere command-r support
|
||||
|
||||
import os, types
|
||||
import json
|
||||
from enum import Enum
|
||||
import requests, copy # type: ignore
|
||||
import time
|
||||
from typing import Callable, Optional, List, Literal, Union
|
||||
from litellm.utils import (
|
||||
ModelResponse,
|
||||
Usage,
|
||||
map_finish_reason,
|
||||
CustomStreamWrapper,
|
||||
Message,
|
||||
Choices,
|
||||
get_secret,
|
||||
)
|
||||
import litellm
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
from .base import BaseLLM
|
||||
import httpx # type: ignore
|
||||
from .bedrock import BedrockError
|
||||
|
||||
|
||||
class BedrockLLM(BaseLLM):
|
||||
"""
|
||||
Example call
|
||||
|
||||
```
|
||||
curl --location --request POST 'https://bedrock-runtime.{aws_region_name}.amazonaws.com/model/{bedrock_model_name}/invoke' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--header 'Accept: application/json' \
|
||||
--user "$AWS_ACCESS_KEY_ID":"$AWS_SECRET_ACCESS_KEY" \
|
||||
--aws-sigv4 "aws:amz:us-east-1:bedrock" \
|
||||
--data-raw '{
|
||||
"prompt": "Hi",
|
||||
"temperature": 0,
|
||||
"p": 0.9,
|
||||
"max_tokens": 4096
|
||||
}'
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def get_credentials(
|
||||
self,
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
aws_secret_access_key: Optional[str] = None,
|
||||
aws_region_name: Optional[str] = None,
|
||||
aws_session_name: Optional[str] = None,
|
||||
aws_profile_name: Optional[str] = None,
|
||||
aws_role_name: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Return a boto3.Credentials object
|
||||
"""
|
||||
import boto3
|
||||
|
||||
## CHECK IS 'os.environ/' passed in
|
||||
params_to_check: List[Optional[str]] = [
|
||||
aws_access_key_id,
|
||||
aws_secret_access_key,
|
||||
aws_region_name,
|
||||
aws_session_name,
|
||||
aws_profile_name,
|
||||
aws_role_name,
|
||||
]
|
||||
|
||||
# Iterate over parameters and update if needed
|
||||
for i, param in enumerate(params_to_check):
|
||||
if param and param.startswith("os.environ/"):
|
||||
_v = get_secret(param)
|
||||
if _v is not None and isinstance(_v, str):
|
||||
params_to_check[i] = _v
|
||||
# Assign updated values back to parameters
|
||||
(
|
||||
aws_access_key_id,
|
||||
aws_secret_access_key,
|
||||
aws_region_name,
|
||||
aws_session_name,
|
||||
aws_profile_name,
|
||||
aws_role_name,
|
||||
) = params_to_check
|
||||
|
||||
### CHECK STS ###
|
||||
if aws_role_name is not None and aws_session_name is not None:
|
||||
sts_client = boto3.client(
|
||||
"sts",
|
||||
aws_access_key_id=aws_access_key_id, # [OPTIONAL]
|
||||
aws_secret_access_key=aws_secret_access_key, # [OPTIONAL]
|
||||
)
|
||||
|
||||
sts_response = sts_client.assume_role(
|
||||
RoleArn=aws_role_name, RoleSessionName=aws_session_name
|
||||
)
|
||||
|
||||
return sts_response["Credentials"]
|
||||
elif aws_profile_name is not None: ### CHECK SESSION ###
|
||||
# uses auth values from AWS profile usually stored in ~/.aws/credentials
|
||||
client = boto3.Session(profile_name=aws_profile_name)
|
||||
|
||||
return client.get_credentials()
|
||||
else:
|
||||
session = boto3.Session(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
region_name=aws_region_name,
|
||||
)
|
||||
|
||||
return session.get_credentials()
|
||||
|
||||
def completion(self, *args, **kwargs) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
## get credentials
|
||||
## generate signature
|
||||
## make request
|
||||
return super().completion(*args, **kwargs)
|
||||
|
||||
def embedding(self, *args, **kwargs):
|
||||
return super().embedding(*args, **kwargs)
|
|
@ -75,6 +75,7 @@ from .llms.anthropic import AnthropicChatCompletion
|
|||
from .llms.anthropic_text import AnthropicTextCompletion
|
||||
from .llms.huggingface_restapi import Huggingface
|
||||
from .llms.predibase import PredibaseChatCompletion
|
||||
from .llms.bedrock_httpx import BedrockLLM
|
||||
from .llms.triton import TritonChatCompletion
|
||||
from .llms.prompt_templates.factory import (
|
||||
prompt_factory,
|
||||
|
@ -104,7 +105,6 @@ from litellm.utils import (
|
|||
)
|
||||
|
||||
####### ENVIRONMENT VARIABLES ###################
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
openai_chat_completions = OpenAIChatCompletion()
|
||||
openai_text_completions = OpenAITextCompletion()
|
||||
anthropic_chat_completions = AnthropicChatCompletion()
|
||||
|
@ -114,6 +114,7 @@ azure_text_completions = AzureTextCompletion()
|
|||
huggingface = Huggingface()
|
||||
predibase_chat_completions = PredibaseChatCompletion()
|
||||
triton_chat_completions = TritonChatCompletion()
|
||||
bedrock_chat_completion = BedrockLLM()
|
||||
####### COMPLETION ENDPOINTS ################
|
||||
|
||||
|
||||
|
|
|
@ -1,10 +1,7 @@
|
|||
from litellm.proxy._types import UserAPIKeyAuth, GenerateKeyRequest
|
||||
from fastapi import Request
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
|
||||
try:
|
||||
|
|
|
@ -8,8 +8,6 @@
|
|||
|
||||
import dotenv, os, requests, random # type: ignore
|
||||
from typing import Optional
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
from litellm.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
#### What this does ####
|
||||
# picks based on response time (for streaming, this is time to first token)
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
import dotenv, os, requests, random # type: ignore
|
||||
import os, requests, random # type: ignore
|
||||
from typing import Optional, Union, List, Dict
|
||||
from datetime import datetime, timedelta
|
||||
import random
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
from litellm.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
|
|
@ -5,8 +5,6 @@ import dotenv, os, requests, random # type: ignore
|
|||
from typing import Optional, Union, List, Dict
|
||||
from datetime import datetime, timedelta
|
||||
import random
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
from litellm.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
|
|
@ -4,8 +4,6 @@
|
|||
import dotenv, os, requests, random
|
||||
from typing import Optional, Union, List, Dict
|
||||
from datetime import datetime
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
from litellm import token_counter
|
||||
from litellm.caching import DualCache
|
||||
|
|
|
@ -5,8 +5,6 @@ import dotenv, os, requests, random
|
|||
from typing import Optional, Union, List, Dict
|
||||
import datetime as datetime_og
|
||||
from datetime import datetime
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback, asyncio, httpx
|
||||
import litellm
|
||||
from litellm import token_counter
|
||||
|
|
|
@ -2584,6 +2584,15 @@ def test_completion_chat_sagemaker_mistral():
|
|||
# test_completion_chat_sagemaker_mistral()
|
||||
|
||||
|
||||
def test_completion_bedrock_command_r():
|
||||
response = completion(
|
||||
model="bedrock/cohere.command-r-plus-v1:0",
|
||||
messages=[{"role": "user", "content": "Hey! how's it going?"}],
|
||||
)
|
||||
|
||||
print(f"response: {response}")
|
||||
|
||||
|
||||
def test_completion_bedrock_titan_null_response():
|
||||
try:
|
||||
response = completion(
|
||||
|
|
|
@ -117,7 +117,6 @@ MAX_THREADS = 100
|
|||
|
||||
# Create a ThreadPoolExecutor
|
||||
executor = ThreadPoolExecutor(max_workers=MAX_THREADS)
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
sentry_sdk_instance = None
|
||||
capture_exception = None
|
||||
add_breadcrumb = None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue