mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
style(test_completion.py): fix merge conflict
This commit is contained in:
parent
396d9d8e38
commit
dd7e397650
22 changed files with 1535 additions and 250 deletions
|
@ -1,9 +1,10 @@
|
|||
import os
|
||||
import json
|
||||
import os, types
|
||||
from enum import Enum
|
||||
import json
|
||||
import requests
|
||||
import time
|
||||
from typing import Callable
|
||||
from typing import Callable, Optional
|
||||
import litellm
|
||||
from litellm.utils import ModelResponse, get_secret
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
|
@ -16,6 +17,32 @@ class SagemakerError(Exception):
|
|||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
class SagemakerConfig():
|
||||
"""
|
||||
Reference: https://d-uuwbxj1u4cnu.studio.us-west-2.sagemaker.aws/jupyter/default/lab/workspaces/auto-q/tree/DemoNotebooks/meta-textgeneration-llama-2-7b-SDK_1.ipynb
|
||||
"""
|
||||
max_new_tokens: Optional[int]=None
|
||||
top_p: Optional[float]=None
|
||||
temperature: Optional[float]=None
|
||||
return_full_text: Optional[bool]=None
|
||||
|
||||
def __init__(self,
|
||||
max_new_tokens: Optional[int]=None,
|
||||
top_p: Optional[float]=None,
|
||||
temperature: Optional[float]=None,
|
||||
return_full_text: Optional[bool]=None) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != 'self' and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {k: v for k, v in cls.__dict__.items()
|
||||
if not k.startswith('__')
|
||||
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
|
||||
and v is not None}
|
||||
|
||||
"""
|
||||
SAGEMAKER AUTH Keys/Vars
|
||||
os.environ['AWS_ACCESS_KEY_ID'] = ""
|
||||
|
@ -47,6 +74,16 @@ def completion(
|
|||
region_name=region_name
|
||||
)
|
||||
|
||||
# pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
|
||||
inference_params = deepcopy(optional_params)
|
||||
inference_params.pop("stream", None)
|
||||
|
||||
## Load Config
|
||||
config = litellm.SagemakerConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in inference_params: # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
|
||||
model = model
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
|
@ -61,9 +98,7 @@ def completion(
|
|||
)
|
||||
else:
|
||||
prompt += f"{message['content']}"
|
||||
# pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
|
||||
inference_params = deepcopy(optional_params)
|
||||
inference_params.pop("stream", None)
|
||||
|
||||
data = {
|
||||
"inputs": prompt,
|
||||
"parameters": inference_params
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue