Merge branch 'main' into litellm_fix_httpx_transport

This commit is contained in:
Krish Dholakia 2024-07-06 19:12:06 -07:00 committed by GitHub
commit c6b6dbeb6b
142 changed files with 6725 additions and 2086 deletions

View file

@ -113,6 +113,7 @@ from .llms.prompt_templates.factory import (
function_call_prompt,
map_system_message_pt,
prompt_factory,
stringify_json_tool_call_content,
)
from .llms.text_completion_codestral import CodestralTextCompletion
from .llms.triton import TritonChatCompletion
@ -984,6 +985,7 @@ def completion(
mock_delay=kwargs.get("mock_delay", None),
custom_llm_provider=custom_llm_provider,
)
if custom_llm_provider == "azure":
# azure configs
api_type = get_secret("AZURE_API_TYPE") or "azure"
@ -1114,6 +1116,73 @@ def completion(
"api_base": api_base,
},
)
elif custom_llm_provider == "azure_ai":
api_base = (
api_base # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there
or litellm.api_base
or get_secret("AZURE_AI_API_BASE")
)
# set API KEY
api_key = (
api_key
or litellm.api_key # for deepinfra/perplexity/anyscale/friendliai we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or get_secret("AZURE_AI_API_KEY")
)
headers = headers or litellm.headers
## LOAD CONFIG - if set
config = litellm.OpenAIConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > openai_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
## FOR COHERE
if "command-r" in model: # make sure tool call in messages are str
messages = stringify_json_tool_call_content(messages=messages)
## COMPLETION CALL
try:
response = openai_chat_completions.completion(
model=model,
messages=messages,
headers=headers,
model_response=model_response,
print_verbose=print_verbose,
api_key=api_key,
api_base=api_base,
acompletion=acompletion,
logging_obj=logging,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
timeout=timeout, # type: ignore
custom_prompt_dict=custom_prompt_dict,
client=client, # pass AsyncOpenAI, OpenAI client
organization=organization,
custom_llm_provider=custom_llm_provider,
)
except Exception as e:
## LOGGING - log the original exception returned
logging.post_call(
input=messages,
api_key=api_key,
original_response=str(e),
additional_args={"headers": headers},
)
raise e
if optional_params.get("stream", False):
## LOGGING
logging.post_call(
input=messages,
api_key=api_key,
original_response=response,
additional_args={"headers": headers},
)
elif (
custom_llm_provider == "text-completion-openai"
or "ft:babbage-002" in model
@ -2008,6 +2077,8 @@ def completion(
vertex_credentials=vertex_credentials,
logging_obj=logging,
acompletion=acompletion,
headers=headers,
custom_prompt_dict=custom_prompt_dict,
)
else:
model_response = vertex_ai.completion(
@ -2026,18 +2097,18 @@ def completion(
acompletion=acompletion,
)
if (
"stream" in optional_params
and optional_params["stream"] == True
and acompletion == False
):
response = CustomStreamWrapper(
model_response,
model,
custom_llm_provider="vertex_ai",
logging_obj=logging,
)
return response
if (
"stream" in optional_params
and optional_params["stream"] == True
and acompletion == False
):
response = CustomStreamWrapper(
model_response,
model,
custom_llm_provider="vertex_ai",
logging_obj=logging,
)
return response
response = model_response
elif custom_llm_provider == "predibase":
tenant_id = (
@ -4297,6 +4368,8 @@ def transcription(
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
if dynamic_api_key is not None:
api_key = dynamic_api_key
optional_params = {
"language": language,
"prompt": prompt,
@ -4338,7 +4411,7 @@ def transcription(
azure_ad_token=azure_ad_token,
max_retries=max_retries,
)
elif custom_llm_provider == "openai":
elif custom_llm_provider == "openai" or custom_llm_provider == "groq":
api_base = (
api_base
or litellm.api_base
@ -4944,14 +5017,22 @@ def stream_chunk_builder(
else:
completion_output = ""
# # Update usage information if needed
prompt_tokens = 0
completion_tokens = 0
for chunk in chunks:
if "usage" in chunk:
if "prompt_tokens" in chunk["usage"]:
prompt_tokens = chunk["usage"].get("prompt_tokens", 0) or 0
if "completion_tokens" in chunk["usage"]:
completion_tokens = chunk["usage"].get("completion_tokens", 0) or 0
try:
response["usage"]["prompt_tokens"] = token_counter(
response["usage"]["prompt_tokens"] = prompt_tokens or token_counter(
model=model, messages=messages
)
except: # don't allow this failing to block a complete streaming response from being returned
print_verbose(f"token_counter failed, assuming prompt tokens is 0")
response["usage"]["prompt_tokens"] = 0
response["usage"]["completion_tokens"] = token_counter(
response["usage"]["completion_tokens"] = completion_tokens or token_counter(
model=model,
text=completion_output,
count_response_tokens=True, # count_response_tokens is a Flag to tell token counter this is a response, No need to add extra tokens we do for input messages