Merge branch 'main' into main

This commit is contained in:
mogith-pn 2024-04-30 22:48:52 +05:30 committed by GitHub
commit d770df2259
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
198 changed files with 10972 additions and 7448 deletions

View file

@ -12,7 +12,6 @@ from typing import Any, Literal, Union, BinaryIO
from functools import partial
import dotenv, traceback, random, asyncio, time, contextvars
from copy import deepcopy
import httpx
import litellm
from ._logging import verbose_logger
@ -64,6 +63,7 @@ from .llms import (
vertex_ai,
vertex_ai_anthropic,
maritalk,
watsonx,
)
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
from .llms.azure import AzureChatCompletion
@ -343,6 +343,7 @@ async def acompletion(
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=completion_kwargs,
extra_kwargs=kwargs,
)
@ -408,8 +409,10 @@ def mock_completion(
model_response["created"] = int(time.time())
model_response["model"] = model
model_response.usage = Usage(
prompt_tokens=10, completion_tokens=20, total_tokens=30
setattr(
model_response,
"usage",
Usage(prompt_tokens=10, completion_tokens=20, total_tokens=30),
)
try:
@ -609,6 +612,7 @@ def completion(
"client",
"rpm",
"tpm",
"max_parallel_requests",
"input_cost_per_token",
"output_cost_per_token",
"input_cost_per_second",
@ -652,6 +656,7 @@ def completion(
model
] # update the model to the actual value if an alias has been passed in
model_response = ModelResponse()
setattr(model_response, "usage", litellm.Usage())
if (
kwargs.get("azure", False) == True
): # don't remove flag check, to remain backwards compatible for repos like Codium
@ -1732,13 +1737,14 @@ def completion(
or optional_params.pop("vertex_ai_credentials", None)
or get_secret("VERTEXAI_CREDENTIALS")
)
new_params = deepcopy(optional_params)
if "claude-3" in model:
model_response = vertex_ai_anthropic.completion(
model=model,
messages=messages,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
optional_params=new_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
@ -1754,7 +1760,7 @@ def completion(
messages=messages,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
optional_params=new_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
@ -1907,6 +1913,43 @@ def completion(
## RESPONSE OBJECT
response = response
elif custom_llm_provider == "watsonx":
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
response = watsonx.IBMWatsonXAI().completion(
model=model,
messages=messages,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params, # type: ignore
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
timeout=timeout,
)
if (
"stream" in optional_params
and optional_params["stream"] == True
and not isinstance(response, CustomStreamWrapper)
):
# don't try to access stream object,
response = CustomStreamWrapper(
iter(response),
model,
custom_llm_provider="watsonx",
logging_obj=logging,
)
if optional_params.get("stream", False):
## LOGGING
logging.post_call(
input=messages,
api_key=None,
original_response=response,
)
## RESPONSE OBJECT
response = response
elif custom_llm_provider == "vllm":
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
model_response = vllm.completion(
@ -1990,9 +2033,16 @@ def completion(
or "http://localhost:11434"
)
api_key = (
api_key
or litellm.ollama_key
or os.environ.get("OLLAMA_API_KEY")
or litellm.api_key
)
## LOGGING
generator = ollama_chat.get_ollama_response(
api_base,
api_key,
model,
messages,
optional_params,
@ -2188,6 +2238,7 @@ def completion(
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=args,
extra_kwargs=kwargs,
)
@ -2549,6 +2600,7 @@ async def aembedding(*args, **kwargs):
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=args,
extra_kwargs=kwargs,
)
@ -2600,6 +2652,7 @@ def embedding(
client = kwargs.pop("client", None)
rpm = kwargs.pop("rpm", None)
tpm = kwargs.pop("tpm", None)
max_parallel_requests = kwargs.pop("max_parallel_requests", None)
model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", None)
encoding_format = kwargs.get("encoding_format", None)
@ -2657,6 +2710,7 @@ def embedding(
"client",
"rpm",
"tpm",
"max_parallel_requests",
"input_cost_per_token",
"output_cost_per_token",
"input_cost_per_second",
@ -2975,6 +3029,15 @@ def embedding(
client=client,
aembedding=aembedding,
)
elif custom_llm_provider == "watsonx":
response = watsonx.IBMWatsonXAI().embedding(
model=model,
input=input,
encoding=encoding,
logging_obj=logging,
optional_params=optional_params,
model_response=EmbeddingResponse(),
)
else:
args = locals()
raise ValueError(f"No valid embedding model args passed in - {args}")
@ -2990,7 +3053,10 @@ def embedding(
)
## Map to OpenAI Exception
raise exception_type(
model=model, original_exception=e, custom_llm_provider=custom_llm_provider
model=model,
original_exception=e,
custom_llm_provider=custom_llm_provider,
extra_kwargs=kwargs,
)
@ -3084,6 +3150,7 @@ async def atext_completion(*args, **kwargs):
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=args,
extra_kwargs=kwargs,
)
@ -3421,6 +3488,7 @@ async def aimage_generation(*args, **kwargs):
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=args,
extra_kwargs=kwargs,
)
@ -3511,6 +3579,7 @@ def image_generation(
"client",
"rpm",
"tpm",
"max_parallel_requests",
"input_cost_per_token",
"output_cost_per_token",
"hf_model_name",
@ -3620,6 +3689,7 @@ def image_generation(
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=locals(),
extra_kwargs=kwargs,
)
@ -3669,6 +3739,7 @@ async def atranscription(*args, **kwargs):
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=args,
extra_kwargs=kwargs,
)