Merge branch 'main' into litellm_gemini_refactoring

This commit is contained in:
Krish Dholakia 2024-06-17 19:50:56 -07:00 committed by GitHub
commit 6f94456f40
27 changed files with 335 additions and 182 deletions

View file

@ -107,19 +107,17 @@ from .llms.databricks import DatabricksChatCompletion
from .llms.huggingface_restapi import Huggingface
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
from .llms.predibase import PredibaseChatCompletion
from .llms.bedrock_httpx import BedrockLLM, BedrockConverseLLM
from .llms.vertex_httpx import VertexLLM
from .llms.triton import TritonChatCompletion
from .llms.text_completion_codestral import CodestralTextCompletion
from .llms.prompt_templates.factory import (
custom_prompt,
function_call_prompt,
map_system_message_pt,
prompt_factory,
)
from .llms.text_completion_codestral import CodestralTextCompletion
from .llms.triton import TritonChatCompletion
from .llms.vertex_httpx import VertexLLM
from .types.llms.openai import HttpxBinaryResponseContent
from .types.utils import ChatCompletionMessageToolCall
encoding = tiktoken.get_encoding("cl100k_base")
from litellm.utils import (
@ -431,6 +429,7 @@ def mock_completion(
messages: List,
stream: Optional[bool] = False,
mock_response: Union[str, Exception] = "This is a mock request",
mock_tool_calls: Optional[List] = None,
logging=None,
custom_llm_provider=None,
**kwargs,
@ -499,6 +498,12 @@ def mock_completion(
model_response["created"] = int(time.time())
model_response["model"] = model
if mock_tool_calls:
model_response["choices"][0]["message"]["tool_calls"] = [
ChatCompletionMessageToolCall(**tool_call)
for tool_call in mock_tool_calls
]
setattr(
model_response,
"usage",
@ -612,6 +617,7 @@ def completion(
args = locals()
api_base = kwargs.get("api_base", None)
mock_response = kwargs.get("mock_response", None)
mock_tool_calls = kwargs.get("mock_tool_calls", None)
force_timeout = kwargs.get("force_timeout", 600) ## deprecated
logger_fn = kwargs.get("logger_fn", None)
verbose = kwargs.get("verbose", False)
@ -930,12 +936,13 @@ def completion(
litellm_params=litellm_params,
custom_llm_provider=custom_llm_provider,
)
if mock_response:
if mock_response or mock_tool_calls:
return mock_completion(
model,
messages,
stream=stream,
mock_response=mock_response,
mock_tool_calls=mock_tool_calls,
logging=logging,
acompletion=acompletion,
mock_delay=kwargs.get("mock_delay", None),