formatting improvements

This commit is contained in:
ishaan-jaff 2023-08-28 09:20:50 -07:00
parent 3e0a16acf4
commit a69b7ffcfa
17 changed files with 464 additions and 323 deletions

View file

@ -103,7 +103,9 @@ def completion(
return completion_with_fallbacks(**args)
if litellm.model_alias_map and model in litellm.model_alias_map:
args["model_alias_map"] = litellm.model_alias_map
model = litellm.model_alias_map[model] # update the model to the actual value if an alias has been passed in
model = litellm.model_alias_map[
model
] # update the model to the actual value if an alias has been passed in
model_response = ModelResponse()
if azure: # this flag is deprecated, remove once notebooks are also updated.
custom_llm_provider = "azure"
@ -146,7 +148,7 @@ def completion(
custom_llm_provider=custom_llm_provider,
custom_api_base=custom_api_base,
litellm_call_id=litellm_call_id,
model_alias_map=litellm.model_alias_map
model_alias_map=litellm.model_alias_map,
)
logging = Logging(
model=model,
@ -216,7 +218,10 @@ def completion(
# note: if a user sets a custom base - we should ensure this works
# allow for the setting of dynamic and stateful api-bases
api_base = (
custom_api_base or litellm.api_base or get_secret("OPENAI_API_BASE") or "https://api.openai.com/v1"
custom_api_base
or litellm.api_base
or get_secret("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
openai.api_base = api_base
openai.api_version = None
@ -255,9 +260,11 @@ def completion(
original_response=response,
additional_args={"headers": litellm.headers},
)
elif (model in litellm.open_ai_text_completion_models or
"ft:babbage-002" in model or # support for finetuned completion models
"ft:davinci-002" in model):
elif (
model in litellm.open_ai_text_completion_models
or "ft:babbage-002" in model
or "ft:davinci-002" in model # support for finetuned completion models
):
openai.api_type = "openai"
openai.api_base = (
litellm.api_base
@ -544,7 +551,10 @@ def completion(
logging.pre_call(input=prompt, api_key=TOGETHER_AI_TOKEN)
print(f"TOGETHER_AI_TOKEN: {TOGETHER_AI_TOKEN}")
if "stream_tokens" in optional_params and optional_params["stream_tokens"] == True:
if (
"stream_tokens" in optional_params
and optional_params["stream_tokens"] == True
):
res = requests.post(
endpoint,
json={
@ -698,9 +708,7 @@ def completion(
):
custom_llm_provider = "baseten"
baseten_key = (
api_key
or litellm.baseten_key
or os.environ.get("BASETEN_API_KEY")
api_key or litellm.baseten_key or os.environ.get("BASETEN_API_KEY")
)
baseten_client = BasetenLLM(
encoding=encoding, api_key=baseten_key, logging_obj=logging
@ -767,11 +775,14 @@ def completion(
model=model, custom_llm_provider=custom_llm_provider, original_exception=e
)
def completion_with_retries(*args, **kwargs):
import tenacity
retryer = tenacity.Retrying(stop=tenacity.stop_after_attempt(3), reraise=True)
return retryer(completion, *args, **kwargs)
def batch_completion(*args, **kwargs):
batch_messages = args[1] if len(args) > 1 else kwargs.get("messages")
completions = []
@ -865,14 +876,16 @@ def embedding(
custom_llm_provider="azure" if azure == True else None,
)
###### Text Completion ################
def text_completion(*args, **kwargs):
if 'prompt' in kwargs:
messages = [{'role': 'system', 'content': kwargs['prompt']}]
kwargs['messages'] = messages
kwargs.pop('prompt')
if "prompt" in kwargs:
messages = [{"role": "system", "content": kwargs["prompt"]}]
kwargs["messages"] = messages
kwargs.pop("prompt")
return completion(*args, **kwargs)
####### HELPER FUNCTIONS ################
## Set verbose to true -> ```litellm.set_verbose = True```
def print_verbose(print_statement):