style(test_completion.py): fix merge conflict

This commit is contained in:
Krrish Dholakia 2023-10-05 22:09:38 -07:00
parent 396d9d8e38
commit dd7e397650
22 changed files with 1535 additions and 250 deletions

View file

@ -1,9 +1,10 @@
import os
import json
import os, types
from enum import Enum
import json
import requests
import time
from typing import Callable
from typing import Callable, Optional
import litellm
from litellm.utils import ModelResponse, get_secret
import sys
from copy import deepcopy
@ -16,6 +17,32 @@ class SagemakerError(Exception):
self.message
) # Call the base class constructor with the parameters it needs
class SagemakerConfig():
"""
Reference: https://d-uuwbxj1u4cnu.studio.us-west-2.sagemaker.aws/jupyter/default/lab/workspaces/auto-q/tree/DemoNotebooks/meta-textgeneration-llama-2-7b-SDK_1.ipynb
"""
max_new_tokens: Optional[int]=None
top_p: Optional[float]=None
temperature: Optional[float]=None
return_full_text: Optional[bool]=None
def __init__(self,
max_new_tokens: Optional[int]=None,
top_p: Optional[float]=None,
temperature: Optional[float]=None,
return_full_text: Optional[bool]=None) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
"""
SAGEMAKER AUTH Keys/Vars
os.environ['AWS_ACCESS_KEY_ID'] = ""
@ -47,6 +74,16 @@ def completion(
region_name=region_name
)
# pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
inference_params = deepcopy(optional_params)
inference_params.pop("stream", None)
## Load Config
config = litellm.SagemakerConfig.get_config()
for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
model = model
prompt = ""
for message in messages:
@ -61,9 +98,7 @@ def completion(
)
else:
prompt += f"{message['content']}"
# pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
inference_params = deepcopy(optional_params)
inference_params.pop("stream", None)
data = {
"inputs": prompt,
"parameters": inference_params