forked from phoenix/litellm-mirror
Merge branch 'main' into feature/watsonx-integration
This commit is contained in:
commit
e1372de9ee
23 changed files with 8026 additions and 271 deletions
19
.github/workflows/interpret_load_test.py
vendored
19
.github/workflows/interpret_load_test.py
vendored
|
@ -64,6 +64,11 @@ if __name__ == "__main__":
|
|||
) # Replace with your repository's username and name
|
||||
latest_release = repo.get_latest_release()
|
||||
print("got latest release: ", latest_release)
|
||||
print(latest_release.title)
|
||||
print(latest_release.tag_name)
|
||||
|
||||
release_version = latest_release.title
|
||||
|
||||
print("latest release body: ", latest_release.body)
|
||||
print("markdown table: ", markdown_table)
|
||||
|
||||
|
@ -74,8 +79,22 @@ if __name__ == "__main__":
|
|||
start_index = latest_release.body.find("Load Test LiteLLM Proxy Results")
|
||||
existing_release_body = latest_release.body[:start_index]
|
||||
|
||||
docker_run_command = f"""
|
||||
\n\n
|
||||
## Docker Run LiteLLM Proxy
|
||||
|
||||
```
|
||||
docker run \\
|
||||
-e STORE_MODEL_IN_DB=True \\
|
||||
-p 4000:4000 \\
|
||||
ghcr.io/berriai/litellm:main-{release_version}
|
||||
```
|
||||
"""
|
||||
print("docker run command: ", docker_run_command)
|
||||
|
||||
new_release_body = (
|
||||
existing_release_body
|
||||
+ docker_run_command
|
||||
+ "\n\n"
|
||||
+ "### Don't want to maintain your internal proxy? get in touch 🎉"
|
||||
+ "\nHosted Proxy Alpha: https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat"
|
||||
|
|
|
@ -16,11 +16,11 @@ repos:
|
|||
name: Check if files match
|
||||
entry: python3 ci_cd/check_files_match.py
|
||||
language: system
|
||||
# - repo: local
|
||||
# hooks:
|
||||
# - id: mypy
|
||||
# name: mypy
|
||||
# entry: python3 -m mypy --ignore-missing-imports
|
||||
# language: system
|
||||
# types: [python]
|
||||
# files: ^litellm/
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: mypy
|
||||
name: mypy
|
||||
entry: python3 -m mypy --ignore-missing-imports
|
||||
language: system
|
||||
types: [python]
|
||||
files: ^litellm/
|
|
@ -83,6 +83,7 @@ def completion(
|
|||
top_p: Optional[float] = None,
|
||||
n: Optional[int] = None,
|
||||
stream: Optional[bool] = None,
|
||||
stream_options: Optional[dict] = None,
|
||||
stop=None,
|
||||
max_tokens: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
|
@ -139,6 +140,10 @@ def completion(
|
|||
|
||||
- `stream`: *boolean or null (optional)* - If set to true, it sends partial message deltas. Tokens will be sent as they become available, with the stream terminated by a [DONE] message.
|
||||
|
||||
- `stream_options` *dict or null (optional)* - Options for streaming response. Only set this when you set `stream: true`
|
||||
|
||||
- `include_usage` *boolean (optional)* - If set, an additional chunk will be streamed before the data: [DONE] message. The usage field on this chunk shows the token usage statistics for the entire request, and the choices field will always be an empty array. All other chunks will also include a usage field, but with a null value.
|
||||
|
||||
- `stop`: *string/ array/ null (optional)* - Up to 4 sequences where the API will stop generating further tokens.
|
||||
|
||||
- `max_tokens`: *integer (optional)* - The maximum number of tokens to generate in the chat completion.
|
||||
|
|
83
docs/my-website/docs/proxy/customer_routing.md
Normal file
83
docs/my-website/docs/proxy/customer_routing.md
Normal file
|
@ -0,0 +1,83 @@
|
|||
# Region-based Routing
|
||||
|
||||
Route specific customers to eu-only models.
|
||||
|
||||
By specifying 'allowed_model_region' for a customer, LiteLLM will filter-out any models in a model group which is not in the allowed region (i.e. 'eu').
|
||||
|
||||
[**See Code**](https://github.com/BerriAI/litellm/blob/5eb12e30cc5faa73799ebc7e48fc86ebf449c879/litellm/router.py#L2938)
|
||||
|
||||
### 1. Create customer with region-specification
|
||||
|
||||
Use the litellm 'end-user' object for this.
|
||||
|
||||
End-users can be tracked / id'ed by passing the 'user' param to litellm in an openai chat completion/embedding call.
|
||||
|
||||
```bash
|
||||
curl -X POST --location 'http://0.0.0.0:4000/end_user/new' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"user_id" : "ishaan-jaff-45",
|
||||
"allowed_model_region": "eu", # 👈 SPECIFY ALLOWED REGION='eu'
|
||||
}'
|
||||
```
|
||||
|
||||
### 2. Add eu models to model-group
|
||||
|
||||
Add eu models to a model group. For azure models, litellm can automatically infer the region (no need to set it).
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: azure/gpt-35-turbo-eu # 👈 EU azure model
|
||||
api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
|
||||
api_key: os.environ/AZURE_EUROPE_API_KEY
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: azure/chatgpt-v-2
|
||||
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
||||
api_version: "2023-05-15"
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
|
||||
router_settings:
|
||||
enable_pre_call_checks: true # 👈 IMPORTANT
|
||||
```
|
||||
|
||||
Start the proxy
|
||||
|
||||
```yaml
|
||||
litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
### 3. Test it!
|
||||
|
||||
Make a simple chat completions call to the proxy. In the response headers, you should see the returned api base.
|
||||
|
||||
```bash
|
||||
curl -X POST --location 'http://localhost:4000/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--data '{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what is the meaning of the universe? 1234"
|
||||
}],
|
||||
"user": "ishaan-jaff-45" # 👈 USER ID
|
||||
}
|
||||
'
|
||||
```
|
||||
|
||||
Expected API Base in response headers
|
||||
|
||||
```
|
||||
x-litellm-api-base: "https://my-endpoint-europe-berri-992.openai.azure.com/"
|
||||
```
|
||||
|
||||
### FAQ
|
||||
|
||||
**What happens if there are no available models for that region?**
|
||||
|
||||
Since the router filters out models not in the specified region, it will return back as an error to the user, if no models in that region are available.
|
|
@ -50,6 +50,7 @@ const sidebars = {
|
|||
items: ["proxy/logging", "proxy/streaming_logging"],
|
||||
},
|
||||
"proxy/team_based_routing",
|
||||
"proxy/customer_routing",
|
||||
"proxy/ui",
|
||||
"proxy/cost_tracking",
|
||||
"proxy/token_auth",
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
### Hide pydantic namespace conflict warnings globally ###
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore", message=".*conflict with protected namespace.*")
|
||||
### INIT VARIABLES ###
|
||||
import threading, requests, os
|
||||
from typing import Callable, List, Optional, Dict, Union, Any, Literal
|
||||
|
@ -71,9 +74,11 @@ maritalk_key: Optional[str] = None
|
|||
ai21_key: Optional[str] = None
|
||||
ollama_key: Optional[str] = None
|
||||
openrouter_key: Optional[str] = None
|
||||
predibase_key: Optional[str] = None
|
||||
huggingface_key: Optional[str] = None
|
||||
vertex_project: Optional[str] = None
|
||||
vertex_location: Optional[str] = None
|
||||
predibase_tenant_id: Optional[str] = None
|
||||
togetherai_api_key: Optional[str] = None
|
||||
cloudflare_api_key: Optional[str] = None
|
||||
baseten_key: Optional[str] = None
|
||||
|
@ -532,6 +537,7 @@ provider_list: List = [
|
|||
"xinference",
|
||||
"fireworks_ai",
|
||||
"watsonx",
|
||||
"predibase",
|
||||
"custom", # custom apis
|
||||
]
|
||||
|
||||
|
@ -644,6 +650,7 @@ from .utils import (
|
|||
)
|
||||
from .llms.huggingface_restapi import HuggingfaceConfig
|
||||
from .llms.anthropic import AnthropicConfig
|
||||
from .llms.predibase import PredibaseConfig
|
||||
from .llms.anthropic_text import AnthropicTextConfig
|
||||
from .llms.replicate import ReplicateConfig
|
||||
from .llms.cohere import CohereConfig
|
||||
|
|
|
@ -322,9 +322,9 @@ class Huggingface(BaseLLM):
|
|||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
custom_prompt_dict={},
|
||||
acompletion: bool = False,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
):
|
||||
|
@ -399,10 +399,11 @@ class Huggingface(BaseLLM):
|
|||
data = {
|
||||
"inputs": prompt,
|
||||
"parameters": optional_params,
|
||||
"stream": (
|
||||
"stream": ( # type: ignore
|
||||
True
|
||||
if "stream" in optional_params
|
||||
and optional_params["stream"] == True
|
||||
and isinstance(optional_params["stream"], bool)
|
||||
and optional_params["stream"] == True # type: ignore
|
||||
else False
|
||||
),
|
||||
}
|
||||
|
@ -433,7 +434,7 @@ class Huggingface(BaseLLM):
|
|||
data = {
|
||||
"inputs": prompt,
|
||||
"parameters": inference_params,
|
||||
"stream": (
|
||||
"stream": ( # type: ignore
|
||||
True
|
||||
if "stream" in optional_params
|
||||
and optional_params["stream"] == True
|
||||
|
|
|
@ -530,6 +530,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
model=model,
|
||||
custom_llm_provider="openai",
|
||||
logging_obj=logging_obj,
|
||||
stream_options=data.get("stream_options", None),
|
||||
)
|
||||
return streamwrapper
|
||||
|
||||
|
@ -579,6 +580,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
model=model,
|
||||
custom_llm_provider="openai",
|
||||
logging_obj=logging_obj,
|
||||
stream_options=data.get("stream_options", None),
|
||||
)
|
||||
return streamwrapper
|
||||
except (
|
||||
|
@ -1203,6 +1205,7 @@ class OpenAITextCompletion(BaseLLM):
|
|||
model=model,
|
||||
custom_llm_provider="text-completion-openai",
|
||||
logging_obj=logging_obj,
|
||||
stream_options=data.get("stream_options", None),
|
||||
)
|
||||
|
||||
for chunk in streamwrapper:
|
||||
|
@ -1241,6 +1244,7 @@ class OpenAITextCompletion(BaseLLM):
|
|||
model=model,
|
||||
custom_llm_provider="text-completion-openai",
|
||||
logging_obj=logging_obj,
|
||||
stream_options=data.get("stream_options", None),
|
||||
)
|
||||
|
||||
async for transformed_chunk in streamwrapper:
|
||||
|
|
520
litellm/llms/predibase.py
Normal file
520
litellm/llms/predibase.py
Normal file
|
@ -0,0 +1,520 @@
|
|||
# What is this?
|
||||
## Controller file for Predibase Integration - https://predibase.com/
|
||||
|
||||
|
||||
import os, types
|
||||
import json
|
||||
from enum import Enum
|
||||
import requests, copy # type: ignore
|
||||
import time
|
||||
from typing import Callable, Optional, List, Literal, Union
|
||||
from litellm.utils import (
|
||||
ModelResponse,
|
||||
Usage,
|
||||
map_finish_reason,
|
||||
CustomStreamWrapper,
|
||||
Message,
|
||||
Choices,
|
||||
)
|
||||
import litellm
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
from .base import BaseLLM
|
||||
import httpx # type: ignore
|
||||
|
||||
|
||||
class PredibaseError(Exception):
|
||||
def __init__(
|
||||
self,
|
||||
status_code,
|
||||
message,
|
||||
request: Optional[httpx.Request] = None,
|
||||
response: Optional[httpx.Response] = None,
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
if request is not None:
|
||||
self.request = request
|
||||
else:
|
||||
self.request = httpx.Request(
|
||||
method="POST",
|
||||
url="https://docs.predibase.com/user-guide/inference/rest_api",
|
||||
)
|
||||
if response is not None:
|
||||
self.response = response
|
||||
else:
|
||||
self.response = httpx.Response(
|
||||
status_code=status_code, request=self.request
|
||||
)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class PredibaseConfig:
|
||||
"""
|
||||
Reference: https://docs.predibase.com/user-guide/inference/rest_api
|
||||
|
||||
"""
|
||||
|
||||
adapter_id: Optional[str] = None
|
||||
adapter_source: Optional[Literal["pbase", "hub", "s3"]] = None
|
||||
best_of: Optional[int] = None
|
||||
decoder_input_details: Optional[bool] = None
|
||||
details: bool = True # enables returning logprobs + best of
|
||||
max_new_tokens: int = (
|
||||
256 # openai default - requests hang if max_new_tokens not given
|
||||
)
|
||||
repetition_penalty: Optional[float] = None
|
||||
return_full_text: Optional[bool] = (
|
||||
False # by default don't return the input as part of the output
|
||||
)
|
||||
seed: Optional[int] = None
|
||||
stop: Optional[List[str]] = None
|
||||
temperature: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
top_p: Optional[int] = None
|
||||
truncate: Optional[int] = None
|
||||
typical_p: Optional[float] = None
|
||||
watermark: Optional[bool] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
best_of: Optional[int] = None,
|
||||
decoder_input_details: Optional[bool] = None,
|
||||
details: Optional[bool] = None,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
return_full_text: Optional[bool] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
truncate: Optional[int] = None,
|
||||
typical_p: Optional[float] = None,
|
||||
watermark: Optional[bool] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
|
||||
|
||||
|
||||
class PredibaseChatCompletion(BaseLLM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def _validate_environment(self, api_key: Optional[str], user_headers: dict) -> dict:
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"Missing Predibase API Key - A call is being made to predibase but no key is set either in the environment variables or via params"
|
||||
)
|
||||
headers = {
|
||||
"content-type": "application/json",
|
||||
"Authorization": "Bearer {}".format(api_key),
|
||||
}
|
||||
if user_headers is not None and isinstance(user_headers, dict):
|
||||
headers = {**headers, **user_headers}
|
||||
return headers
|
||||
|
||||
def output_parser(self, generated_text: str):
|
||||
"""
|
||||
Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens.
|
||||
|
||||
Initial issue that prompted this - https://github.com/BerriAI/litellm/issues/763
|
||||
"""
|
||||
chat_template_tokens = [
|
||||
"<|assistant|>",
|
||||
"<|system|>",
|
||||
"<|user|>",
|
||||
"<s>",
|
||||
"</s>",
|
||||
]
|
||||
for token in chat_template_tokens:
|
||||
if generated_text.strip().startswith(token):
|
||||
generated_text = generated_text.replace(token, "", 1)
|
||||
if generated_text.endswith(token):
|
||||
generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1]
|
||||
return generated_text
|
||||
|
||||
def process_response(
|
||||
self,
|
||||
model: str,
|
||||
response: Union[requests.Response, httpx.Response],
|
||||
model_response: ModelResponse,
|
||||
stream: bool,
|
||||
logging_obj: litellm.utils.Logging,
|
||||
optional_params: dict,
|
||||
api_key: str,
|
||||
data: dict,
|
||||
messages: list,
|
||||
print_verbose,
|
||||
encoding,
|
||||
) -> ModelResponse:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
print_verbose(f"raw model_response: {response.text}")
|
||||
## RESPONSE OBJECT
|
||||
try:
|
||||
completion_response = response.json()
|
||||
except:
|
||||
raise PredibaseError(
|
||||
message=response.text, status_code=response.status_code
|
||||
)
|
||||
if "error" in completion_response:
|
||||
raise PredibaseError(
|
||||
message=str(completion_response["error"]),
|
||||
status_code=response.status_code,
|
||||
)
|
||||
else:
|
||||
if (
|
||||
not isinstance(completion_response, dict)
|
||||
or "generated_text" not in completion_response
|
||||
):
|
||||
raise PredibaseError(
|
||||
status_code=422,
|
||||
message=f"response is not in expected format - {completion_response}",
|
||||
)
|
||||
|
||||
if len(completion_response["generated_text"]) > 0:
|
||||
model_response["choices"][0]["message"]["content"] = self.output_parser(
|
||||
completion_response["generated_text"]
|
||||
)
|
||||
## GETTING LOGPROBS + FINISH REASON
|
||||
if (
|
||||
"details" in completion_response
|
||||
and "tokens" in completion_response["details"]
|
||||
):
|
||||
model_response.choices[0].finish_reason = completion_response[
|
||||
"details"
|
||||
]["finish_reason"]
|
||||
sum_logprob = 0
|
||||
for token in completion_response["details"]["tokens"]:
|
||||
if token["logprob"] != None:
|
||||
sum_logprob += token["logprob"]
|
||||
model_response["choices"][0][
|
||||
"message"
|
||||
]._logprob = (
|
||||
sum_logprob # [TODO] move this to using the actual logprobs
|
||||
)
|
||||
if "best_of" in optional_params and optional_params["best_of"] > 1:
|
||||
if (
|
||||
"details" in completion_response
|
||||
and "best_of_sequences" in completion_response["details"]
|
||||
):
|
||||
choices_list = []
|
||||
for idx, item in enumerate(
|
||||
completion_response["details"]["best_of_sequences"]
|
||||
):
|
||||
sum_logprob = 0
|
||||
for token in item["tokens"]:
|
||||
if token["logprob"] != None:
|
||||
sum_logprob += token["logprob"]
|
||||
if len(item["generated_text"]) > 0:
|
||||
message_obj = Message(
|
||||
content=self.output_parser(item["generated_text"]),
|
||||
logprobs=sum_logprob,
|
||||
)
|
||||
else:
|
||||
message_obj = Message(content=None)
|
||||
choice_obj = Choices(
|
||||
finish_reason=item["finish_reason"],
|
||||
index=idx + 1,
|
||||
message=message_obj,
|
||||
)
|
||||
choices_list.append(choice_obj)
|
||||
model_response["choices"].extend(choices_list)
|
||||
|
||||
## CALCULATING USAGE
|
||||
prompt_tokens = 0
|
||||
try:
|
||||
prompt_tokens = len(
|
||||
encoding.encode(model_response["choices"][0]["message"]["content"])
|
||||
) ##[TODO] use a model-specific tokenizer here
|
||||
except:
|
||||
# this should remain non blocking we should not block a response returning if calculating usage fails
|
||||
pass
|
||||
output_text = model_response["choices"][0]["message"].get("content", "")
|
||||
if output_text is not None and len(output_text) > 0:
|
||||
completion_tokens = 0
|
||||
try:
|
||||
completion_tokens = len(
|
||||
encoding.encode(
|
||||
model_response["choices"][0]["message"].get("content", "")
|
||||
)
|
||||
) ##[TODO] use a model-specific tokenizer
|
||||
except:
|
||||
# this should remain non blocking we should not block a response returning if calculating usage fails
|
||||
pass
|
||||
else:
|
||||
completion_tokens = 0
|
||||
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
model_response.usage = usage # type: ignore
|
||||
return model_response
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key: str,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
tenant_id: str,
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers: dict = {},
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
headers = self._validate_environment(api_key, headers)
|
||||
completion_url = ""
|
||||
input_text = ""
|
||||
base_url = "https://serving.app.predibase.com"
|
||||
if "https" in model:
|
||||
completion_url = model
|
||||
elif api_base:
|
||||
base_url = api_base
|
||||
elif "PREDIBASE_API_BASE" in os.environ:
|
||||
base_url = os.getenv("PREDIBASE_API_BASE", "")
|
||||
|
||||
completion_url = f"{base_url}/{tenant_id}/deployments/v2/llms/{model}"
|
||||
|
||||
if optional_params.get("stream", False) == True:
|
||||
completion_url += "/generate_stream"
|
||||
else:
|
||||
completion_url += "/generate"
|
||||
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
prompt = prompt_factory(model=model, messages=messages)
|
||||
|
||||
## Load Config
|
||||
config = litellm.PredibaseConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
stream = optional_params.pop("stream", False)
|
||||
|
||||
data = {
|
||||
"inputs": prompt,
|
||||
"parameters": optional_params,
|
||||
}
|
||||
input_text = prompt
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input_text,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"headers": headers,
|
||||
"api_base": completion_url,
|
||||
"acompletion": acompletion,
|
||||
},
|
||||
)
|
||||
## COMPLETION CALL
|
||||
if acompletion is True:
|
||||
### ASYNC STREAMING
|
||||
if stream == True:
|
||||
return self.async_streaming(
|
||||
model=model,
|
||||
messages=messages,
|
||||
data=data,
|
||||
api_base=completion_url,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
) # type: ignore
|
||||
else:
|
||||
### ASYNC COMPLETION
|
||||
return self.async_completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
data=data,
|
||||
api_base=completion_url,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
stream=False,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
) # type: ignore
|
||||
|
||||
### SYNC STREAMING
|
||||
if stream == True:
|
||||
response = requests.post(
|
||||
completion_url,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
stream=stream,
|
||||
)
|
||||
_response = CustomStreamWrapper(
|
||||
response.iter_lines(),
|
||||
model,
|
||||
custom_llm_provider="predibase",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return _response
|
||||
### SYNC COMPLETION
|
||||
else:
|
||||
response = requests.post(
|
||||
url=completion_url,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
)
|
||||
|
||||
return self.process_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
stream=optional_params.get("stream", False),
|
||||
logging_obj=logging_obj, # type: ignore
|
||||
optional_params=optional_params,
|
||||
api_key=api_key,
|
||||
data=data,
|
||||
messages=messages,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
async def async_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
stream,
|
||||
data: dict,
|
||||
optional_params: dict,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
) -> ModelResponse:
|
||||
self.async_handler = AsyncHTTPHandler(
|
||||
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
)
|
||||
response = await self.async_handler.post(
|
||||
api_base, headers=headers, data=json.dumps(data)
|
||||
)
|
||||
return self.process_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
stream=stream,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
data=data,
|
||||
messages=messages,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
async def async_streaming(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
data: dict,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
) -> CustomStreamWrapper:
|
||||
self.async_handler = AsyncHTTPHandler(
|
||||
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
)
|
||||
data["stream"] = True
|
||||
response = await self.async_handler.post(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise PredibaseError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
|
||||
completion_stream = response.aiter_lines()
|
||||
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="predibase",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return streamwrapper
|
||||
|
||||
def embedding(self, *args, **kwargs):
|
||||
pass
|
|
@ -487,7 +487,7 @@ def format_prompt_togetherai(messages, prompt_format, chat_template):
|
|||
|
||||
def ibm_granite_pt(messages: list):
|
||||
"""
|
||||
IBM's Granite chat models uses the template:
|
||||
IBM's Granite models uses the template:
|
||||
<|system|> {system_message} <|user|> {user_message} <|assistant|> {assistant_message}
|
||||
|
||||
See: https://www.ibm.com/docs/en/watsonx-as-a-service?topic=solutions-supported-foundation-models
|
||||
|
@ -503,13 +503,12 @@ def ibm_granite_pt(messages: list):
|
|||
"pre_message": "<|user|>\n",
|
||||
"post_message": "\n",
|
||||
},
|
||||
'assistant': {
|
||||
'pre_message': '<|assistant|>\n',
|
||||
'post_message': '\n',
|
||||
"assistant": {
|
||||
"pre_message": "<|assistant|>\n",
|
||||
"post_message": "\n",
|
||||
},
|
||||
},
|
||||
final_prompt_value='<|assistant|>\n',
|
||||
)
|
||||
).strip()
|
||||
|
||||
|
||||
### ANTHROPIC ###
|
||||
|
@ -1525,9 +1524,24 @@ def prompt_factory(
|
|||
return mistral_instruct_pt(messages=messages)
|
||||
elif "meta-llama/llama-3" in model and "instruct" in model:
|
||||
# https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3/
|
||||
return hf_chat_template(
|
||||
model="meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
return custom_prompt(
|
||||
role_dict={
|
||||
"system": {
|
||||
"pre_message": "<|start_header_id|>system<|end_header_id|>\n",
|
||||
"post_message": "<|eot_id|>",
|
||||
},
|
||||
"user": {
|
||||
"pre_message": "<|start_header_id|>user<|end_header_id|>\n",
|
||||
"post_message": "<|eot_id|>",
|
||||
},
|
||||
"assistant": {
|
||||
"pre_message": "<|start_header_id|>assistant<|end_header_id|>\n",
|
||||
"post_message": "<|eot_id|>",
|
||||
},
|
||||
},
|
||||
messages=messages,
|
||||
initial_prompt_value="<|begin_of_text|>",
|
||||
final_prompt_value="<|start_header_id|>assistant<|end_header_id|>\n",
|
||||
)
|
||||
try:
|
||||
if "meta-llama/llama-2" in model and "chat" in model:
|
||||
|
|
|
@ -451,9 +451,6 @@ class IBMWatsonXAI(BaseLLM):
|
|||
return streamwrapper
|
||||
|
||||
# create the function to manage the request to watsonx.ai
|
||||
# manage_request = self._make_request_manager(
|
||||
# async_=(acompletion is True), logging_obj=logging_obj
|
||||
# )
|
||||
self.request_manager = RequestManager(logging_obj)
|
||||
|
||||
def handle_text_request(request_params: dict) -> ModelResponse:
|
||||
|
@ -576,9 +573,6 @@ class IBMWatsonXAI(BaseLLM):
|
|||
"json": payload,
|
||||
"params": request_params,
|
||||
}
|
||||
# manage_request = self._make_request_manager(
|
||||
# async_=(aembedding is True), logging_obj=logging_obj
|
||||
# )
|
||||
request_manager = RequestManager(logging_obj)
|
||||
|
||||
def process_embedding_response(json_resp: dict) -> ModelResponse:
|
||||
|
@ -654,143 +648,12 @@ class IBMWatsonXAI(BaseLLM):
|
|||
request_params = dict(version=api_params["api_version"])
|
||||
url = api_params["url"].rstrip("/") + WatsonXAIEndpoint.AVAILABLE_MODELS
|
||||
req_params = dict(method="GET", url=url, headers=headers, params=request_params)
|
||||
# manage_request = self._make_request_manager(async_=False, logging_obj=None)
|
||||
with RequestManager(logging_obj=None).request(req_params) as resp:
|
||||
json_resp = resp.json()
|
||||
if not ids_only:
|
||||
return json_resp
|
||||
return [res["model_id"] for res in json_resp["resources"]]
|
||||
|
||||
def _make_request_manager(
|
||||
self, async_: bool, logging_obj=None
|
||||
) -> Callable[
|
||||
...,
|
||||
Union[ContextManager[requests.Response], AsyncContextManager[httpx.Response]],
|
||||
]:
|
||||
"""
|
||||
Returns a context manager that manages the response from the request.
|
||||
if async_ is True, returns an async context manager, otherwise returns a regular context manager.
|
||||
|
||||
Usage:
|
||||
```python
|
||||
manage_request = self._make_request_manager(async_=True, logging_obj=logging_obj)
|
||||
async with manage_request(request_params) as resp:
|
||||
...
|
||||
# or
|
||||
manage_request = self._make_request_manager(async_=False, logging_obj=logging_obj)
|
||||
with manage_request(request_params) as resp:
|
||||
...
|
||||
```
|
||||
"""
|
||||
|
||||
def pre_call(
|
||||
request_params: dict,
|
||||
input: Optional[Any] = None,
|
||||
):
|
||||
if logging_obj is None:
|
||||
return
|
||||
request_str = (
|
||||
f"response = {'await ' if async_ else ''}{request_params['method']}(\n"
|
||||
f"\turl={request_params['url']},\n"
|
||||
f"\tjson={request_params.get('json')},\n"
|
||||
f")"
|
||||
)
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
api_key=request_params["headers"].get("Authorization"),
|
||||
additional_args={
|
||||
"complete_input_dict": request_params.get("json"),
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
|
||||
def post_call(resp, request_params):
|
||||
if logging_obj is None:
|
||||
return
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=request_params["headers"].get("Authorization"),
|
||||
original_response=json.dumps(resp.json()),
|
||||
additional_args={
|
||||
"status_code": resp.status_code,
|
||||
"complete_input_dict": request_params.get(
|
||||
"data", request_params.get("json")
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def _manage_request(
|
||||
request_params: dict,
|
||||
stream: bool = False,
|
||||
input: Optional[Any] = None,
|
||||
timeout=None,
|
||||
) -> Generator[requests.Response, None, None]:
|
||||
"""
|
||||
Returns a context manager that yields the response from the request.
|
||||
"""
|
||||
pre_call(request_params, input)
|
||||
if timeout:
|
||||
request_params["timeout"] = timeout
|
||||
if stream:
|
||||
request_params["stream"] = stream
|
||||
try:
|
||||
resp = requests.request(**request_params)
|
||||
if not resp.ok:
|
||||
raise WatsonXAIError(
|
||||
status_code=resp.status_code,
|
||||
message=f"Error {resp.status_code} ({resp.reason}): {resp.text}",
|
||||
)
|
||||
yield resp
|
||||
except Exception as e:
|
||||
raise WatsonXAIError(status_code=500, message=str(e))
|
||||
if not stream:
|
||||
post_call(resp, request_params)
|
||||
|
||||
@asynccontextmanager
|
||||
async def _manage_request_async(
|
||||
request_params: dict,
|
||||
stream: bool = False,
|
||||
input: Optional[Any] = None,
|
||||
timeout=None,
|
||||
) -> AsyncGenerator[httpx.Response, None]:
|
||||
pre_call(request_params, input)
|
||||
if timeout:
|
||||
request_params["timeout"] = timeout
|
||||
if stream:
|
||||
request_params["stream"] = stream
|
||||
try:
|
||||
# async with AsyncHTTPHandler(timeout=timeout) as client:
|
||||
self.async_handler = AsyncHTTPHandler(
|
||||
timeout=httpx.Timeout(
|
||||
timeout=request_params.pop("timeout", 600.0), connect=5.0
|
||||
),
|
||||
)
|
||||
# async_handler.client.verify = False
|
||||
if "json" in request_params:
|
||||
request_params["data"] = json.dumps(request_params.pop("json", {}))
|
||||
method = request_params.pop("method")
|
||||
if method.upper() == "POST":
|
||||
resp = await self.async_handler.post(**request_params)
|
||||
else:
|
||||
resp = await self.async_handler.get(**request_params)
|
||||
if resp.status_code not in [200, 201]:
|
||||
raise WatsonXAIError(
|
||||
status_code=resp.status_code,
|
||||
message=f"Error {resp.status_code} ({resp.reason}): {resp.text}",
|
||||
)
|
||||
yield resp
|
||||
# await async_handler.close()
|
||||
except Exception as e:
|
||||
raise WatsonXAIError(status_code=500, message=str(e))
|
||||
if not stream:
|
||||
post_call(resp, request_params)
|
||||
|
||||
if async_:
|
||||
return _manage_request_async
|
||||
else:
|
||||
return _manage_request
|
||||
|
||||
class RequestManager:
|
||||
"""
|
||||
Returns a context manager that manages the response from the request.
|
||||
|
|
|
@ -14,6 +14,7 @@ import dotenv, traceback, random, asyncio, time, contextvars
|
|||
from copy import deepcopy
|
||||
import httpx
|
||||
import litellm
|
||||
|
||||
from ._logging import verbose_logger
|
||||
from litellm import ( # type: ignore
|
||||
client,
|
||||
|
@ -73,7 +74,7 @@ from .llms.azure_text import AzureTextCompletion
|
|||
from .llms.anthropic import AnthropicChatCompletion
|
||||
from .llms.anthropic_text import AnthropicTextCompletion
|
||||
from .llms.huggingface_restapi import Huggingface
|
||||
from .llms.watsonx import IBMWatsonXAI
|
||||
from .llms.predibase import PredibaseChatCompletion
|
||||
from .llms.prompt_templates.factory import (
|
||||
prompt_factory,
|
||||
custom_prompt,
|
||||
|
@ -110,7 +111,7 @@ anthropic_text_completions = AnthropicTextCompletion()
|
|||
azure_chat_completions = AzureChatCompletion()
|
||||
azure_text_completions = AzureTextCompletion()
|
||||
huggingface = Huggingface()
|
||||
watsonxai = IBMWatsonXAI()
|
||||
predibase_chat_completions = PredibaseChatCompletion()
|
||||
####### COMPLETION ENDPOINTS ################
|
||||
|
||||
|
||||
|
@ -189,6 +190,7 @@ async def acompletion(
|
|||
top_p: Optional[float] = None,
|
||||
n: Optional[int] = None,
|
||||
stream: Optional[bool] = None,
|
||||
stream_options: Optional[dict] = None,
|
||||
stop=None,
|
||||
max_tokens: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
|
@ -208,6 +210,7 @@ async def acompletion(
|
|||
api_version: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
model_list: Optional[list] = None, # pass in a list of api_base,keys, etc.
|
||||
extra_headers: Optional[dict] = None,
|
||||
# Optional liteLLM function params
|
||||
**kwargs,
|
||||
):
|
||||
|
@ -225,6 +228,7 @@ async def acompletion(
|
|||
top_p (float, optional): The top-p parameter for nucleus sampling (default is 1.0).
|
||||
n (int, optional): The number of completions to generate (default is 1).
|
||||
stream (bool, optional): If True, return a streaming response (default is False).
|
||||
stream_options (dict, optional): A dictionary containing options for the streaming response. Only use this if stream is True.
|
||||
stop(string/list, optional): - Up to 4 sequences where the LLM API will stop generating further tokens.
|
||||
max_tokens (integer, optional): The maximum number of tokens in the generated completion (default is infinity).
|
||||
presence_penalty (float, optional): It is used to penalize new tokens based on their existence in the text so far.
|
||||
|
@ -262,6 +266,7 @@ async def acompletion(
|
|||
"top_p": top_p,
|
||||
"n": n,
|
||||
"stream": stream,
|
||||
"stream_options": stream_options,
|
||||
"stop": stop,
|
||||
"max_tokens": max_tokens,
|
||||
"presence_penalty": presence_penalty,
|
||||
|
@ -315,7 +320,7 @@ async def acompletion(
|
|||
or custom_llm_provider == "gemini"
|
||||
or custom_llm_provider == "sagemaker"
|
||||
or custom_llm_provider == "anthropic"
|
||||
or custom_llm_provider == "watsonx"
|
||||
or custom_llm_provider == "predibase"
|
||||
or custom_llm_provider in litellm.openai_compatible_providers
|
||||
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
|
@ -460,6 +465,7 @@ def completion(
|
|||
top_p: Optional[float] = None,
|
||||
n: Optional[int] = None,
|
||||
stream: Optional[bool] = None,
|
||||
stream_options: Optional[dict] = None,
|
||||
stop=None,
|
||||
max_tokens: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
|
@ -499,6 +505,7 @@ def completion(
|
|||
top_p (float, optional): The top-p parameter for nucleus sampling (default is 1.0).
|
||||
n (int, optional): The number of completions to generate (default is 1).
|
||||
stream (bool, optional): If True, return a streaming response (default is False).
|
||||
stream_options (dict, optional): A dictionary containing options for the streaming response. Only set this when you set stream: true.
|
||||
stop(string/list, optional): - Up to 4 sequences where the LLM API will stop generating further tokens.
|
||||
max_tokens (integer, optional): The maximum number of tokens in the generated completion (default is infinity).
|
||||
presence_penalty (float, optional): It is used to penalize new tokens based on their existence in the text so far.
|
||||
|
@ -576,6 +583,7 @@ def completion(
|
|||
"top_p",
|
||||
"n",
|
||||
"stream",
|
||||
"stream_options",
|
||||
"stop",
|
||||
"max_tokens",
|
||||
"presence_penalty",
|
||||
|
@ -788,6 +796,7 @@ def completion(
|
|||
top_p=top_p,
|
||||
n=n,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
stop=stop,
|
||||
max_tokens=max_tokens,
|
||||
presence_penalty=presence_penalty,
|
||||
|
@ -1779,6 +1788,52 @@ def completion(
|
|||
)
|
||||
return response
|
||||
response = model_response
|
||||
elif custom_llm_provider == "predibase":
|
||||
tenant_id = (
|
||||
optional_params.pop("tenant_id", None)
|
||||
or optional_params.pop("predibase_tenant_id", None)
|
||||
or litellm.predibase_tenant_id
|
||||
or get_secret("PREDIBASE_TENANT_ID")
|
||||
)
|
||||
|
||||
api_base = (
|
||||
optional_params.pop("api_base", None)
|
||||
or optional_params.pop("base_url", None)
|
||||
or litellm.api_base
|
||||
or get_secret("PREDIBASE_API_BASE")
|
||||
)
|
||||
|
||||
api_key = (
|
||||
api_key
|
||||
or litellm.api_key
|
||||
or litellm.predibase_key
|
||||
or get_secret("PREDIBASE_API_KEY")
|
||||
)
|
||||
|
||||
_model_response = predibase_chat_completions.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
acompletion=acompletion,
|
||||
api_base=api_base,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
api_key=api_key,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
if (
|
||||
"stream" in optional_params
|
||||
and optional_params["stream"] == True
|
||||
and acompletion == False
|
||||
):
|
||||
return _model_response
|
||||
response = _model_response
|
||||
elif custom_llm_provider == "ai21":
|
||||
custom_llm_provider = "ai21"
|
||||
ai21_key = (
|
||||
|
@ -1911,7 +1966,7 @@ def completion(
|
|||
response = response
|
||||
elif custom_llm_provider == "watsonx":
|
||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||
response = watsonxai.completion(
|
||||
response = watsonx.IBMWatsonXAI().completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
|
@ -1922,8 +1977,7 @@ def completion(
|
|||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
acompletion=acompletion,
|
||||
timeout=timeout,
|
||||
timeout=timeout, # type: ignore
|
||||
)
|
||||
if (
|
||||
"stream" in optional_params
|
||||
|
@ -2576,7 +2630,6 @@ async def aembedding(*args, **kwargs):
|
|||
or custom_llm_provider == "fireworks_ai"
|
||||
or custom_llm_provider == "ollama"
|
||||
or custom_llm_provider == "vertex_ai"
|
||||
or custom_llm_provider == "watsonx"
|
||||
): # currently implemented aiohttp calls for just azure and openai, soon all.
|
||||
# Await normally
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
|
@ -3034,14 +3087,13 @@ def embedding(
|
|||
aembedding=aembedding,
|
||||
)
|
||||
elif custom_llm_provider == "watsonx":
|
||||
response = watsonxai.embedding(
|
||||
response = watsonx.IBMWatsonXAI().embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
optional_params=optional_params,
|
||||
model_response=EmbeddingResponse(),
|
||||
aembedding=aembedding,
|
||||
)
|
||||
else:
|
||||
args = locals()
|
||||
|
@ -3197,6 +3249,7 @@ def text_completion(
|
|||
Union[str, List[str]]
|
||||
] = None, # Optional: Sequences where the API will stop generating further tokens.
|
||||
stream: Optional[bool] = None, # Optional: Whether to stream back partial progress.
|
||||
stream_options: Optional[dict] = None,
|
||||
suffix: Optional[
|
||||
str
|
||||
] = None, # Optional: The suffix that comes after a completion of inserted text.
|
||||
|
@ -3274,6 +3327,8 @@ def text_completion(
|
|||
optional_params["stop"] = stop
|
||||
if stream is not None:
|
||||
optional_params["stream"] = stream
|
||||
if stream_options is not None:
|
||||
optional_params["stream_options"] = stream_options
|
||||
if suffix is not None:
|
||||
optional_params["suffix"] = suffix
|
||||
if temperature is not None:
|
||||
|
@ -3384,7 +3439,9 @@ def text_completion(
|
|||
if kwargs.get("acompletion", False) == True:
|
||||
return response
|
||||
if stream == True or kwargs.get("stream", False) == True:
|
||||
response = TextCompletionStreamWrapper(completion_stream=response, model=model)
|
||||
response = TextCompletionStreamWrapper(
|
||||
completion_stream=response, model=model, stream_options=stream_options
|
||||
)
|
||||
return response
|
||||
transformed_logprobs = None
|
||||
# only supported for TGI models
|
||||
|
|
|
@ -206,11 +206,9 @@ async def get_end_user_object(
|
|||
|
||||
if end_user_id is None:
|
||||
return None
|
||||
|
||||
_key = "end_user_id:{}".format(end_user_id)
|
||||
# check if in cache
|
||||
cached_user_obj = user_api_key_cache.async_get_cache(
|
||||
key="end_user_id:{}".format(end_user_id)
|
||||
)
|
||||
cached_user_obj = await user_api_key_cache.async_get_cache(key=_key)
|
||||
if cached_user_obj is not None:
|
||||
if isinstance(cached_user_obj, dict):
|
||||
return LiteLLM_EndUserTable(**cached_user_obj)
|
||||
|
|
|
@ -1086,9 +1086,7 @@ async def user_api_key_auth(
|
|||
user_id_information, list
|
||||
):
|
||||
_user = user_id_information[0]
|
||||
user_role = _user.get("user_role", {}).get(
|
||||
"user_role", "unknown"
|
||||
)
|
||||
user_role = _user.get("user_role", "unknown")
|
||||
user_id = _user.get("user_id", "unknown")
|
||||
raise Exception(
|
||||
f"Only proxy admin can be used to generate, delete, update info for new keys/users/teams. Route={route}. Your role={user_role}. Your user_id={user_id}"
|
||||
|
@ -1834,6 +1832,9 @@ async def update_cache(
|
|||
)
|
||||
|
||||
async def _update_end_user_cache():
|
||||
if end_user_id is None or response_cost is None:
|
||||
return
|
||||
|
||||
_id = "end_user_id:{}".format(end_user_id)
|
||||
try:
|
||||
# Fetch the existing cost for the given user
|
||||
|
@ -1846,7 +1847,7 @@ async def update_cache(
|
|||
if litellm.max_end_user_budget is not None:
|
||||
max_end_user_budget = litellm.max_end_user_budget
|
||||
existing_spend_obj = LiteLLM_EndUserTable(
|
||||
user_id=_id,
|
||||
user_id=end_user_id,
|
||||
spend=0,
|
||||
blocked=False,
|
||||
litellm_budget_table=LiteLLM_BudgetTable(
|
||||
|
@ -1874,7 +1875,7 @@ async def update_cache(
|
|||
existing_spend_obj.spend = new_spend
|
||||
user_api_key_cache.set_cache(key=_id, value=existing_spend_obj.json())
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
verbose_proxy_logger.error(
|
||||
f"An error occurred updating end user cache: {str(e)}\n\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
|
@ -7310,6 +7311,43 @@ async def unblock_team(
|
|||
return record
|
||||
|
||||
|
||||
@router.get(
|
||||
"/team/list", tags=["team management"], dependencies=[Depends(user_api_key_auth)]
|
||||
)
|
||||
async def list_team(
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
[Admin-only] List all available teams
|
||||
|
||||
```
|
||||
curl --location --request GET 'http://0.0.0.0:4000/team/list' \
|
||||
--header 'Authorization: Bearer sk-1234'
|
||||
```
|
||||
"""
|
||||
global prisma_client
|
||||
|
||||
if user_api_key_dict.user_role != "proxy_admin":
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={
|
||||
"error": "Admin-only endpoint. Your user role={}".format(
|
||||
user_api_key_dict.user_role
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
response = await prisma_client.db.litellm_teamtable.find_many()
|
||||
|
||||
return response
|
||||
|
||||
|
||||
#### ORGANIZATION MANAGEMENT ####
|
||||
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -1,5 +1,6 @@
|
|||
import pytest
|
||||
from litellm import acompletion
|
||||
from litellm import completion
|
||||
|
||||
|
||||
def test_acompletion_params():
|
||||
|
@ -7,17 +8,29 @@ def test_acompletion_params():
|
|||
from litellm.types.completion import CompletionRequest
|
||||
|
||||
acompletion_params_odict = inspect.signature(acompletion).parameters
|
||||
acompletion_params = {name: param.annotation for name, param in acompletion_params_odict.items()}
|
||||
completion_params = {field_name: field_type for field_name, field_type in CompletionRequest.__annotations__.items()}
|
||||
completion_params_dict = inspect.signature(completion).parameters
|
||||
|
||||
# remove kwargs
|
||||
acompletion_params.pop("kwargs", None)
|
||||
acompletion_params = {
|
||||
name: param.annotation for name, param in acompletion_params_odict.items()
|
||||
}
|
||||
completion_params = {
|
||||
name: param.annotation for name, param in completion_params_dict.items()
|
||||
}
|
||||
|
||||
keys_acompletion = set(acompletion_params.keys())
|
||||
keys_completion = set(completion_params.keys())
|
||||
|
||||
print(keys_acompletion)
|
||||
print("\n\n\n")
|
||||
print(keys_completion)
|
||||
|
||||
print("diff=", keys_completion - keys_acompletion)
|
||||
|
||||
# Assert that the parameters are the same
|
||||
if keys_acompletion != keys_completion:
|
||||
pytest.fail("The parameters of the acompletion function and the CompletionRequest class are not the same.")
|
||||
pytest.fail(
|
||||
"The parameters of the litellm.acompletion function and litellm.completion are not the same."
|
||||
)
|
||||
|
||||
|
||||
# test_acompletion_params()
|
||||
|
|
|
@ -85,6 +85,42 @@ def test_completion_azure_command_r():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# @pytest.mark.skip(reason="local test")
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_predibase(sync_mode):
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
|
||||
if sync_mode:
|
||||
response = completion(
|
||||
model="predibase/llama-3-8b-instruct",
|
||||
tenant_id="c4768f95",
|
||||
api_base="https://serving.app.predibase.com",
|
||||
api_key=os.getenv("PREDIBASE_API_KEY"),
|
||||
messages=[{"role": "user", "content": "What is the meaning of life?"}],
|
||||
)
|
||||
|
||||
print(response)
|
||||
else:
|
||||
response = await litellm.acompletion(
|
||||
model="predibase/llama-3-8b-instruct",
|
||||
tenant_id="c4768f95",
|
||||
api_base="https://serving.app.predibase.com",
|
||||
api_key=os.getenv("PREDIBASE_API_KEY"),
|
||||
messages=[{"role": "user", "content": "What is the meaning of life?"}],
|
||||
)
|
||||
|
||||
print(response)
|
||||
except litellm.Timeout as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# test_completion_predibase()
|
||||
|
||||
|
||||
def test_completion_claude():
|
||||
litellm.set_verbose = True
|
||||
litellm.cache = None
|
||||
|
|
|
@ -418,9 +418,16 @@ def test_call_with_user_over_budget(prisma_client):
|
|||
print(vars(e))
|
||||
|
||||
|
||||
def test_end_user_cache_write_unit_test():
|
||||
"""
|
||||
assert end user object is being written to cache as expected
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def test_call_with_end_user_over_budget(prisma_client):
|
||||
# Test if a user passed to /chat/completions is tracked & fails when they cross their budget
|
||||
# we only check this when litellm.max_user_budget is set
|
||||
# we only check this when litellm.max_end_user_budget is set
|
||||
import random
|
||||
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
|
|
|
@ -150,9 +150,9 @@ async def test_router_atext_completion_streaming():
|
|||
{
|
||||
"model_name": "azure-model",
|
||||
"litellm_params": {
|
||||
"model": "azure/gpt-35-turbo",
|
||||
"api_key": "os.environ/AZURE_EUROPE_API_KEY",
|
||||
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
|
||||
"model": "azure/gpt-turbo",
|
||||
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
|
||||
"api_base": "https://openai-france-1234.openai.azure.com",
|
||||
"rpm": 6,
|
||||
},
|
||||
"model_info": {"id": 2},
|
||||
|
@ -160,9 +160,9 @@ async def test_router_atext_completion_streaming():
|
|||
{
|
||||
"model_name": "azure-model",
|
||||
"litellm_params": {
|
||||
"model": "azure/gpt-35-turbo",
|
||||
"api_key": "os.environ/AZURE_CANADA_API_KEY",
|
||||
"api_base": "https://my-endpoint-canada-berri992.openai.azure.com",
|
||||
"model": "azure/gpt-turbo",
|
||||
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
|
||||
"api_base": "https://openai-france-1234.openai.azure.com",
|
||||
"rpm": 6,
|
||||
},
|
||||
"model_info": {"id": 3},
|
||||
|
@ -193,7 +193,7 @@ async def test_router_atext_completion_streaming():
|
|||
## check if calls equally distributed
|
||||
cache_dict = router.cache.get_cache(key=cache_key)
|
||||
for k, v in cache_dict.items():
|
||||
assert v == 1
|
||||
assert v == 1, f"Failed. K={k} called v={v} times, cache_dict={cache_dict}"
|
||||
|
||||
|
||||
# asyncio.run(test_router_atext_completion_streaming())
|
||||
|
|
|
@ -16,7 +16,7 @@ litellm.set_verbose = True
|
|||
model_alias_map = {"good-model": "anyscale/meta-llama/Llama-2-7b-chat-hf"}
|
||||
|
||||
|
||||
def test_model_alias_map():
|
||||
def test_model_alias_map(caplog):
|
||||
try:
|
||||
litellm.model_alias_map = model_alias_map
|
||||
response = completion(
|
||||
|
@ -27,9 +27,15 @@ def test_model_alias_map():
|
|||
max_tokens=10,
|
||||
)
|
||||
print(response.model)
|
||||
|
||||
captured_logs = [rec.levelname for rec in caplog.records]
|
||||
|
||||
for log in captured_logs:
|
||||
assert "ERROR" not in log
|
||||
|
||||
assert "Llama-2-7b-chat-hf" in response.model
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
test_model_alias_map()
|
||||
# test_model_alias_map()
|
||||
|
|
|
@ -5,6 +5,7 @@ import sys, os, asyncio
|
|||
import traceback
|
||||
import time, pytest
|
||||
from pydantic import BaseModel
|
||||
from typing import Tuple
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
|
@ -142,7 +143,7 @@ def validate_last_format(chunk):
|
|||
), "'finish_reason' should be a string."
|
||||
|
||||
|
||||
def streaming_format_tests(idx, chunk):
|
||||
def streaming_format_tests(idx, chunk) -> Tuple[str, bool]:
|
||||
extracted_chunk = ""
|
||||
finished = False
|
||||
print(f"chunk: {chunk}")
|
||||
|
@ -306,6 +307,70 @@ def test_completion_azure_stream():
|
|||
|
||||
|
||||
# test_completion_azure_stream()
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_predibase_streaming(sync_mode):
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
|
||||
if sync_mode:
|
||||
response = completion(
|
||||
model="predibase/llama-3-8b-instruct",
|
||||
tenant_id="c4768f95",
|
||||
api_base="https://serving.app.predibase.com",
|
||||
api_key=os.getenv("PREDIBASE_API_KEY"),
|
||||
messages=[{"role": "user", "content": "What is the meaning of life?"}],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
complete_response = ""
|
||||
for idx, init_chunk in enumerate(response):
|
||||
chunk, finished = streaming_format_tests(idx, init_chunk)
|
||||
complete_response += chunk
|
||||
custom_llm_provider = init_chunk._hidden_params["custom_llm_provider"]
|
||||
print(f"custom_llm_provider: {custom_llm_provider}")
|
||||
assert custom_llm_provider == "predibase"
|
||||
if finished:
|
||||
assert isinstance(
|
||||
init_chunk.choices[0], litellm.utils.StreamingChoices
|
||||
)
|
||||
break
|
||||
if complete_response.strip() == "":
|
||||
raise Exception("Empty response received")
|
||||
else:
|
||||
response = await litellm.acompletion(
|
||||
model="predibase/llama-3-8b-instruct",
|
||||
tenant_id="c4768f95",
|
||||
api_base="https://serving.app.predibase.com",
|
||||
api_key=os.getenv("PREDIBASE_API_KEY"),
|
||||
messages=[{"role": "user", "content": "What is the meaning of life?"}],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# await response
|
||||
|
||||
complete_response = ""
|
||||
idx = 0
|
||||
async for init_chunk in response:
|
||||
chunk, finished = streaming_format_tests(idx, init_chunk)
|
||||
complete_response += chunk
|
||||
custom_llm_provider = init_chunk._hidden_params["custom_llm_provider"]
|
||||
print(f"custom_llm_provider: {custom_llm_provider}")
|
||||
assert custom_llm_provider == "predibase"
|
||||
idx += 1
|
||||
if finished:
|
||||
assert isinstance(
|
||||
init_chunk.choices[0], litellm.utils.StreamingChoices
|
||||
)
|
||||
break
|
||||
if complete_response.strip() == "":
|
||||
raise Exception("Empty response received")
|
||||
|
||||
print(f"complete_response: {complete_response}")
|
||||
except litellm.Timeout as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_completion_azure_function_calling_stream():
|
||||
|
@ -1501,6 +1566,70 @@ def test_openai_chat_completion_complete_response_call():
|
|||
|
||||
|
||||
# test_openai_chat_completion_complete_response_call()
|
||||
def test_openai_stream_options_call():
|
||||
litellm.set_verbose = False
|
||||
response = litellm.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "system", "content": "say GM - we're going to make it "}],
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
max_tokens=10,
|
||||
)
|
||||
usage = None
|
||||
chunks = []
|
||||
for chunk in response:
|
||||
print("chunk: ", chunk)
|
||||
chunks.append(chunk)
|
||||
|
||||
last_chunk = chunks[-1]
|
||||
print("last chunk: ", last_chunk)
|
||||
|
||||
"""
|
||||
Assert that:
|
||||
- Last Chunk includes Usage
|
||||
- All chunks prior to last chunk have usage=None
|
||||
"""
|
||||
|
||||
assert last_chunk.usage is not None
|
||||
assert last_chunk.usage.total_tokens > 0
|
||||
assert last_chunk.usage.prompt_tokens > 0
|
||||
assert last_chunk.usage.completion_tokens > 0
|
||||
|
||||
# assert all non last chunks have usage=None
|
||||
assert all(chunk.usage is None for chunk in chunks[:-1])
|
||||
|
||||
|
||||
def test_openai_stream_options_call_text_completion():
|
||||
litellm.set_verbose = False
|
||||
response = litellm.text_completion(
|
||||
model="gpt-3.5-turbo-instruct",
|
||||
prompt="say GM - we're going to make it ",
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
max_tokens=10,
|
||||
)
|
||||
usage = None
|
||||
chunks = []
|
||||
for chunk in response:
|
||||
print("chunk: ", chunk)
|
||||
chunks.append(chunk)
|
||||
|
||||
last_chunk = chunks[-1]
|
||||
print("last chunk: ", last_chunk)
|
||||
|
||||
"""
|
||||
Assert that:
|
||||
- Last Chunk includes Usage
|
||||
- All chunks prior to last chunk have usage=None
|
||||
"""
|
||||
|
||||
assert last_chunk.usage is not None
|
||||
assert last_chunk.usage.total_tokens > 0
|
||||
assert last_chunk.usage.prompt_tokens > 0
|
||||
assert last_chunk.usage.completion_tokens > 0
|
||||
|
||||
# assert all non last chunks have usage=None
|
||||
assert all(chunk.usage is None for chunk in chunks[:-1])
|
||||
|
||||
|
||||
def test_openai_text_completion_call():
|
||||
|
|
127
litellm/utils.py
127
litellm/utils.py
|
@ -369,7 +369,7 @@ class ChatCompletionMessageToolCall(OpenAIObject):
|
|||
class Message(OpenAIObject):
|
||||
def __init__(
|
||||
self,
|
||||
content="default",
|
||||
content: Optional[str] = "default",
|
||||
role="assistant",
|
||||
logprobs=None,
|
||||
function_call=None,
|
||||
|
@ -612,6 +612,7 @@ class ModelResponse(OpenAIObject):
|
|||
system_fingerprint=None,
|
||||
usage=None,
|
||||
stream=None,
|
||||
stream_options=None,
|
||||
response_ms=None,
|
||||
hidden_params=None,
|
||||
**params,
|
||||
|
@ -658,6 +659,12 @@ class ModelResponse(OpenAIObject):
|
|||
usage = usage
|
||||
elif stream is None or stream == False:
|
||||
usage = Usage()
|
||||
elif (
|
||||
stream == True
|
||||
and stream_options is not None
|
||||
and stream_options.get("include_usage") == True
|
||||
):
|
||||
usage = Usage()
|
||||
if hidden_params:
|
||||
self._hidden_params = hidden_params
|
||||
|
||||
|
@ -4839,6 +4846,7 @@ def get_optional_params(
|
|||
top_p=None,
|
||||
n=None,
|
||||
stream=False,
|
||||
stream_options=None,
|
||||
stop=None,
|
||||
max_tokens=None,
|
||||
presence_penalty=None,
|
||||
|
@ -4908,6 +4916,7 @@ def get_optional_params(
|
|||
"top_p": None,
|
||||
"n": None,
|
||||
"stream": None,
|
||||
"stream_options": None,
|
||||
"stop": None,
|
||||
"max_tokens": None,
|
||||
"presence_penalty": None,
|
||||
|
@ -5779,6 +5788,8 @@ def get_optional_params(
|
|||
optional_params["n"] = n
|
||||
if stream is not None:
|
||||
optional_params["stream"] = stream
|
||||
if stream_options is not None:
|
||||
optional_params["stream_options"] = stream_options
|
||||
if stop is not None:
|
||||
optional_params["stop"] = stop
|
||||
if max_tokens is not None:
|
||||
|
@ -5927,13 +5938,15 @@ def get_api_base(model: str, optional_params: dict) -> Optional[str]:
|
|||
model=model, **optional_params
|
||||
) # convert to pydantic object
|
||||
except Exception as e:
|
||||
verbose_logger.error("Error occurred in getting api base - {}".format(str(e)))
|
||||
verbose_logger.debug("Error occurred in getting api base - {}".format(str(e)))
|
||||
return None
|
||||
# get llm provider
|
||||
|
||||
if _optional_params.api_base is not None:
|
||||
return _optional_params.api_base
|
||||
|
||||
if litellm.model_alias_map and model in litellm.model_alias_map:
|
||||
model = litellm.model_alias_map[model]
|
||||
try:
|
||||
model, custom_llm_provider, dynamic_api_key, dynamic_api_base = (
|
||||
get_llm_provider(
|
||||
|
@ -6083,6 +6096,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
|
|||
"top_p",
|
||||
"n",
|
||||
"stream",
|
||||
"stream_options",
|
||||
"stop",
|
||||
"max_tokens",
|
||||
"presence_penalty",
|
||||
|
@ -9500,7 +9514,12 @@ def get_secret(
|
|||
# replicate/anthropic/cohere
|
||||
class CustomStreamWrapper:
|
||||
def __init__(
|
||||
self, completion_stream, model, custom_llm_provider=None, logging_obj=None
|
||||
self,
|
||||
completion_stream,
|
||||
model,
|
||||
custom_llm_provider=None,
|
||||
logging_obj=None,
|
||||
stream_options=None,
|
||||
):
|
||||
self.model = model
|
||||
self.custom_llm_provider = custom_llm_provider
|
||||
|
@ -9526,6 +9545,7 @@ class CustomStreamWrapper:
|
|||
self.response_id = None
|
||||
self.logging_loop = None
|
||||
self.rules = Rules()
|
||||
self.stream_options = stream_options
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
@ -9737,6 +9757,50 @@ class CustomStreamWrapper:
|
|||
"finish_reason": finish_reason,
|
||||
}
|
||||
|
||||
def handle_predibase_chunk(self, chunk):
|
||||
try:
|
||||
if type(chunk) != str:
|
||||
chunk = chunk.decode(
|
||||
"utf-8"
|
||||
) # DO NOT REMOVE this: This is required for HF inference API + Streaming
|
||||
text = ""
|
||||
is_finished = False
|
||||
finish_reason = ""
|
||||
print_verbose(f"chunk: {chunk}")
|
||||
if chunk.startswith("data:"):
|
||||
data_json = json.loads(chunk[5:])
|
||||
print_verbose(f"data json: {data_json}")
|
||||
if "token" in data_json and "text" in data_json["token"]:
|
||||
text = data_json["token"]["text"]
|
||||
if data_json.get("details", False) and data_json["details"].get(
|
||||
"finish_reason", False
|
||||
):
|
||||
is_finished = True
|
||||
finish_reason = data_json["details"]["finish_reason"]
|
||||
elif data_json.get(
|
||||
"generated_text", False
|
||||
): # if full generated text exists, then stream is complete
|
||||
text = "" # don't return the final bos token
|
||||
is_finished = True
|
||||
finish_reason = "stop"
|
||||
elif data_json.get("error", False):
|
||||
raise Exception(data_json.get("error"))
|
||||
return {
|
||||
"text": text,
|
||||
"is_finished": is_finished,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
elif "error" in chunk:
|
||||
raise ValueError(chunk)
|
||||
return {
|
||||
"text": text,
|
||||
"is_finished": is_finished,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
|
||||
def handle_huggingface_chunk(self, chunk):
|
||||
try:
|
||||
if type(chunk) != str:
|
||||
|
@ -9966,6 +10030,7 @@ class CustomStreamWrapper:
|
|||
is_finished = False
|
||||
finish_reason = None
|
||||
logprobs = None
|
||||
usage = None
|
||||
original_chunk = None # this is used for function/tool calling
|
||||
if len(str_line.choices) > 0:
|
||||
if (
|
||||
|
@ -10000,12 +10065,15 @@ class CustomStreamWrapper:
|
|||
else:
|
||||
logprobs = None
|
||||
|
||||
usage = getattr(str_line, "usage", None)
|
||||
|
||||
return {
|
||||
"text": text,
|
||||
"is_finished": is_finished,
|
||||
"finish_reason": finish_reason,
|
||||
"logprobs": logprobs,
|
||||
"original_chunk": str_line,
|
||||
"usage": usage,
|
||||
}
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
@ -10038,16 +10106,19 @@ class CustomStreamWrapper:
|
|||
text = ""
|
||||
is_finished = False
|
||||
finish_reason = None
|
||||
usage = None
|
||||
choices = getattr(chunk, "choices", [])
|
||||
if len(choices) > 0:
|
||||
text = choices[0].text
|
||||
if choices[0].finish_reason is not None:
|
||||
is_finished = True
|
||||
finish_reason = choices[0].finish_reason
|
||||
usage = getattr(chunk, "usage", None)
|
||||
return {
|
||||
"text": text,
|
||||
"is_finished": is_finished,
|
||||
"finish_reason": finish_reason,
|
||||
"usage": usage,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
|
@ -10308,7 +10379,9 @@ class CustomStreamWrapper:
|
|||
raise e
|
||||
|
||||
def model_response_creator(self):
|
||||
model_response = ModelResponse(stream=True, model=self.model)
|
||||
model_response = ModelResponse(
|
||||
stream=True, model=self.model, stream_options=self.stream_options
|
||||
)
|
||||
if self.response_id is not None:
|
||||
model_response.id = self.response_id
|
||||
else:
|
||||
|
@ -10365,6 +10438,11 @@ class CustomStreamWrapper:
|
|||
completion_obj["content"] = response_obj["text"]
|
||||
if response_obj["is_finished"]:
|
||||
self.received_finish_reason = response_obj["finish_reason"]
|
||||
elif self.custom_llm_provider and self.custom_llm_provider == "predibase":
|
||||
response_obj = self.handle_predibase_chunk(chunk)
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
if response_obj["is_finished"]:
|
||||
self.received_finish_reason = response_obj["finish_reason"]
|
||||
elif (
|
||||
self.custom_llm_provider and self.custom_llm_provider == "baseten"
|
||||
): # baseten doesn't provide streaming
|
||||
|
@ -10567,18 +10645,6 @@ class CustomStreamWrapper:
|
|||
elif self.custom_llm_provider == "watsonx":
|
||||
response_obj = self.handle_watsonx_stream(chunk)
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
print_verbose(f"completion obj content: {completion_obj['content']}")
|
||||
if getattr(model_response, "usage", None) is None:
|
||||
model_response.usage = Usage()
|
||||
if response_obj.get("prompt_tokens") is not None:
|
||||
prompt_token_count = getattr(model_response.usage, "prompt_tokens", 0)
|
||||
model_response.usage.prompt_tokens = (prompt_token_count+response_obj["prompt_tokens"])
|
||||
if response_obj.get("completion_tokens") is not None:
|
||||
model_response.usage.completion_tokens = response_obj["completion_tokens"]
|
||||
model_response.usage.total_tokens = (
|
||||
getattr(model_response.usage, "prompt_tokens", 0)
|
||||
+ getattr(model_response.usage, "completion_tokens", 0)
|
||||
)
|
||||
if response_obj["is_finished"]:
|
||||
self.received_finish_reason = response_obj["finish_reason"]
|
||||
elif self.custom_llm_provider == "text-completion-openai":
|
||||
|
@ -10587,6 +10653,11 @@ class CustomStreamWrapper:
|
|||
print_verbose(f"completion obj content: {completion_obj['content']}")
|
||||
if response_obj["is_finished"]:
|
||||
self.received_finish_reason = response_obj["finish_reason"]
|
||||
if (
|
||||
self.stream_options
|
||||
and self.stream_options.get("include_usage", False) == True
|
||||
):
|
||||
model_response.usage = response_obj["usage"]
|
||||
elif self.custom_llm_provider == "azure_text":
|
||||
response_obj = self.handle_azure_text_completion_chunk(chunk)
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
|
@ -10640,6 +10711,12 @@ class CustomStreamWrapper:
|
|||
if response_obj["logprobs"] is not None:
|
||||
model_response.choices[0].logprobs = response_obj["logprobs"]
|
||||
|
||||
if (
|
||||
self.stream_options is not None
|
||||
and self.stream_options["include_usage"] == True
|
||||
):
|
||||
model_response.usage = response_obj["usage"]
|
||||
|
||||
model_response.model = self.model
|
||||
print_verbose(
|
||||
f"model_response finish reason 3: {self.received_finish_reason}; response_obj={response_obj}"
|
||||
|
@ -10727,6 +10804,11 @@ class CustomStreamWrapper:
|
|||
except Exception as e:
|
||||
model_response.choices[0].delta = Delta()
|
||||
else:
|
||||
if (
|
||||
self.stream_options is not None
|
||||
and self.stream_options["include_usage"] == True
|
||||
):
|
||||
return model_response
|
||||
return
|
||||
print_verbose(
|
||||
f"model_response.choices[0].delta: {model_response.choices[0].delta}; completion_obj: {completion_obj}"
|
||||
|
@ -10983,7 +11065,7 @@ class CustomStreamWrapper:
|
|||
or self.custom_llm_provider == "sagemaker"
|
||||
or self.custom_llm_provider == "gemini"
|
||||
or self.custom_llm_provider == "cached_response"
|
||||
or self.custom_llm_provider == "watsonx"
|
||||
or self.custom_llm_provider == "predibase"
|
||||
or self.custom_llm_provider in litellm.openai_compatible_endpoints
|
||||
):
|
||||
async for chunk in self.completion_stream:
|
||||
|
@ -11106,9 +11188,10 @@ class CustomStreamWrapper:
|
|||
|
||||
|
||||
class TextCompletionStreamWrapper:
|
||||
def __init__(self, completion_stream, model):
|
||||
def __init__(self, completion_stream, model, stream_options: Optional[dict] = None):
|
||||
self.completion_stream = completion_stream
|
||||
self.model = model
|
||||
self.stream_options = stream_options
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
@ -11132,6 +11215,14 @@ class TextCompletionStreamWrapper:
|
|||
text_choices["index"] = chunk["choices"][0]["index"]
|
||||
text_choices["finish_reason"] = chunk["choices"][0]["finish_reason"]
|
||||
response["choices"] = [text_choices]
|
||||
|
||||
# only pass usage when stream_options["include_usage"] is True
|
||||
if (
|
||||
self.stream_options
|
||||
and self.stream_options.get("include_usage", False) == True
|
||||
):
|
||||
response["usage"] = chunk.get("usage", None)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "litellm"
|
||||
version = "1.36.4"
|
||||
version = "1.37.0"
|
||||
description = "Library to easily interface with LLM API providers"
|
||||
authors = ["BerriAI"]
|
||||
license = "MIT"
|
||||
|
@ -80,7 +80,7 @@ requires = ["poetry-core", "wheel"]
|
|||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.commitizen]
|
||||
version = "1.36.4"
|
||||
version = "1.37.0"
|
||||
version_files = [
|
||||
"pyproject.toml:^version"
|
||||
]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue