fix(utils.py): support 'drop_params' for 'parallel_tool_calls'

Closes https://github.com/BerriAI/litellm/issues/4584

 OpenAI-only param
This commit is contained in:
Krrish Dholakia 2024-07-08 07:36:41 -07:00
parent f889a7e4b0
commit a00a1267bc
5 changed files with 86 additions and 24 deletions

View file

@ -1,28 +1,31 @@
import json, types, time # noqa: E401
import asyncio
import json # noqa: E401
import time
import types
from contextlib import asynccontextmanager, contextmanager
from datetime import datetime
from enum import Enum
from contextlib import asynccontextmanager, contextmanager
from typing import (
Any,
AsyncContextManager,
AsyncGenerator,
AsyncIterator,
Callable,
ContextManager,
Dict,
Generator,
AsyncGenerator,
Iterator,
AsyncIterator,
Optional,
Any,
Union,
List,
ContextManager,
AsyncContextManager,
Optional,
Union,
)
import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm.utils import ModelResponse, Usage, get_secret
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.utils import ModelResponse, Usage, get_secret
from .base import BaseLLM
from .prompt_templates import factory as ptf
@ -287,8 +290,8 @@ class IBMWatsonXAI(BaseLLM):
)
def _get_api_params(
self,
params: dict,
self,
params: dict,
print_verbose: Optional[Callable] = None,
generate_token: Optional[bool] = True,
) -> dict:
@ -440,7 +443,7 @@ class IBMWatsonXAI(BaseLLM):
acompletion=None,
litellm_params=None,
logger_fn=None,
timeout=None
timeout=None,
):
"""
Send a text generation request to the IBM Watsonx.ai API.
@ -547,7 +550,9 @@ class IBMWatsonXAI(BaseLLM):
except Exception as e:
raise WatsonXAIError(status_code=500, message=str(e))
def _process_embedding_response(self, json_resp: dict, model_response:Union[ModelResponse,None]=None) -> ModelResponse:
def _process_embedding_response(
self, json_resp: dict, model_response: Union[ModelResponse, None] = None
) -> ModelResponse:
if model_response is None:
model_response = ModelResponse(model=json_resp.get("model_id", None))
results = json_resp.get("results", [])
@ -563,10 +568,14 @@ class IBMWatsonXAI(BaseLLM):
model_response["object"] = "list"
model_response["data"] = embedding_response
input_tokens = json_resp.get("input_token_count", 0)
model_response.usage = Usage(
prompt_tokens=input_tokens,
completion_tokens=0,
total_tokens=input_tokens,
setattr(
model_response,
"usage",
Usage(
prompt_tokens=input_tokens,
completion_tokens=0,
total_tokens=input_tokens,
),
)
return model_response
@ -580,7 +589,7 @@ class IBMWatsonXAI(BaseLLM):
optional_params=None,
encoding=None,
print_verbose=None,
aembedding=None
aembedding=None,
):
"""
Send a text embedding request to the IBM Watsonx.ai API.
@ -593,8 +602,8 @@ class IBMWatsonXAI(BaseLLM):
if k not in optional_params:
optional_params[k] = v
model_response['model'] = model
model_response["model"] = model
# Load auth variables from environment variables
if isinstance(input, str):
input = [input]
@ -685,6 +694,7 @@ class IBMWatsonXAI(BaseLLM):
return json_resp
return [res["model_id"] for res in json_resp["resources"]]
class RequestManager:
"""
A class to handle sync/async HTTP requests to the IBM Watsonx.ai API.
@ -744,7 +754,7 @@ class RequestManager:
@contextmanager
def request(
self,
self,
request_params: dict,
stream: bool = False,
input: Optional[Any] = None,