LiteLLM Minor Fixes & Improvements (10/07/2024) (#6101)

* fix(utils.py): support dropping temperature param for azure o1 models

* fix(main.py): handle azure o1 streaming requests

o1 doesn't support streaming, fake it to ensure code works as expected

* feat(utils.py): expose `hosted_vllm/` endpoint, with tool handling for vllm

Fixes https://github.com/BerriAI/litellm/issues/6088

* refactor(internal_user_endpoints.py): cleanup unused params + update docstring

Closes https://github.com/BerriAI/litellm/issues/6100

* fix(main.py): expose custom image generation api support

Fixes https://github.com/BerriAI/litellm/issues/6097

* fix: fix linting errors

* docs(custom_llm_server.md): add docs on custom api for image gen calls

* fix(types/utils.py): handle dict type

* fix(types/utils.py): fix linting errors
This commit is contained in:
Krish Dholakia 2024-10-08 01:17:22 -04:00 committed by GitHub
parent 5de69cb1b2
commit 6729c9ca7f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 643 additions and 76 deletions

View file

@ -42,6 +42,7 @@ from litellm import ( # type: ignore
)
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.secret_managers.main import get_secret_str
from litellm.utils import (
CustomStreamWrapper,
@ -89,6 +90,7 @@ from .llms.azure_ai.embed import AzureAIEmbedding
from .llms.azure_text import AzureTextCompletion
from .llms.AzureOpenAI.audio_transcriptions import AzureAudioTranscription
from .llms.AzureOpenAI.azure import AzureChatCompletion, _check_dynamic_azure_params
from .llms.AzureOpenAI.chat.o1_handler import AzureOpenAIO1ChatCompletion
from .llms.bedrock import image_generation as bedrock_image_generation # type: ignore
from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
from .llms.bedrock.embed.embedding import BedrockEmbedding
@ -178,6 +180,7 @@ azure_ai_embedding = AzureAIEmbedding()
anthropic_chat_completions = AnthropicChatCompletion()
anthropic_text_completions = AnthropicTextCompletion()
azure_chat_completions = AzureChatCompletion()
azure_o1_chat_completions = AzureOpenAIO1ChatCompletion()
azure_text_completions = AzureTextCompletion()
azure_audio_transcriptions = AzureAudioTranscription()
huggingface = Huggingface()
@ -1064,35 +1067,68 @@ def completion( # type: ignore
headers = headers or litellm.headers
## LOAD CONFIG - if set
config = litellm.AzureOpenAIConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > azure_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
if (
litellm.enable_preview_features
and litellm.AzureOpenAIO1Config().is_o1_model(model=model)
):
## LOAD CONFIG - if set
config = litellm.AzureOpenAIO1Config.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > azure_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
## COMPLETION CALL
response = azure_chat_completions.completion(
model=model,
messages=messages,
headers=headers,
api_key=api_key,
api_base=api_base,
api_version=api_version,
api_type=api_type,
dynamic_params=dynamic_params,
azure_ad_token=azure_ad_token,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
logging_obj=logging,
acompletion=acompletion,
timeout=timeout, # type: ignore
client=client, # pass AsyncAzureOpenAI, AzureOpenAI client
)
response = azure_o1_chat_completions.completion(
model=model,
messages=messages,
headers=headers,
api_key=api_key,
api_base=api_base,
api_version=api_version,
api_type=api_type,
dynamic_params=dynamic_params,
azure_ad_token=azure_ad_token,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
logging_obj=logging,
acompletion=acompletion,
timeout=timeout, # type: ignore
client=client, # pass AsyncAzureOpenAI, AzureOpenAI client
)
else:
## LOAD CONFIG - if set
config = litellm.AzureOpenAIConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > azure_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
## COMPLETION CALL
response = azure_chat_completions.completion(
model=model,
messages=messages,
headers=headers,
api_key=api_key,
api_base=api_base,
api_version=api_version,
api_type=api_type,
dynamic_params=dynamic_params,
azure_ad_token=azure_ad_token,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
logging_obj=logging,
acompletion=acompletion,
timeout=timeout, # type: ignore
client=client, # pass AsyncAzureOpenAI, AzureOpenAI client
)
if optional_params.get("stream", False):
## LOGGING
@ -4582,6 +4618,7 @@ def image_generation(
Currently supports just Azure + OpenAI.
"""
try:
args = locals()
aimg_generation = kwargs.get("aimg_generation", False)
litellm_call_id = kwargs.get("litellm_call_id", None)
logger_fn = kwargs.get("logger_fn", None)
@ -4787,6 +4824,51 @@ def image_generation(
vertex_credentials=vertex_credentials,
aimg_generation=aimg_generation,
)
elif (
custom_llm_provider in litellm._custom_providers
): # Assume custom LLM provider
# Get the Custom Handler
custom_handler: Optional[CustomLLM] = None
for item in litellm.custom_provider_map:
if item["provider"] == custom_llm_provider:
custom_handler = item["custom_handler"]
if custom_handler is None:
raise ValueError(
f"Unable to map your input to a model. Check your input - {args}"
)
## ROUTE LLM CALL ##
if aimg_generation is True:
async_custom_client: Optional[AsyncHTTPHandler] = None
if client is not None and isinstance(client, AsyncHTTPHandler):
async_custom_client = client
## CALL FUNCTION
model_response = custom_handler.aimage_generation( # type: ignore
model=model,
prompt=prompt,
model_response=model_response,
optional_params=optional_params,
logging_obj=litellm_logging_obj,
timeout=timeout,
client=async_custom_client,
)
else:
custom_client: Optional[HTTPHandler] = None
if client is not None and isinstance(client, HTTPHandler):
custom_client = client
## CALL FUNCTION
model_response = custom_handler.image_generation(
model=model,
prompt=prompt,
model_response=model_response,
optional_params=optional_params,
logging_obj=litellm_logging_obj,
timeout=timeout,
client=custom_client,
)
return model_response
except Exception as e: