Merge branch 'BerriAI:main' into main

This commit is contained in:
Sunny Wan 2025-03-13 19:37:22 -04:00 committed by GitHub
commit f9a5109203
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
317 changed files with 15980 additions and 5207 deletions

View file

@ -1159,6 +1159,18 @@ def completion( # type: ignore # noqa: PLR0915
prompt_id=prompt_id,
prompt_variables=prompt_variables,
ssl_verify=ssl_verify,
merge_reasoning_content_in_choices=kwargs.get(
"merge_reasoning_content_in_choices", None
),
api_version=api_version,
azure_ad_token=kwargs.get("azure_ad_token"),
tenant_id=kwargs.get("tenant_id"),
client_id=kwargs.get("client_id"),
client_secret=kwargs.get("client_secret"),
azure_username=kwargs.get("azure_username"),
azure_password=kwargs.get("azure_password"),
max_retries=max_retries,
timeout=timeout,
)
logging.update_environment_variables(
model=model,
@ -2271,23 +2283,22 @@ def completion( # type: ignore # noqa: PLR0915
data = {"model": model, "messages": messages, **optional_params}
## COMPLETION CALL
response = openai_like_chat_completion.completion(
response = base_llm_http_handler.completion(
model=model,
stream=stream,
messages=messages,
headers=headers,
api_key=api_key,
acompletion=acompletion,
api_base=api_base,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
logging_obj=logging,
acompletion=acompletion,
timeout=timeout, # type: ignore
custom_llm_provider="openrouter",
custom_prompt_dict=custom_prompt_dict,
timeout=timeout,
headers=headers,
encoding=encoding,
api_key=api_key,
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
client=client,
)
## LOGGING
logging.post_call(
@ -2853,6 +2864,7 @@ def completion( # type: ignore # noqa: PLR0915
acompletion=acompletion,
model_response=model_response,
encoding=encoding,
client=client,
)
if acompletion is True or optional_params.get("stream", False) is True:
return generator
@ -3380,6 +3392,7 @@ def embedding( # noqa: PLR0915
}
}
)
litellm_params_dict = get_litellm_params(**kwargs)
logging: Logging = litellm_logging_obj # type: ignore
@ -3441,6 +3454,7 @@ def embedding( # noqa: PLR0915
aembedding=aembedding,
max_retries=max_retries,
headers=headers or extra_headers,
litellm_params=litellm_params_dict,
)
elif (
model in litellm.open_ai_embedding_models
@ -3930,42 +3944,19 @@ async def atext_completion(
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
_, custom_llm_provider, _, _ = get_llm_provider(
model=model, api_base=kwargs.get("api_base", None)
)
if (
custom_llm_provider == "openai"
or custom_llm_provider == "azure"
or custom_llm_provider == "azure_text"
or custom_llm_provider == "custom_openai"
or custom_llm_provider == "anyscale"
or custom_llm_provider == "mistral"
or custom_llm_provider == "openrouter"
or custom_llm_provider == "deepinfra"
or custom_llm_provider == "perplexity"
or custom_llm_provider == "groq"
or custom_llm_provider == "nvidia_nim"
or custom_llm_provider == "cerebras"
or custom_llm_provider == "sambanova"
or custom_llm_provider == "ai21_chat"
or custom_llm_provider == "ai21"
or custom_llm_provider == "volcengine"
or custom_llm_provider == "text-completion-codestral"
or custom_llm_provider == "deepseek"
or custom_llm_provider == "text-completion-openai"
or custom_llm_provider == "huggingface"
or custom_llm_provider == "ollama"
or custom_llm_provider == "vertex_ai"
or custom_llm_provider in litellm.openai_compatible_providers
): # currently implemented aiohttp calls for just azure and openai, soon all.
# Await normally
response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(response):
response = await response
init_response = await loop.run_in_executor(None, func_with_context)
if isinstance(init_response, dict) or isinstance(
init_response, TextCompletionResponse
): ## CACHING SCENARIO
if isinstance(init_response, dict):
response = TextCompletionResponse(**init_response)
else:
response = init_response
elif asyncio.iscoroutine(init_response):
response = await init_response
else:
# Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context)
response = init_response # type: ignore
if (
kwargs.get("stream", False) is True
or isinstance(response, TextCompletionStreamWrapper)
@ -4554,6 +4545,7 @@ def image_generation( # noqa: PLR0915
non_default_params = {
k: v for k, v in kwargs.items() if k not in default_params
} # model-specific params - pass them straight to the model/provider
optional_params = get_optional_params_image_gen(
model=model,
n=n,
@ -4565,6 +4557,9 @@ def image_generation( # noqa: PLR0915
custom_llm_provider=custom_llm_provider,
**non_default_params,
)
litellm_params_dict = get_litellm_params(**kwargs)
logging: Logging = litellm_logging_obj
logging.update_environment_variables(
model=model,
@ -4635,6 +4630,7 @@ def image_generation( # noqa: PLR0915
aimg_generation=aimg_generation,
client=client,
headers=headers,
litellm_params=litellm_params_dict,
)
elif (
custom_llm_provider == "openai"
@ -4663,6 +4659,7 @@ def image_generation( # noqa: PLR0915
optional_params=optional_params,
model_response=model_response,
aimg_generation=aimg_generation,
client=client,
)
elif custom_llm_provider == "vertex_ai":
vertex_ai_project = (
@ -5029,6 +5026,7 @@ def transcription(
custom_llm_provider=custom_llm_provider,
drop_params=drop_params,
)
litellm_params_dict = get_litellm_params(**kwargs)
litellm_logging_obj.update_environment_variables(
model=model,
@ -5082,6 +5080,7 @@ def transcription(
api_version=api_version,
azure_ad_token=azure_ad_token,
max_retries=max_retries,
litellm_params=litellm_params_dict,
)
elif (
custom_llm_provider == "openai"
@ -5184,7 +5183,7 @@ async def aspeech(*args, **kwargs) -> HttpxBinaryResponseContent:
@client
def speech(
def speech( # noqa: PLR0915
model: str,
input: str,
voice: Optional[Union[str, dict]] = None,
@ -5225,7 +5224,7 @@ def speech(
if max_retries is None:
max_retries = litellm.num_retries or openai.DEFAULT_MAX_RETRIES
litellm_params_dict = get_litellm_params(**kwargs)
logging_obj = kwargs.get("litellm_logging_obj", None)
logging_obj.update_environment_variables(
model=model,
@ -5342,6 +5341,7 @@ def speech(
timeout=timeout,
client=client, # pass AsyncOpenAI, OpenAI client
aspeech=aspeech,
litellm_params=litellm_params_dict,
)
elif custom_llm_provider == "vertex_ai" or custom_llm_provider == "vertex_ai_beta":