Merge branch 'main' into feature/watsonx-integration

This commit is contained in:
Simon Sanchez Viloria 2024-05-10 12:09:09 +02:00
commit e1372de9ee
23 changed files with 8026 additions and 271 deletions

View file

@ -369,7 +369,7 @@ class ChatCompletionMessageToolCall(OpenAIObject):
class Message(OpenAIObject):
def __init__(
self,
content="default",
content: Optional[str] = "default",
role="assistant",
logprobs=None,
function_call=None,
@ -612,6 +612,7 @@ class ModelResponse(OpenAIObject):
system_fingerprint=None,
usage=None,
stream=None,
stream_options=None,
response_ms=None,
hidden_params=None,
**params,
@ -658,6 +659,12 @@ class ModelResponse(OpenAIObject):
usage = usage
elif stream is None or stream == False:
usage = Usage()
elif (
stream == True
and stream_options is not None
and stream_options.get("include_usage") == True
):
usage = Usage()
if hidden_params:
self._hidden_params = hidden_params
@ -4839,6 +4846,7 @@ def get_optional_params(
top_p=None,
n=None,
stream=False,
stream_options=None,
stop=None,
max_tokens=None,
presence_penalty=None,
@ -4908,6 +4916,7 @@ def get_optional_params(
"top_p": None,
"n": None,
"stream": None,
"stream_options": None,
"stop": None,
"max_tokens": None,
"presence_penalty": None,
@ -5779,6 +5788,8 @@ def get_optional_params(
optional_params["n"] = n
if stream is not None:
optional_params["stream"] = stream
if stream_options is not None:
optional_params["stream_options"] = stream_options
if stop is not None:
optional_params["stop"] = stop
if max_tokens is not None:
@ -5927,13 +5938,15 @@ def get_api_base(model: str, optional_params: dict) -> Optional[str]:
model=model, **optional_params
) # convert to pydantic object
except Exception as e:
verbose_logger.error("Error occurred in getting api base - {}".format(str(e)))
verbose_logger.debug("Error occurred in getting api base - {}".format(str(e)))
return None
# get llm provider
if _optional_params.api_base is not None:
return _optional_params.api_base
if litellm.model_alias_map and model in litellm.model_alias_map:
model = litellm.model_alias_map[model]
try:
model, custom_llm_provider, dynamic_api_key, dynamic_api_base = (
get_llm_provider(
@ -6083,6 +6096,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
"top_p",
"n",
"stream",
"stream_options",
"stop",
"max_tokens",
"presence_penalty",
@ -9500,7 +9514,12 @@ def get_secret(
# replicate/anthropic/cohere
class CustomStreamWrapper:
def __init__(
self, completion_stream, model, custom_llm_provider=None, logging_obj=None
self,
completion_stream,
model,
custom_llm_provider=None,
logging_obj=None,
stream_options=None,
):
self.model = model
self.custom_llm_provider = custom_llm_provider
@ -9526,6 +9545,7 @@ class CustomStreamWrapper:
self.response_id = None
self.logging_loop = None
self.rules = Rules()
self.stream_options = stream_options
def __iter__(self):
return self
@ -9737,6 +9757,50 @@ class CustomStreamWrapper:
"finish_reason": finish_reason,
}
def handle_predibase_chunk(self, chunk):
try:
if type(chunk) != str:
chunk = chunk.decode(
"utf-8"
) # DO NOT REMOVE this: This is required for HF inference API + Streaming
text = ""
is_finished = False
finish_reason = ""
print_verbose(f"chunk: {chunk}")
if chunk.startswith("data:"):
data_json = json.loads(chunk[5:])
print_verbose(f"data json: {data_json}")
if "token" in data_json and "text" in data_json["token"]:
text = data_json["token"]["text"]
if data_json.get("details", False) and data_json["details"].get(
"finish_reason", False
):
is_finished = True
finish_reason = data_json["details"]["finish_reason"]
elif data_json.get(
"generated_text", False
): # if full generated text exists, then stream is complete
text = "" # don't return the final bos token
is_finished = True
finish_reason = "stop"
elif data_json.get("error", False):
raise Exception(data_json.get("error"))
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
}
elif "error" in chunk:
raise ValueError(chunk)
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
}
except Exception as e:
traceback.print_exc()
raise e
def handle_huggingface_chunk(self, chunk):
try:
if type(chunk) != str:
@ -9966,6 +10030,7 @@ class CustomStreamWrapper:
is_finished = False
finish_reason = None
logprobs = None
usage = None
original_chunk = None # this is used for function/tool calling
if len(str_line.choices) > 0:
if (
@ -10000,12 +10065,15 @@ class CustomStreamWrapper:
else:
logprobs = None
usage = getattr(str_line, "usage", None)
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
"logprobs": logprobs,
"original_chunk": str_line,
"usage": usage,
}
except Exception as e:
traceback.print_exc()
@ -10038,16 +10106,19 @@ class CustomStreamWrapper:
text = ""
is_finished = False
finish_reason = None
usage = None
choices = getattr(chunk, "choices", [])
if len(choices) > 0:
text = choices[0].text
if choices[0].finish_reason is not None:
is_finished = True
finish_reason = choices[0].finish_reason
usage = getattr(chunk, "usage", None)
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
"usage": usage,
}
except Exception as e:
@ -10308,7 +10379,9 @@ class CustomStreamWrapper:
raise e
def model_response_creator(self):
model_response = ModelResponse(stream=True, model=self.model)
model_response = ModelResponse(
stream=True, model=self.model, stream_options=self.stream_options
)
if self.response_id is not None:
model_response.id = self.response_id
else:
@ -10365,6 +10438,11 @@ class CustomStreamWrapper:
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider and self.custom_llm_provider == "predibase":
response_obj = self.handle_predibase_chunk(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
elif (
self.custom_llm_provider and self.custom_llm_provider == "baseten"
): # baseten doesn't provide streaming
@ -10567,18 +10645,6 @@ class CustomStreamWrapper:
elif self.custom_llm_provider == "watsonx":
response_obj = self.handle_watsonx_stream(chunk)
completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}")
if getattr(model_response, "usage", None) is None:
model_response.usage = Usage()
if response_obj.get("prompt_tokens") is not None:
prompt_token_count = getattr(model_response.usage, "prompt_tokens", 0)
model_response.usage.prompt_tokens = (prompt_token_count+response_obj["prompt_tokens"])
if response_obj.get("completion_tokens") is not None:
model_response.usage.completion_tokens = response_obj["completion_tokens"]
model_response.usage.total_tokens = (
getattr(model_response.usage, "prompt_tokens", 0)
+ getattr(model_response.usage, "completion_tokens", 0)
)
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "text-completion-openai":
@ -10587,6 +10653,11 @@ class CustomStreamWrapper:
print_verbose(f"completion obj content: {completion_obj['content']}")
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
if (
self.stream_options
and self.stream_options.get("include_usage", False) == True
):
model_response.usage = response_obj["usage"]
elif self.custom_llm_provider == "azure_text":
response_obj = self.handle_azure_text_completion_chunk(chunk)
completion_obj["content"] = response_obj["text"]
@ -10640,6 +10711,12 @@ class CustomStreamWrapper:
if response_obj["logprobs"] is not None:
model_response.choices[0].logprobs = response_obj["logprobs"]
if (
self.stream_options is not None
and self.stream_options["include_usage"] == True
):
model_response.usage = response_obj["usage"]
model_response.model = self.model
print_verbose(
f"model_response finish reason 3: {self.received_finish_reason}; response_obj={response_obj}"
@ -10727,6 +10804,11 @@ class CustomStreamWrapper:
except Exception as e:
model_response.choices[0].delta = Delta()
else:
if (
self.stream_options is not None
and self.stream_options["include_usage"] == True
):
return model_response
return
print_verbose(
f"model_response.choices[0].delta: {model_response.choices[0].delta}; completion_obj: {completion_obj}"
@ -10983,7 +11065,7 @@ class CustomStreamWrapper:
or self.custom_llm_provider == "sagemaker"
or self.custom_llm_provider == "gemini"
or self.custom_llm_provider == "cached_response"
or self.custom_llm_provider == "watsonx"
or self.custom_llm_provider == "predibase"
or self.custom_llm_provider in litellm.openai_compatible_endpoints
):
async for chunk in self.completion_stream:
@ -11106,9 +11188,10 @@ class CustomStreamWrapper:
class TextCompletionStreamWrapper:
def __init__(self, completion_stream, model):
def __init__(self, completion_stream, model, stream_options: Optional[dict] = None):
self.completion_stream = completion_stream
self.model = model
self.stream_options = stream_options
def __iter__(self):
return self
@ -11132,6 +11215,14 @@ class TextCompletionStreamWrapper:
text_choices["index"] = chunk["choices"][0]["index"]
text_choices["finish_reason"] = chunk["choices"][0]["finish_reason"]
response["choices"] = [text_choices]
# only pass usage when stream_options["include_usage"] is True
if (
self.stream_options
and self.stream_options.get("include_usage", False) == True
):
response["usage"] = chunk.get("usage", None)
return response
except Exception as e:
raise Exception(