mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Merge branch 'main' into litellm_fix_httpx_transport
This commit is contained in:
commit
c6b6dbeb6b
142 changed files with 6725 additions and 2086 deletions
111
litellm/main.py
111
litellm/main.py
|
@ -113,6 +113,7 @@ from .llms.prompt_templates.factory import (
|
|||
function_call_prompt,
|
||||
map_system_message_pt,
|
||||
prompt_factory,
|
||||
stringify_json_tool_call_content,
|
||||
)
|
||||
from .llms.text_completion_codestral import CodestralTextCompletion
|
||||
from .llms.triton import TritonChatCompletion
|
||||
|
@ -984,6 +985,7 @@ def completion(
|
|||
mock_delay=kwargs.get("mock_delay", None),
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
if custom_llm_provider == "azure":
|
||||
# azure configs
|
||||
api_type = get_secret("AZURE_API_TYPE") or "azure"
|
||||
|
@ -1114,6 +1116,73 @@ def completion(
|
|||
"api_base": api_base,
|
||||
},
|
||||
)
|
||||
elif custom_llm_provider == "azure_ai":
|
||||
api_base = (
|
||||
api_base # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there
|
||||
or litellm.api_base
|
||||
or get_secret("AZURE_AI_API_BASE")
|
||||
)
|
||||
# set API KEY
|
||||
api_key = (
|
||||
api_key
|
||||
or litellm.api_key # for deepinfra/perplexity/anyscale/friendliai we check in get_llm_provider and pass in the api key from there
|
||||
or litellm.openai_key
|
||||
or get_secret("AZURE_AI_API_KEY")
|
||||
)
|
||||
|
||||
headers = headers or litellm.headers
|
||||
|
||||
## LOAD CONFIG - if set
|
||||
config = litellm.OpenAIConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > openai_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
## FOR COHERE
|
||||
if "command-r" in model: # make sure tool call in messages are str
|
||||
messages = stringify_json_tool_call_content(messages=messages)
|
||||
|
||||
## COMPLETION CALL
|
||||
try:
|
||||
response = openai_chat_completions.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
headers=headers,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
acompletion=acompletion,
|
||||
logging_obj=logging,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
timeout=timeout, # type: ignore
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
client=client, # pass AsyncOpenAI, OpenAI client
|
||||
organization=organization,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
except Exception as e:
|
||||
## LOGGING - log the original exception returned
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=str(e),
|
||||
additional_args={"headers": headers},
|
||||
)
|
||||
raise e
|
||||
|
||||
if optional_params.get("stream", False):
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=response,
|
||||
additional_args={"headers": headers},
|
||||
)
|
||||
elif (
|
||||
custom_llm_provider == "text-completion-openai"
|
||||
or "ft:babbage-002" in model
|
||||
|
@ -2008,6 +2077,8 @@ def completion(
|
|||
vertex_credentials=vertex_credentials,
|
||||
logging_obj=logging,
|
||||
acompletion=acompletion,
|
||||
headers=headers,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
)
|
||||
else:
|
||||
model_response = vertex_ai.completion(
|
||||
|
@ -2026,18 +2097,18 @@ def completion(
|
|||
acompletion=acompletion,
|
||||
)
|
||||
|
||||
if (
|
||||
"stream" in optional_params
|
||||
and optional_params["stream"] == True
|
||||
and acompletion == False
|
||||
):
|
||||
response = CustomStreamWrapper(
|
||||
model_response,
|
||||
model,
|
||||
custom_llm_provider="vertex_ai",
|
||||
logging_obj=logging,
|
||||
)
|
||||
return response
|
||||
if (
|
||||
"stream" in optional_params
|
||||
and optional_params["stream"] == True
|
||||
and acompletion == False
|
||||
):
|
||||
response = CustomStreamWrapper(
|
||||
model_response,
|
||||
model,
|
||||
custom_llm_provider="vertex_ai",
|
||||
logging_obj=logging,
|
||||
)
|
||||
return response
|
||||
response = model_response
|
||||
elif custom_llm_provider == "predibase":
|
||||
tenant_id = (
|
||||
|
@ -4297,6 +4368,8 @@ def transcription(
|
|||
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
|
||||
|
||||
if dynamic_api_key is not None:
|
||||
api_key = dynamic_api_key
|
||||
optional_params = {
|
||||
"language": language,
|
||||
"prompt": prompt,
|
||||
|
@ -4338,7 +4411,7 @@ def transcription(
|
|||
azure_ad_token=azure_ad_token,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
elif custom_llm_provider == "openai":
|
||||
elif custom_llm_provider == "openai" or custom_llm_provider == "groq":
|
||||
api_base = (
|
||||
api_base
|
||||
or litellm.api_base
|
||||
|
@ -4944,14 +5017,22 @@ def stream_chunk_builder(
|
|||
else:
|
||||
completion_output = ""
|
||||
# # Update usage information if needed
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
for chunk in chunks:
|
||||
if "usage" in chunk:
|
||||
if "prompt_tokens" in chunk["usage"]:
|
||||
prompt_tokens = chunk["usage"].get("prompt_tokens", 0) or 0
|
||||
if "completion_tokens" in chunk["usage"]:
|
||||
completion_tokens = chunk["usage"].get("completion_tokens", 0) or 0
|
||||
try:
|
||||
response["usage"]["prompt_tokens"] = token_counter(
|
||||
response["usage"]["prompt_tokens"] = prompt_tokens or token_counter(
|
||||
model=model, messages=messages
|
||||
)
|
||||
except: # don't allow this failing to block a complete streaming response from being returned
|
||||
print_verbose(f"token_counter failed, assuming prompt tokens is 0")
|
||||
response["usage"]["prompt_tokens"] = 0
|
||||
response["usage"]["completion_tokens"] = token_counter(
|
||||
response["usage"]["completion_tokens"] = completion_tokens or token_counter(
|
||||
model=model,
|
||||
text=completion_output,
|
||||
count_response_tokens=True, # count_response_tokens is a Flag to tell token counter this is a response, No need to add extra tokens we do for input messages
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue