Merge branch 'main' into feature/watsonx-integration

This commit is contained in:
Simon Sanchez Viloria 2024-05-10 12:09:09 +02:00
commit e1372de9ee
23 changed files with 8026 additions and 271 deletions

View file

@ -64,6 +64,11 @@ if __name__ == "__main__":
) # Replace with your repository's username and name ) # Replace with your repository's username and name
latest_release = repo.get_latest_release() latest_release = repo.get_latest_release()
print("got latest release: ", latest_release) print("got latest release: ", latest_release)
print(latest_release.title)
print(latest_release.tag_name)
release_version = latest_release.title
print("latest release body: ", latest_release.body) print("latest release body: ", latest_release.body)
print("markdown table: ", markdown_table) print("markdown table: ", markdown_table)
@ -74,8 +79,22 @@ if __name__ == "__main__":
start_index = latest_release.body.find("Load Test LiteLLM Proxy Results") start_index = latest_release.body.find("Load Test LiteLLM Proxy Results")
existing_release_body = latest_release.body[:start_index] existing_release_body = latest_release.body[:start_index]
docker_run_command = f"""
\n\n
## Docker Run LiteLLM Proxy
```
docker run \\
-e STORE_MODEL_IN_DB=True \\
-p 4000:4000 \\
ghcr.io/berriai/litellm:main-{release_version}
```
"""
print("docker run command: ", docker_run_command)
new_release_body = ( new_release_body = (
existing_release_body existing_release_body
+ docker_run_command
+ "\n\n" + "\n\n"
+ "### Don't want to maintain your internal proxy? get in touch 🎉" + "### Don't want to maintain your internal proxy? get in touch 🎉"
+ "\nHosted Proxy Alpha: https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat" + "\nHosted Proxy Alpha: https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat"

View file

@ -16,11 +16,11 @@ repos:
name: Check if files match name: Check if files match
entry: python3 ci_cd/check_files_match.py entry: python3 ci_cd/check_files_match.py
language: system language: system
# - repo: local - repo: local
# hooks: hooks:
# - id: mypy - id: mypy
# name: mypy name: mypy
# entry: python3 -m mypy --ignore-missing-imports entry: python3 -m mypy --ignore-missing-imports
# language: system language: system
# types: [python] types: [python]
# files: ^litellm/ files: ^litellm/

View file

@ -83,6 +83,7 @@ def completion(
top_p: Optional[float] = None, top_p: Optional[float] = None,
n: Optional[int] = None, n: Optional[int] = None,
stream: Optional[bool] = None, stream: Optional[bool] = None,
stream_options: Optional[dict] = None,
stop=None, stop=None,
max_tokens: Optional[int] = None, max_tokens: Optional[int] = None,
presence_penalty: Optional[float] = None, presence_penalty: Optional[float] = None,
@ -139,6 +140,10 @@ def completion(
- `stream`: *boolean or null (optional)* - If set to true, it sends partial message deltas. Tokens will be sent as they become available, with the stream terminated by a [DONE] message. - `stream`: *boolean or null (optional)* - If set to true, it sends partial message deltas. Tokens will be sent as they become available, with the stream terminated by a [DONE] message.
- `stream_options` *dict or null (optional)* - Options for streaming response. Only set this when you set `stream: true`
- `include_usage` *boolean (optional)* - If set, an additional chunk will be streamed before the data: [DONE] message. The usage field on this chunk shows the token usage statistics for the entire request, and the choices field will always be an empty array. All other chunks will also include a usage field, but with a null value.
- `stop`: *string/ array/ null (optional)* - Up to 4 sequences where the API will stop generating further tokens. - `stop`: *string/ array/ null (optional)* - Up to 4 sequences where the API will stop generating further tokens.
- `max_tokens`: *integer (optional)* - The maximum number of tokens to generate in the chat completion. - `max_tokens`: *integer (optional)* - The maximum number of tokens to generate in the chat completion.

View file

@ -0,0 +1,83 @@
# Region-based Routing
Route specific customers to eu-only models.
By specifying 'allowed_model_region' for a customer, LiteLLM will filter-out any models in a model group which is not in the allowed region (i.e. 'eu').
[**See Code**](https://github.com/BerriAI/litellm/blob/5eb12e30cc5faa73799ebc7e48fc86ebf449c879/litellm/router.py#L2938)
### 1. Create customer with region-specification
Use the litellm 'end-user' object for this.
End-users can be tracked / id'ed by passing the 'user' param to litellm in an openai chat completion/embedding call.
```bash
curl -X POST --location 'http://0.0.0.0:4000/end_user/new' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"user_id" : "ishaan-jaff-45",
"allowed_model_region": "eu", # 👈 SPECIFY ALLOWED REGION='eu'
}'
```
### 2. Add eu models to model-group
Add eu models to a model group. For azure models, litellm can automatically infer the region (no need to set it).
```yaml
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/gpt-35-turbo-eu # 👈 EU azure model
api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
api_key: os.environ/AZURE_EUROPE_API_KEY
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/chatgpt-v-2
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
api_version: "2023-05-15"
api_key: os.environ/AZURE_API_KEY
router_settings:
enable_pre_call_checks: true # 👈 IMPORTANT
```
Start the proxy
```yaml
litellm --config /path/to/config.yaml
```
### 3. Test it!
Make a simple chat completions call to the proxy. In the response headers, you should see the returned api base.
```bash
curl -X POST --location 'http://localhost:4000/chat/completions' \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer sk-1234' \
--data '{
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "what is the meaning of the universe? 1234"
}],
"user": "ishaan-jaff-45" # 👈 USER ID
}
'
```
Expected API Base in response headers
```
x-litellm-api-base: "https://my-endpoint-europe-berri-992.openai.azure.com/"
```
### FAQ
**What happens if there are no available models for that region?**
Since the router filters out models not in the specified region, it will return back as an error to the user, if no models in that region are available.

View file

@ -50,6 +50,7 @@ const sidebars = {
items: ["proxy/logging", "proxy/streaming_logging"], items: ["proxy/logging", "proxy/streaming_logging"],
}, },
"proxy/team_based_routing", "proxy/team_based_routing",
"proxy/customer_routing",
"proxy/ui", "proxy/ui",
"proxy/cost_tracking", "proxy/cost_tracking",
"proxy/token_auth", "proxy/token_auth",

View file

@ -1,3 +1,6 @@
### Hide pydantic namespace conflict warnings globally ###
import warnings
warnings.filterwarnings("ignore", message=".*conflict with protected namespace.*")
### INIT VARIABLES ### ### INIT VARIABLES ###
import threading, requests, os import threading, requests, os
from typing import Callable, List, Optional, Dict, Union, Any, Literal from typing import Callable, List, Optional, Dict, Union, Any, Literal
@ -71,9 +74,11 @@ maritalk_key: Optional[str] = None
ai21_key: Optional[str] = None ai21_key: Optional[str] = None
ollama_key: Optional[str] = None ollama_key: Optional[str] = None
openrouter_key: Optional[str] = None openrouter_key: Optional[str] = None
predibase_key: Optional[str] = None
huggingface_key: Optional[str] = None huggingface_key: Optional[str] = None
vertex_project: Optional[str] = None vertex_project: Optional[str] = None
vertex_location: Optional[str] = None vertex_location: Optional[str] = None
predibase_tenant_id: Optional[str] = None
togetherai_api_key: Optional[str] = None togetherai_api_key: Optional[str] = None
cloudflare_api_key: Optional[str] = None cloudflare_api_key: Optional[str] = None
baseten_key: Optional[str] = None baseten_key: Optional[str] = None
@ -532,6 +537,7 @@ provider_list: List = [
"xinference", "xinference",
"fireworks_ai", "fireworks_ai",
"watsonx", "watsonx",
"predibase",
"custom", # custom apis "custom", # custom apis
] ]
@ -644,6 +650,7 @@ from .utils import (
) )
from .llms.huggingface_restapi import HuggingfaceConfig from .llms.huggingface_restapi import HuggingfaceConfig
from .llms.anthropic import AnthropicConfig from .llms.anthropic import AnthropicConfig
from .llms.predibase import PredibaseConfig
from .llms.anthropic_text import AnthropicTextConfig from .llms.anthropic_text import AnthropicTextConfig
from .llms.replicate import ReplicateConfig from .llms.replicate import ReplicateConfig
from .llms.cohere import CohereConfig from .llms.cohere import CohereConfig

View file

@ -322,9 +322,9 @@ class Huggingface(BaseLLM):
encoding, encoding,
api_key, api_key,
logging_obj, logging_obj,
optional_params: dict,
custom_prompt_dict={}, custom_prompt_dict={},
acompletion: bool = False, acompletion: bool = False,
optional_params=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
): ):
@ -399,10 +399,11 @@ class Huggingface(BaseLLM):
data = { data = {
"inputs": prompt, "inputs": prompt,
"parameters": optional_params, "parameters": optional_params,
"stream": ( "stream": ( # type: ignore
True True
if "stream" in optional_params if "stream" in optional_params
and optional_params["stream"] == True and isinstance(optional_params["stream"], bool)
and optional_params["stream"] == True # type: ignore
else False else False
), ),
} }
@ -433,7 +434,7 @@ class Huggingface(BaseLLM):
data = { data = {
"inputs": prompt, "inputs": prompt,
"parameters": inference_params, "parameters": inference_params,
"stream": ( "stream": ( # type: ignore
True True
if "stream" in optional_params if "stream" in optional_params
and optional_params["stream"] == True and optional_params["stream"] == True

View file

@ -530,6 +530,7 @@ class OpenAIChatCompletion(BaseLLM):
model=model, model=model,
custom_llm_provider="openai", custom_llm_provider="openai",
logging_obj=logging_obj, logging_obj=logging_obj,
stream_options=data.get("stream_options", None),
) )
return streamwrapper return streamwrapper
@ -579,6 +580,7 @@ class OpenAIChatCompletion(BaseLLM):
model=model, model=model,
custom_llm_provider="openai", custom_llm_provider="openai",
logging_obj=logging_obj, logging_obj=logging_obj,
stream_options=data.get("stream_options", None),
) )
return streamwrapper return streamwrapper
except ( except (
@ -1203,6 +1205,7 @@ class OpenAITextCompletion(BaseLLM):
model=model, model=model,
custom_llm_provider="text-completion-openai", custom_llm_provider="text-completion-openai",
logging_obj=logging_obj, logging_obj=logging_obj,
stream_options=data.get("stream_options", None),
) )
for chunk in streamwrapper: for chunk in streamwrapper:
@ -1241,6 +1244,7 @@ class OpenAITextCompletion(BaseLLM):
model=model, model=model,
custom_llm_provider="text-completion-openai", custom_llm_provider="text-completion-openai",
logging_obj=logging_obj, logging_obj=logging_obj,
stream_options=data.get("stream_options", None),
) )
async for transformed_chunk in streamwrapper: async for transformed_chunk in streamwrapper:

520
litellm/llms/predibase.py Normal file
View file

@ -0,0 +1,520 @@
# What is this?
## Controller file for Predibase Integration - https://predibase.com/
import os, types
import json
from enum import Enum
import requests, copy # type: ignore
import time
from typing import Callable, Optional, List, Literal, Union
from litellm.utils import (
ModelResponse,
Usage,
map_finish_reason,
CustomStreamWrapper,
Message,
Choices,
)
import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from .base import BaseLLM
import httpx # type: ignore
class PredibaseError(Exception):
def __init__(
self,
status_code,
message,
request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None,
):
self.status_code = status_code
self.message = message
if request is not None:
self.request = request
else:
self.request = httpx.Request(
method="POST",
url="https://docs.predibase.com/user-guide/inference/rest_api",
)
if response is not None:
self.response = response
else:
self.response = httpx.Response(
status_code=status_code, request=self.request
)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class PredibaseConfig:
"""
Reference: https://docs.predibase.com/user-guide/inference/rest_api
"""
adapter_id: Optional[str] = None
adapter_source: Optional[Literal["pbase", "hub", "s3"]] = None
best_of: Optional[int] = None
decoder_input_details: Optional[bool] = None
details: bool = True # enables returning logprobs + best of
max_new_tokens: int = (
256 # openai default - requests hang if max_new_tokens not given
)
repetition_penalty: Optional[float] = None
return_full_text: Optional[bool] = (
False # by default don't return the input as part of the output
)
seed: Optional[int] = None
stop: Optional[List[str]] = None
temperature: Optional[float] = None
top_k: Optional[int] = None
top_p: Optional[int] = None
truncate: Optional[int] = None
typical_p: Optional[float] = None
watermark: Optional[bool] = None
def __init__(
self,
best_of: Optional[int] = None,
decoder_input_details: Optional[bool] = None,
details: Optional[bool] = None,
max_new_tokens: Optional[int] = None,
repetition_penalty: Optional[float] = None,
return_full_text: Optional[bool] = None,
seed: Optional[int] = None,
stop: Optional[List[str]] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[int] = None,
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: Optional[bool] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
class PredibaseChatCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
def _validate_environment(self, api_key: Optional[str], user_headers: dict) -> dict:
if api_key is None:
raise ValueError(
"Missing Predibase API Key - A call is being made to predibase but no key is set either in the environment variables or via params"
)
headers = {
"content-type": "application/json",
"Authorization": "Bearer {}".format(api_key),
}
if user_headers is not None and isinstance(user_headers, dict):
headers = {**headers, **user_headers}
return headers
def output_parser(self, generated_text: str):
"""
Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens.
Initial issue that prompted this - https://github.com/BerriAI/litellm/issues/763
"""
chat_template_tokens = [
"<|assistant|>",
"<|system|>",
"<|user|>",
"<s>",
"</s>",
]
for token in chat_template_tokens:
if generated_text.strip().startswith(token):
generated_text = generated_text.replace(token, "", 1)
if generated_text.endswith(token):
generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1]
return generated_text
def process_response(
self,
model: str,
response: Union[requests.Response, httpx.Response],
model_response: ModelResponse,
stream: bool,
logging_obj: litellm.utils.Logging,
optional_params: dict,
api_key: str,
data: dict,
messages: list,
print_verbose,
encoding,
) -> ModelResponse:
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
try:
completion_response = response.json()
except:
raise PredibaseError(
message=response.text, status_code=response.status_code
)
if "error" in completion_response:
raise PredibaseError(
message=str(completion_response["error"]),
status_code=response.status_code,
)
else:
if (
not isinstance(completion_response, dict)
or "generated_text" not in completion_response
):
raise PredibaseError(
status_code=422,
message=f"response is not in expected format - {completion_response}",
)
if len(completion_response["generated_text"]) > 0:
model_response["choices"][0]["message"]["content"] = self.output_parser(
completion_response["generated_text"]
)
## GETTING LOGPROBS + FINISH REASON
if (
"details" in completion_response
and "tokens" in completion_response["details"]
):
model_response.choices[0].finish_reason = completion_response[
"details"
]["finish_reason"]
sum_logprob = 0
for token in completion_response["details"]["tokens"]:
if token["logprob"] != None:
sum_logprob += token["logprob"]
model_response["choices"][0][
"message"
]._logprob = (
sum_logprob # [TODO] move this to using the actual logprobs
)
if "best_of" in optional_params and optional_params["best_of"] > 1:
if (
"details" in completion_response
and "best_of_sequences" in completion_response["details"]
):
choices_list = []
for idx, item in enumerate(
completion_response["details"]["best_of_sequences"]
):
sum_logprob = 0
for token in item["tokens"]:
if token["logprob"] != None:
sum_logprob += token["logprob"]
if len(item["generated_text"]) > 0:
message_obj = Message(
content=self.output_parser(item["generated_text"]),
logprobs=sum_logprob,
)
else:
message_obj = Message(content=None)
choice_obj = Choices(
finish_reason=item["finish_reason"],
index=idx + 1,
message=message_obj,
)
choices_list.append(choice_obj)
model_response["choices"].extend(choices_list)
## CALCULATING USAGE
prompt_tokens = 0
try:
prompt_tokens = len(
encoding.encode(model_response["choices"][0]["message"]["content"])
) ##[TODO] use a model-specific tokenizer here
except:
# this should remain non blocking we should not block a response returning if calculating usage fails
pass
output_text = model_response["choices"][0]["message"].get("content", "")
if output_text is not None and len(output_text) > 0:
completion_tokens = 0
try:
completion_tokens = len(
encoding.encode(
model_response["choices"][0]["message"].get("content", "")
)
) ##[TODO] use a model-specific tokenizer
except:
# this should remain non blocking we should not block a response returning if calculating usage fails
pass
else:
completion_tokens = 0
total_tokens = prompt_tokens + completion_tokens
model_response["created"] = int(time.time())
model_response["model"] = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
)
model_response.usage = usage # type: ignore
return model_response
def completion(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key: str,
logging_obj,
optional_params: dict,
tenant_id: str,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers: dict = {},
) -> Union[ModelResponse, CustomStreamWrapper]:
headers = self._validate_environment(api_key, headers)
completion_url = ""
input_text = ""
base_url = "https://serving.app.predibase.com"
if "https" in model:
completion_url = model
elif api_base:
base_url = api_base
elif "PREDIBASE_API_BASE" in os.environ:
base_url = os.getenv("PREDIBASE_API_BASE", "")
completion_url = f"{base_url}/{tenant_id}/deployments/v2/llms/{model}"
if optional_params.get("stream", False) == True:
completion_url += "/generate_stream"
else:
completion_url += "/generate"
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages)
## Load Config
config = litellm.PredibaseConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
stream = optional_params.pop("stream", False)
data = {
"inputs": prompt,
"parameters": optional_params,
}
input_text = prompt
## LOGGING
logging_obj.pre_call(
input=input_text,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": completion_url,
"acompletion": acompletion,
},
)
## COMPLETION CALL
if acompletion is True:
### ASYNC STREAMING
if stream == True:
return self.async_streaming(
model=model,
messages=messages,
data=data,
api_base=completion_url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
api_key=api_key,
logging_obj=logging_obj,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
) # type: ignore
else:
### ASYNC COMPLETION
return self.async_completion(
model=model,
messages=messages,
data=data,
api_base=completion_url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
api_key=api_key,
logging_obj=logging_obj,
optional_params=optional_params,
stream=False,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
) # type: ignore
### SYNC STREAMING
if stream == True:
response = requests.post(
completion_url,
headers=headers,
data=json.dumps(data),
stream=stream,
)
_response = CustomStreamWrapper(
response.iter_lines(),
model,
custom_llm_provider="predibase",
logging_obj=logging_obj,
)
return _response
### SYNC COMPLETION
else:
response = requests.post(
url=completion_url,
headers=headers,
data=json.dumps(data),
)
return self.process_response(
model=model,
response=response,
model_response=model_response,
stream=optional_params.get("stream", False),
logging_obj=logging_obj, # type: ignore
optional_params=optional_params,
api_key=api_key,
data=data,
messages=messages,
print_verbose=print_verbose,
encoding=encoding,
)
async def async_completion(
self,
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
stream,
data: dict,
optional_params: dict,
litellm_params=None,
logger_fn=None,
headers={},
) -> ModelResponse:
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
response = await self.async_handler.post(
api_base, headers=headers, data=json.dumps(data)
)
return self.process_response(
model=model,
response=response,
model_response=model_response,
stream=stream,
logging_obj=logging_obj,
api_key=api_key,
data=data,
messages=messages,
print_verbose=print_verbose,
optional_params=optional_params,
encoding=encoding,
)
async def async_streaming(
self,
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
data: dict,
optional_params=None,
litellm_params=None,
logger_fn=None,
headers={},
) -> CustomStreamWrapper:
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
data["stream"] = True
response = await self.async_handler.post(
url=api_base,
headers=headers,
data=json.dumps(data),
stream=True,
)
if response.status_code != 200:
raise PredibaseError(
status_code=response.status_code, message=response.text
)
completion_stream = response.aiter_lines()
streamwrapper = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="predibase",
logging_obj=logging_obj,
)
return streamwrapper
def embedding(self, *args, **kwargs):
pass

View file

@ -487,7 +487,7 @@ def format_prompt_togetherai(messages, prompt_format, chat_template):
def ibm_granite_pt(messages: list): def ibm_granite_pt(messages: list):
""" """
IBM's Granite chat models uses the template: IBM's Granite models uses the template:
<|system|> {system_message} <|user|> {user_message} <|assistant|> {assistant_message} <|system|> {system_message} <|user|> {user_message} <|assistant|> {assistant_message}
See: https://www.ibm.com/docs/en/watsonx-as-a-service?topic=solutions-supported-foundation-models See: https://www.ibm.com/docs/en/watsonx-as-a-service?topic=solutions-supported-foundation-models
@ -503,13 +503,12 @@ def ibm_granite_pt(messages: list):
"pre_message": "<|user|>\n", "pre_message": "<|user|>\n",
"post_message": "\n", "post_message": "\n",
}, },
'assistant': { "assistant": {
'pre_message': '<|assistant|>\n', "pre_message": "<|assistant|>\n",
'post_message': '\n', "post_message": "\n",
}, },
}, },
final_prompt_value='<|assistant|>\n', ).strip()
)
### ANTHROPIC ### ### ANTHROPIC ###
@ -1525,9 +1524,24 @@ def prompt_factory(
return mistral_instruct_pt(messages=messages) return mistral_instruct_pt(messages=messages)
elif "meta-llama/llama-3" in model and "instruct" in model: elif "meta-llama/llama-3" in model and "instruct" in model:
# https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3/ # https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3/
return hf_chat_template( return custom_prompt(
model="meta-llama/Meta-Llama-3-8B-Instruct", role_dict={
"system": {
"pre_message": "<|start_header_id|>system<|end_header_id|>\n",
"post_message": "<|eot_id|>",
},
"user": {
"pre_message": "<|start_header_id|>user<|end_header_id|>\n",
"post_message": "<|eot_id|>",
},
"assistant": {
"pre_message": "<|start_header_id|>assistant<|end_header_id|>\n",
"post_message": "<|eot_id|>",
},
},
messages=messages, messages=messages,
initial_prompt_value="<|begin_of_text|>",
final_prompt_value="<|start_header_id|>assistant<|end_header_id|>\n",
) )
try: try:
if "meta-llama/llama-2" in model and "chat" in model: if "meta-llama/llama-2" in model and "chat" in model:

View file

@ -451,9 +451,6 @@ class IBMWatsonXAI(BaseLLM):
return streamwrapper return streamwrapper
# create the function to manage the request to watsonx.ai # create the function to manage the request to watsonx.ai
# manage_request = self._make_request_manager(
# async_=(acompletion is True), logging_obj=logging_obj
# )
self.request_manager = RequestManager(logging_obj) self.request_manager = RequestManager(logging_obj)
def handle_text_request(request_params: dict) -> ModelResponse: def handle_text_request(request_params: dict) -> ModelResponse:
@ -576,9 +573,6 @@ class IBMWatsonXAI(BaseLLM):
"json": payload, "json": payload,
"params": request_params, "params": request_params,
} }
# manage_request = self._make_request_manager(
# async_=(aembedding is True), logging_obj=logging_obj
# )
request_manager = RequestManager(logging_obj) request_manager = RequestManager(logging_obj)
def process_embedding_response(json_resp: dict) -> ModelResponse: def process_embedding_response(json_resp: dict) -> ModelResponse:
@ -654,143 +648,12 @@ class IBMWatsonXAI(BaseLLM):
request_params = dict(version=api_params["api_version"]) request_params = dict(version=api_params["api_version"])
url = api_params["url"].rstrip("/") + WatsonXAIEndpoint.AVAILABLE_MODELS url = api_params["url"].rstrip("/") + WatsonXAIEndpoint.AVAILABLE_MODELS
req_params = dict(method="GET", url=url, headers=headers, params=request_params) req_params = dict(method="GET", url=url, headers=headers, params=request_params)
# manage_request = self._make_request_manager(async_=False, logging_obj=None)
with RequestManager(logging_obj=None).request(req_params) as resp: with RequestManager(logging_obj=None).request(req_params) as resp:
json_resp = resp.json() json_resp = resp.json()
if not ids_only: if not ids_only:
return json_resp return json_resp
return [res["model_id"] for res in json_resp["resources"]] return [res["model_id"] for res in json_resp["resources"]]
def _make_request_manager(
self, async_: bool, logging_obj=None
) -> Callable[
...,
Union[ContextManager[requests.Response], AsyncContextManager[httpx.Response]],
]:
"""
Returns a context manager that manages the response from the request.
if async_ is True, returns an async context manager, otherwise returns a regular context manager.
Usage:
```python
manage_request = self._make_request_manager(async_=True, logging_obj=logging_obj)
async with manage_request(request_params) as resp:
...
# or
manage_request = self._make_request_manager(async_=False, logging_obj=logging_obj)
with manage_request(request_params) as resp:
...
```
"""
def pre_call(
request_params: dict,
input: Optional[Any] = None,
):
if logging_obj is None:
return
request_str = (
f"response = {'await ' if async_ else ''}{request_params['method']}(\n"
f"\turl={request_params['url']},\n"
f"\tjson={request_params.get('json')},\n"
f")"
)
logging_obj.pre_call(
input=input,
api_key=request_params["headers"].get("Authorization"),
additional_args={
"complete_input_dict": request_params.get("json"),
"request_str": request_str,
},
)
def post_call(resp, request_params):
if logging_obj is None:
return
logging_obj.post_call(
input=input,
api_key=request_params["headers"].get("Authorization"),
original_response=json.dumps(resp.json()),
additional_args={
"status_code": resp.status_code,
"complete_input_dict": request_params.get(
"data", request_params.get("json")
),
},
)
@contextmanager
def _manage_request(
request_params: dict,
stream: bool = False,
input: Optional[Any] = None,
timeout=None,
) -> Generator[requests.Response, None, None]:
"""
Returns a context manager that yields the response from the request.
"""
pre_call(request_params, input)
if timeout:
request_params["timeout"] = timeout
if stream:
request_params["stream"] = stream
try:
resp = requests.request(**request_params)
if not resp.ok:
raise WatsonXAIError(
status_code=resp.status_code,
message=f"Error {resp.status_code} ({resp.reason}): {resp.text}",
)
yield resp
except Exception as e:
raise WatsonXAIError(status_code=500, message=str(e))
if not stream:
post_call(resp, request_params)
@asynccontextmanager
async def _manage_request_async(
request_params: dict,
stream: bool = False,
input: Optional[Any] = None,
timeout=None,
) -> AsyncGenerator[httpx.Response, None]:
pre_call(request_params, input)
if timeout:
request_params["timeout"] = timeout
if stream:
request_params["stream"] = stream
try:
# async with AsyncHTTPHandler(timeout=timeout) as client:
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(
timeout=request_params.pop("timeout", 600.0), connect=5.0
),
)
# async_handler.client.verify = False
if "json" in request_params:
request_params["data"] = json.dumps(request_params.pop("json", {}))
method = request_params.pop("method")
if method.upper() == "POST":
resp = await self.async_handler.post(**request_params)
else:
resp = await self.async_handler.get(**request_params)
if resp.status_code not in [200, 201]:
raise WatsonXAIError(
status_code=resp.status_code,
message=f"Error {resp.status_code} ({resp.reason}): {resp.text}",
)
yield resp
# await async_handler.close()
except Exception as e:
raise WatsonXAIError(status_code=500, message=str(e))
if not stream:
post_call(resp, request_params)
if async_:
return _manage_request_async
else:
return _manage_request
class RequestManager: class RequestManager:
""" """
Returns a context manager that manages the response from the request. Returns a context manager that manages the response from the request.

View file

@ -14,6 +14,7 @@ import dotenv, traceback, random, asyncio, time, contextvars
from copy import deepcopy from copy import deepcopy
import httpx import httpx
import litellm import litellm
from ._logging import verbose_logger from ._logging import verbose_logger
from litellm import ( # type: ignore from litellm import ( # type: ignore
client, client,
@ -73,7 +74,7 @@ from .llms.azure_text import AzureTextCompletion
from .llms.anthropic import AnthropicChatCompletion from .llms.anthropic import AnthropicChatCompletion
from .llms.anthropic_text import AnthropicTextCompletion from .llms.anthropic_text import AnthropicTextCompletion
from .llms.huggingface_restapi import Huggingface from .llms.huggingface_restapi import Huggingface
from .llms.watsonx import IBMWatsonXAI from .llms.predibase import PredibaseChatCompletion
from .llms.prompt_templates.factory import ( from .llms.prompt_templates.factory import (
prompt_factory, prompt_factory,
custom_prompt, custom_prompt,
@ -110,7 +111,7 @@ anthropic_text_completions = AnthropicTextCompletion()
azure_chat_completions = AzureChatCompletion() azure_chat_completions = AzureChatCompletion()
azure_text_completions = AzureTextCompletion() azure_text_completions = AzureTextCompletion()
huggingface = Huggingface() huggingface = Huggingface()
watsonxai = IBMWatsonXAI() predibase_chat_completions = PredibaseChatCompletion()
####### COMPLETION ENDPOINTS ################ ####### COMPLETION ENDPOINTS ################
@ -189,6 +190,7 @@ async def acompletion(
top_p: Optional[float] = None, top_p: Optional[float] = None,
n: Optional[int] = None, n: Optional[int] = None,
stream: Optional[bool] = None, stream: Optional[bool] = None,
stream_options: Optional[dict] = None,
stop=None, stop=None,
max_tokens: Optional[int] = None, max_tokens: Optional[int] = None,
presence_penalty: Optional[float] = None, presence_penalty: Optional[float] = None,
@ -208,6 +210,7 @@ async def acompletion(
api_version: Optional[str] = None, api_version: Optional[str] = None,
api_key: Optional[str] = None, api_key: Optional[str] = None,
model_list: Optional[list] = None, # pass in a list of api_base,keys, etc. model_list: Optional[list] = None, # pass in a list of api_base,keys, etc.
extra_headers: Optional[dict] = None,
# Optional liteLLM function params # Optional liteLLM function params
**kwargs, **kwargs,
): ):
@ -225,6 +228,7 @@ async def acompletion(
top_p (float, optional): The top-p parameter for nucleus sampling (default is 1.0). top_p (float, optional): The top-p parameter for nucleus sampling (default is 1.0).
n (int, optional): The number of completions to generate (default is 1). n (int, optional): The number of completions to generate (default is 1).
stream (bool, optional): If True, return a streaming response (default is False). stream (bool, optional): If True, return a streaming response (default is False).
stream_options (dict, optional): A dictionary containing options for the streaming response. Only use this if stream is True.
stop(string/list, optional): - Up to 4 sequences where the LLM API will stop generating further tokens. stop(string/list, optional): - Up to 4 sequences where the LLM API will stop generating further tokens.
max_tokens (integer, optional): The maximum number of tokens in the generated completion (default is infinity). max_tokens (integer, optional): The maximum number of tokens in the generated completion (default is infinity).
presence_penalty (float, optional): It is used to penalize new tokens based on their existence in the text so far. presence_penalty (float, optional): It is used to penalize new tokens based on their existence in the text so far.
@ -262,6 +266,7 @@ async def acompletion(
"top_p": top_p, "top_p": top_p,
"n": n, "n": n,
"stream": stream, "stream": stream,
"stream_options": stream_options,
"stop": stop, "stop": stop,
"max_tokens": max_tokens, "max_tokens": max_tokens,
"presence_penalty": presence_penalty, "presence_penalty": presence_penalty,
@ -315,7 +320,7 @@ async def acompletion(
or custom_llm_provider == "gemini" or custom_llm_provider == "gemini"
or custom_llm_provider == "sagemaker" or custom_llm_provider == "sagemaker"
or custom_llm_provider == "anthropic" or custom_llm_provider == "anthropic"
or custom_llm_provider == "watsonx" or custom_llm_provider == "predibase"
or custom_llm_provider in litellm.openai_compatible_providers or custom_llm_provider in litellm.openai_compatible_providers
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all. ): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
init_response = await loop.run_in_executor(None, func_with_context) init_response = await loop.run_in_executor(None, func_with_context)
@ -460,6 +465,7 @@ def completion(
top_p: Optional[float] = None, top_p: Optional[float] = None,
n: Optional[int] = None, n: Optional[int] = None,
stream: Optional[bool] = None, stream: Optional[bool] = None,
stream_options: Optional[dict] = None,
stop=None, stop=None,
max_tokens: Optional[int] = None, max_tokens: Optional[int] = None,
presence_penalty: Optional[float] = None, presence_penalty: Optional[float] = None,
@ -499,6 +505,7 @@ def completion(
top_p (float, optional): The top-p parameter for nucleus sampling (default is 1.0). top_p (float, optional): The top-p parameter for nucleus sampling (default is 1.0).
n (int, optional): The number of completions to generate (default is 1). n (int, optional): The number of completions to generate (default is 1).
stream (bool, optional): If True, return a streaming response (default is False). stream (bool, optional): If True, return a streaming response (default is False).
stream_options (dict, optional): A dictionary containing options for the streaming response. Only set this when you set stream: true.
stop(string/list, optional): - Up to 4 sequences where the LLM API will stop generating further tokens. stop(string/list, optional): - Up to 4 sequences where the LLM API will stop generating further tokens.
max_tokens (integer, optional): The maximum number of tokens in the generated completion (default is infinity). max_tokens (integer, optional): The maximum number of tokens in the generated completion (default is infinity).
presence_penalty (float, optional): It is used to penalize new tokens based on their existence in the text so far. presence_penalty (float, optional): It is used to penalize new tokens based on their existence in the text so far.
@ -576,6 +583,7 @@ def completion(
"top_p", "top_p",
"n", "n",
"stream", "stream",
"stream_options",
"stop", "stop",
"max_tokens", "max_tokens",
"presence_penalty", "presence_penalty",
@ -788,6 +796,7 @@ def completion(
top_p=top_p, top_p=top_p,
n=n, n=n,
stream=stream, stream=stream,
stream_options=stream_options,
stop=stop, stop=stop,
max_tokens=max_tokens, max_tokens=max_tokens,
presence_penalty=presence_penalty, presence_penalty=presence_penalty,
@ -1779,6 +1788,52 @@ def completion(
) )
return response return response
response = model_response response = model_response
elif custom_llm_provider == "predibase":
tenant_id = (
optional_params.pop("tenant_id", None)
or optional_params.pop("predibase_tenant_id", None)
or litellm.predibase_tenant_id
or get_secret("PREDIBASE_TENANT_ID")
)
api_base = (
optional_params.pop("api_base", None)
or optional_params.pop("base_url", None)
or litellm.api_base
or get_secret("PREDIBASE_API_BASE")
)
api_key = (
api_key
or litellm.api_key
or litellm.predibase_key
or get_secret("PREDIBASE_API_KEY")
)
_model_response = predibase_chat_completions.completion(
model=model,
messages=messages,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
acompletion=acompletion,
api_base=api_base,
custom_prompt_dict=custom_prompt_dict,
api_key=api_key,
tenant_id=tenant_id,
)
if (
"stream" in optional_params
and optional_params["stream"] == True
and acompletion == False
):
return _model_response
response = _model_response
elif custom_llm_provider == "ai21": elif custom_llm_provider == "ai21":
custom_llm_provider = "ai21" custom_llm_provider = "ai21"
ai21_key = ( ai21_key = (
@ -1911,7 +1966,7 @@ def completion(
response = response response = response
elif custom_llm_provider == "watsonx": elif custom_llm_provider == "watsonx":
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
response = watsonxai.completion( response = watsonx.IBMWatsonXAI().completion(
model=model, model=model,
messages=messages, messages=messages,
custom_prompt_dict=custom_prompt_dict, custom_prompt_dict=custom_prompt_dict,
@ -1922,8 +1977,7 @@ def completion(
logger_fn=logger_fn, logger_fn=logger_fn,
encoding=encoding, encoding=encoding,
logging_obj=logging, logging_obj=logging,
acompletion=acompletion, timeout=timeout, # type: ignore
timeout=timeout,
) )
if ( if (
"stream" in optional_params "stream" in optional_params
@ -2576,7 +2630,6 @@ async def aembedding(*args, **kwargs):
or custom_llm_provider == "fireworks_ai" or custom_llm_provider == "fireworks_ai"
or custom_llm_provider == "ollama" or custom_llm_provider == "ollama"
or custom_llm_provider == "vertex_ai" or custom_llm_provider == "vertex_ai"
or custom_llm_provider == "watsonx"
): # currently implemented aiohttp calls for just azure and openai, soon all. ): # currently implemented aiohttp calls for just azure and openai, soon all.
# Await normally # Await normally
init_response = await loop.run_in_executor(None, func_with_context) init_response = await loop.run_in_executor(None, func_with_context)
@ -3034,14 +3087,13 @@ def embedding(
aembedding=aembedding, aembedding=aembedding,
) )
elif custom_llm_provider == "watsonx": elif custom_llm_provider == "watsonx":
response = watsonxai.embedding( response = watsonx.IBMWatsonXAI().embedding(
model=model, model=model,
input=input, input=input,
encoding=encoding, encoding=encoding,
logging_obj=logging, logging_obj=logging,
optional_params=optional_params, optional_params=optional_params,
model_response=EmbeddingResponse(), model_response=EmbeddingResponse(),
aembedding=aembedding,
) )
else: else:
args = locals() args = locals()
@ -3197,6 +3249,7 @@ def text_completion(
Union[str, List[str]] Union[str, List[str]]
] = None, # Optional: Sequences where the API will stop generating further tokens. ] = None, # Optional: Sequences where the API will stop generating further tokens.
stream: Optional[bool] = None, # Optional: Whether to stream back partial progress. stream: Optional[bool] = None, # Optional: Whether to stream back partial progress.
stream_options: Optional[dict] = None,
suffix: Optional[ suffix: Optional[
str str
] = None, # Optional: The suffix that comes after a completion of inserted text. ] = None, # Optional: The suffix that comes after a completion of inserted text.
@ -3274,6 +3327,8 @@ def text_completion(
optional_params["stop"] = stop optional_params["stop"] = stop
if stream is not None: if stream is not None:
optional_params["stream"] = stream optional_params["stream"] = stream
if stream_options is not None:
optional_params["stream_options"] = stream_options
if suffix is not None: if suffix is not None:
optional_params["suffix"] = suffix optional_params["suffix"] = suffix
if temperature is not None: if temperature is not None:
@ -3384,7 +3439,9 @@ def text_completion(
if kwargs.get("acompletion", False) == True: if kwargs.get("acompletion", False) == True:
return response return response
if stream == True or kwargs.get("stream", False) == True: if stream == True or kwargs.get("stream", False) == True:
response = TextCompletionStreamWrapper(completion_stream=response, model=model) response = TextCompletionStreamWrapper(
completion_stream=response, model=model, stream_options=stream_options
)
return response return response
transformed_logprobs = None transformed_logprobs = None
# only supported for TGI models # only supported for TGI models

View file

@ -206,11 +206,9 @@ async def get_end_user_object(
if end_user_id is None: if end_user_id is None:
return None return None
_key = "end_user_id:{}".format(end_user_id)
# check if in cache # check if in cache
cached_user_obj = user_api_key_cache.async_get_cache( cached_user_obj = await user_api_key_cache.async_get_cache(key=_key)
key="end_user_id:{}".format(end_user_id)
)
if cached_user_obj is not None: if cached_user_obj is not None:
if isinstance(cached_user_obj, dict): if isinstance(cached_user_obj, dict):
return LiteLLM_EndUserTable(**cached_user_obj) return LiteLLM_EndUserTable(**cached_user_obj)

View file

@ -1086,9 +1086,7 @@ async def user_api_key_auth(
user_id_information, list user_id_information, list
): ):
_user = user_id_information[0] _user = user_id_information[0]
user_role = _user.get("user_role", {}).get( user_role = _user.get("user_role", "unknown")
"user_role", "unknown"
)
user_id = _user.get("user_id", "unknown") user_id = _user.get("user_id", "unknown")
raise Exception( raise Exception(
f"Only proxy admin can be used to generate, delete, update info for new keys/users/teams. Route={route}. Your role={user_role}. Your user_id={user_id}" f"Only proxy admin can be used to generate, delete, update info for new keys/users/teams. Route={route}. Your role={user_role}. Your user_id={user_id}"
@ -1834,6 +1832,9 @@ async def update_cache(
) )
async def _update_end_user_cache(): async def _update_end_user_cache():
if end_user_id is None or response_cost is None:
return
_id = "end_user_id:{}".format(end_user_id) _id = "end_user_id:{}".format(end_user_id)
try: try:
# Fetch the existing cost for the given user # Fetch the existing cost for the given user
@ -1846,7 +1847,7 @@ async def update_cache(
if litellm.max_end_user_budget is not None: if litellm.max_end_user_budget is not None:
max_end_user_budget = litellm.max_end_user_budget max_end_user_budget = litellm.max_end_user_budget
existing_spend_obj = LiteLLM_EndUserTable( existing_spend_obj = LiteLLM_EndUserTable(
user_id=_id, user_id=end_user_id,
spend=0, spend=0,
blocked=False, blocked=False,
litellm_budget_table=LiteLLM_BudgetTable( litellm_budget_table=LiteLLM_BudgetTable(
@ -1874,7 +1875,7 @@ async def update_cache(
existing_spend_obj.spend = new_spend existing_spend_obj.spend = new_spend
user_api_key_cache.set_cache(key=_id, value=existing_spend_obj.json()) user_api_key_cache.set_cache(key=_id, value=existing_spend_obj.json())
except Exception as e: except Exception as e:
verbose_proxy_logger.debug( verbose_proxy_logger.error(
f"An error occurred updating end user cache: {str(e)}\n\n{traceback.format_exc()}" f"An error occurred updating end user cache: {str(e)}\n\n{traceback.format_exc()}"
) )
@ -7310,6 +7311,43 @@ async def unblock_team(
return record return record
@router.get(
"/team/list", tags=["team management"], dependencies=[Depends(user_api_key_auth)]
)
async def list_team(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
[Admin-only] List all available teams
```
curl --location --request GET 'http://0.0.0.0:4000/team/list' \
--header 'Authorization: Bearer sk-1234'
```
"""
global prisma_client
if user_api_key_dict.user_role != "proxy_admin":
raise HTTPException(
status_code=401,
detail={
"error": "Admin-only endpoint. Your user role={}".format(
user_api_key_dict.user_role
)
},
)
if prisma_client is None:
raise HTTPException(
status_code=400,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
response = await prisma_client.db.litellm_teamtable.find_many()
return response
#### ORGANIZATION MANAGEMENT #### #### ORGANIZATION MANAGEMENT ####

File diff suppressed because it is too large Load diff

View file

@ -1,5 +1,6 @@
import pytest import pytest
from litellm import acompletion from litellm import acompletion
from litellm import completion
def test_acompletion_params(): def test_acompletion_params():
@ -7,17 +8,29 @@ def test_acompletion_params():
from litellm.types.completion import CompletionRequest from litellm.types.completion import CompletionRequest
acompletion_params_odict = inspect.signature(acompletion).parameters acompletion_params_odict = inspect.signature(acompletion).parameters
acompletion_params = {name: param.annotation for name, param in acompletion_params_odict.items()} completion_params_dict = inspect.signature(completion).parameters
completion_params = {field_name: field_type for field_name, field_type in CompletionRequest.__annotations__.items()}
# remove kwargs acompletion_params = {
acompletion_params.pop("kwargs", None) name: param.annotation for name, param in acompletion_params_odict.items()
}
completion_params = {
name: param.annotation for name, param in completion_params_dict.items()
}
keys_acompletion = set(acompletion_params.keys()) keys_acompletion = set(acompletion_params.keys())
keys_completion = set(completion_params.keys()) keys_completion = set(completion_params.keys())
print(keys_acompletion)
print("\n\n\n")
print(keys_completion)
print("diff=", keys_completion - keys_acompletion)
# Assert that the parameters are the same # Assert that the parameters are the same
if keys_acompletion != keys_completion: if keys_acompletion != keys_completion:
pytest.fail("The parameters of the acompletion function and the CompletionRequest class are not the same.") pytest.fail(
"The parameters of the litellm.acompletion function and litellm.completion are not the same."
)
# test_acompletion_params() # test_acompletion_params()

View file

@ -85,6 +85,42 @@ def test_completion_azure_command_r():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# @pytest.mark.skip(reason="local test")
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_completion_predibase(sync_mode):
try:
litellm.set_verbose = True
if sync_mode:
response = completion(
model="predibase/llama-3-8b-instruct",
tenant_id="c4768f95",
api_base="https://serving.app.predibase.com",
api_key=os.getenv("PREDIBASE_API_KEY"),
messages=[{"role": "user", "content": "What is the meaning of life?"}],
)
print(response)
else:
response = await litellm.acompletion(
model="predibase/llama-3-8b-instruct",
tenant_id="c4768f95",
api_base="https://serving.app.predibase.com",
api_key=os.getenv("PREDIBASE_API_KEY"),
messages=[{"role": "user", "content": "What is the meaning of life?"}],
)
print(response)
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_predibase()
def test_completion_claude(): def test_completion_claude():
litellm.set_verbose = True litellm.set_verbose = True
litellm.cache = None litellm.cache = None

View file

@ -418,9 +418,16 @@ def test_call_with_user_over_budget(prisma_client):
print(vars(e)) print(vars(e))
def test_end_user_cache_write_unit_test():
"""
assert end user object is being written to cache as expected
"""
pass
def test_call_with_end_user_over_budget(prisma_client): def test_call_with_end_user_over_budget(prisma_client):
# Test if a user passed to /chat/completions is tracked & fails when they cross their budget # Test if a user passed to /chat/completions is tracked & fails when they cross their budget
# we only check this when litellm.max_user_budget is set # we only check this when litellm.max_end_user_budget is set
import random import random
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)

View file

@ -150,9 +150,9 @@ async def test_router_atext_completion_streaming():
{ {
"model_name": "azure-model", "model_name": "azure-model",
"litellm_params": { "litellm_params": {
"model": "azure/gpt-35-turbo", "model": "azure/gpt-turbo",
"api_key": "os.environ/AZURE_EUROPE_API_KEY", "api_key": "os.environ/AZURE_FRANCE_API_KEY",
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com", "api_base": "https://openai-france-1234.openai.azure.com",
"rpm": 6, "rpm": 6,
}, },
"model_info": {"id": 2}, "model_info": {"id": 2},
@ -160,9 +160,9 @@ async def test_router_atext_completion_streaming():
{ {
"model_name": "azure-model", "model_name": "azure-model",
"litellm_params": { "litellm_params": {
"model": "azure/gpt-35-turbo", "model": "azure/gpt-turbo",
"api_key": "os.environ/AZURE_CANADA_API_KEY", "api_key": "os.environ/AZURE_FRANCE_API_KEY",
"api_base": "https://my-endpoint-canada-berri992.openai.azure.com", "api_base": "https://openai-france-1234.openai.azure.com",
"rpm": 6, "rpm": 6,
}, },
"model_info": {"id": 3}, "model_info": {"id": 3},
@ -193,7 +193,7 @@ async def test_router_atext_completion_streaming():
## check if calls equally distributed ## check if calls equally distributed
cache_dict = router.cache.get_cache(key=cache_key) cache_dict = router.cache.get_cache(key=cache_key)
for k, v in cache_dict.items(): for k, v in cache_dict.items():
assert v == 1 assert v == 1, f"Failed. K={k} called v={v} times, cache_dict={cache_dict}"
# asyncio.run(test_router_atext_completion_streaming()) # asyncio.run(test_router_atext_completion_streaming())

View file

@ -16,7 +16,7 @@ litellm.set_verbose = True
model_alias_map = {"good-model": "anyscale/meta-llama/Llama-2-7b-chat-hf"} model_alias_map = {"good-model": "anyscale/meta-llama/Llama-2-7b-chat-hf"}
def test_model_alias_map(): def test_model_alias_map(caplog):
try: try:
litellm.model_alias_map = model_alias_map litellm.model_alias_map = model_alias_map
response = completion( response = completion(
@ -27,9 +27,15 @@ def test_model_alias_map():
max_tokens=10, max_tokens=10,
) )
print(response.model) print(response.model)
captured_logs = [rec.levelname for rec in caplog.records]
for log in captured_logs:
assert "ERROR" not in log
assert "Llama-2-7b-chat-hf" in response.model assert "Llama-2-7b-chat-hf" in response.model
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
test_model_alias_map() # test_model_alias_map()

View file

@ -5,6 +5,7 @@ import sys, os, asyncio
import traceback import traceback
import time, pytest import time, pytest
from pydantic import BaseModel from pydantic import BaseModel
from typing import Tuple
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
@ -142,7 +143,7 @@ def validate_last_format(chunk):
), "'finish_reason' should be a string." ), "'finish_reason' should be a string."
def streaming_format_tests(idx, chunk): def streaming_format_tests(idx, chunk) -> Tuple[str, bool]:
extracted_chunk = "" extracted_chunk = ""
finished = False finished = False
print(f"chunk: {chunk}") print(f"chunk: {chunk}")
@ -306,6 +307,70 @@ def test_completion_azure_stream():
# test_completion_azure_stream() # test_completion_azure_stream()
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_completion_predibase_streaming(sync_mode):
try:
litellm.set_verbose = True
if sync_mode:
response = completion(
model="predibase/llama-3-8b-instruct",
tenant_id="c4768f95",
api_base="https://serving.app.predibase.com",
api_key=os.getenv("PREDIBASE_API_KEY"),
messages=[{"role": "user", "content": "What is the meaning of life?"}],
stream=True,
)
complete_response = ""
for idx, init_chunk in enumerate(response):
chunk, finished = streaming_format_tests(idx, init_chunk)
complete_response += chunk
custom_llm_provider = init_chunk._hidden_params["custom_llm_provider"]
print(f"custom_llm_provider: {custom_llm_provider}")
assert custom_llm_provider == "predibase"
if finished:
assert isinstance(
init_chunk.choices[0], litellm.utils.StreamingChoices
)
break
if complete_response.strip() == "":
raise Exception("Empty response received")
else:
response = await litellm.acompletion(
model="predibase/llama-3-8b-instruct",
tenant_id="c4768f95",
api_base="https://serving.app.predibase.com",
api_key=os.getenv("PREDIBASE_API_KEY"),
messages=[{"role": "user", "content": "What is the meaning of life?"}],
stream=True,
)
# await response
complete_response = ""
idx = 0
async for init_chunk in response:
chunk, finished = streaming_format_tests(idx, init_chunk)
complete_response += chunk
custom_llm_provider = init_chunk._hidden_params["custom_llm_provider"]
print(f"custom_llm_provider: {custom_llm_provider}")
assert custom_llm_provider == "predibase"
idx += 1
if finished:
assert isinstance(
init_chunk.choices[0], litellm.utils.StreamingChoices
)
break
if complete_response.strip() == "":
raise Exception("Empty response received")
print(f"complete_response: {complete_response}")
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_completion_azure_function_calling_stream(): def test_completion_azure_function_calling_stream():
@ -1501,6 +1566,70 @@ def test_openai_chat_completion_complete_response_call():
# test_openai_chat_completion_complete_response_call() # test_openai_chat_completion_complete_response_call()
def test_openai_stream_options_call():
litellm.set_verbose = False
response = litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "system", "content": "say GM - we're going to make it "}],
stream=True,
stream_options={"include_usage": True},
max_tokens=10,
)
usage = None
chunks = []
for chunk in response:
print("chunk: ", chunk)
chunks.append(chunk)
last_chunk = chunks[-1]
print("last chunk: ", last_chunk)
"""
Assert that:
- Last Chunk includes Usage
- All chunks prior to last chunk have usage=None
"""
assert last_chunk.usage is not None
assert last_chunk.usage.total_tokens > 0
assert last_chunk.usage.prompt_tokens > 0
assert last_chunk.usage.completion_tokens > 0
# assert all non last chunks have usage=None
assert all(chunk.usage is None for chunk in chunks[:-1])
def test_openai_stream_options_call_text_completion():
litellm.set_verbose = False
response = litellm.text_completion(
model="gpt-3.5-turbo-instruct",
prompt="say GM - we're going to make it ",
stream=True,
stream_options={"include_usage": True},
max_tokens=10,
)
usage = None
chunks = []
for chunk in response:
print("chunk: ", chunk)
chunks.append(chunk)
last_chunk = chunks[-1]
print("last chunk: ", last_chunk)
"""
Assert that:
- Last Chunk includes Usage
- All chunks prior to last chunk have usage=None
"""
assert last_chunk.usage is not None
assert last_chunk.usage.total_tokens > 0
assert last_chunk.usage.prompt_tokens > 0
assert last_chunk.usage.completion_tokens > 0
# assert all non last chunks have usage=None
assert all(chunk.usage is None for chunk in chunks[:-1])
def test_openai_text_completion_call(): def test_openai_text_completion_call():

View file

@ -369,7 +369,7 @@ class ChatCompletionMessageToolCall(OpenAIObject):
class Message(OpenAIObject): class Message(OpenAIObject):
def __init__( def __init__(
self, self,
content="default", content: Optional[str] = "default",
role="assistant", role="assistant",
logprobs=None, logprobs=None,
function_call=None, function_call=None,
@ -612,6 +612,7 @@ class ModelResponse(OpenAIObject):
system_fingerprint=None, system_fingerprint=None,
usage=None, usage=None,
stream=None, stream=None,
stream_options=None,
response_ms=None, response_ms=None,
hidden_params=None, hidden_params=None,
**params, **params,
@ -658,6 +659,12 @@ class ModelResponse(OpenAIObject):
usage = usage usage = usage
elif stream is None or stream == False: elif stream is None or stream == False:
usage = Usage() usage = Usage()
elif (
stream == True
and stream_options is not None
and stream_options.get("include_usage") == True
):
usage = Usage()
if hidden_params: if hidden_params:
self._hidden_params = hidden_params self._hidden_params = hidden_params
@ -4839,6 +4846,7 @@ def get_optional_params(
top_p=None, top_p=None,
n=None, n=None,
stream=False, stream=False,
stream_options=None,
stop=None, stop=None,
max_tokens=None, max_tokens=None,
presence_penalty=None, presence_penalty=None,
@ -4908,6 +4916,7 @@ def get_optional_params(
"top_p": None, "top_p": None,
"n": None, "n": None,
"stream": None, "stream": None,
"stream_options": None,
"stop": None, "stop": None,
"max_tokens": None, "max_tokens": None,
"presence_penalty": None, "presence_penalty": None,
@ -5779,6 +5788,8 @@ def get_optional_params(
optional_params["n"] = n optional_params["n"] = n
if stream is not None: if stream is not None:
optional_params["stream"] = stream optional_params["stream"] = stream
if stream_options is not None:
optional_params["stream_options"] = stream_options
if stop is not None: if stop is not None:
optional_params["stop"] = stop optional_params["stop"] = stop
if max_tokens is not None: if max_tokens is not None:
@ -5927,13 +5938,15 @@ def get_api_base(model: str, optional_params: dict) -> Optional[str]:
model=model, **optional_params model=model, **optional_params
) # convert to pydantic object ) # convert to pydantic object
except Exception as e: except Exception as e:
verbose_logger.error("Error occurred in getting api base - {}".format(str(e))) verbose_logger.debug("Error occurred in getting api base - {}".format(str(e)))
return None return None
# get llm provider # get llm provider
if _optional_params.api_base is not None: if _optional_params.api_base is not None:
return _optional_params.api_base return _optional_params.api_base
if litellm.model_alias_map and model in litellm.model_alias_map:
model = litellm.model_alias_map[model]
try: try:
model, custom_llm_provider, dynamic_api_key, dynamic_api_base = ( model, custom_llm_provider, dynamic_api_key, dynamic_api_base = (
get_llm_provider( get_llm_provider(
@ -6083,6 +6096,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
"top_p", "top_p",
"n", "n",
"stream", "stream",
"stream_options",
"stop", "stop",
"max_tokens", "max_tokens",
"presence_penalty", "presence_penalty",
@ -9500,7 +9514,12 @@ def get_secret(
# replicate/anthropic/cohere # replicate/anthropic/cohere
class CustomStreamWrapper: class CustomStreamWrapper:
def __init__( def __init__(
self, completion_stream, model, custom_llm_provider=None, logging_obj=None self,
completion_stream,
model,
custom_llm_provider=None,
logging_obj=None,
stream_options=None,
): ):
self.model = model self.model = model
self.custom_llm_provider = custom_llm_provider self.custom_llm_provider = custom_llm_provider
@ -9526,6 +9545,7 @@ class CustomStreamWrapper:
self.response_id = None self.response_id = None
self.logging_loop = None self.logging_loop = None
self.rules = Rules() self.rules = Rules()
self.stream_options = stream_options
def __iter__(self): def __iter__(self):
return self return self
@ -9737,6 +9757,50 @@ class CustomStreamWrapper:
"finish_reason": finish_reason, "finish_reason": finish_reason,
} }
def handle_predibase_chunk(self, chunk):
try:
if type(chunk) != str:
chunk = chunk.decode(
"utf-8"
) # DO NOT REMOVE this: This is required for HF inference API + Streaming
text = ""
is_finished = False
finish_reason = ""
print_verbose(f"chunk: {chunk}")
if chunk.startswith("data:"):
data_json = json.loads(chunk[5:])
print_verbose(f"data json: {data_json}")
if "token" in data_json and "text" in data_json["token"]:
text = data_json["token"]["text"]
if data_json.get("details", False) and data_json["details"].get(
"finish_reason", False
):
is_finished = True
finish_reason = data_json["details"]["finish_reason"]
elif data_json.get(
"generated_text", False
): # if full generated text exists, then stream is complete
text = "" # don't return the final bos token
is_finished = True
finish_reason = "stop"
elif data_json.get("error", False):
raise Exception(data_json.get("error"))
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
}
elif "error" in chunk:
raise ValueError(chunk)
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
}
except Exception as e:
traceback.print_exc()
raise e
def handle_huggingface_chunk(self, chunk): def handle_huggingface_chunk(self, chunk):
try: try:
if type(chunk) != str: if type(chunk) != str:
@ -9966,6 +10030,7 @@ class CustomStreamWrapper:
is_finished = False is_finished = False
finish_reason = None finish_reason = None
logprobs = None logprobs = None
usage = None
original_chunk = None # this is used for function/tool calling original_chunk = None # this is used for function/tool calling
if len(str_line.choices) > 0: if len(str_line.choices) > 0:
if ( if (
@ -10000,12 +10065,15 @@ class CustomStreamWrapper:
else: else:
logprobs = None logprobs = None
usage = getattr(str_line, "usage", None)
return { return {
"text": text, "text": text,
"is_finished": is_finished, "is_finished": is_finished,
"finish_reason": finish_reason, "finish_reason": finish_reason,
"logprobs": logprobs, "logprobs": logprobs,
"original_chunk": str_line, "original_chunk": str_line,
"usage": usage,
} }
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
@ -10038,16 +10106,19 @@ class CustomStreamWrapper:
text = "" text = ""
is_finished = False is_finished = False
finish_reason = None finish_reason = None
usage = None
choices = getattr(chunk, "choices", []) choices = getattr(chunk, "choices", [])
if len(choices) > 0: if len(choices) > 0:
text = choices[0].text text = choices[0].text
if choices[0].finish_reason is not None: if choices[0].finish_reason is not None:
is_finished = True is_finished = True
finish_reason = choices[0].finish_reason finish_reason = choices[0].finish_reason
usage = getattr(chunk, "usage", None)
return { return {
"text": text, "text": text,
"is_finished": is_finished, "is_finished": is_finished,
"finish_reason": finish_reason, "finish_reason": finish_reason,
"usage": usage,
} }
except Exception as e: except Exception as e:
@ -10308,7 +10379,9 @@ class CustomStreamWrapper:
raise e raise e
def model_response_creator(self): def model_response_creator(self):
model_response = ModelResponse(stream=True, model=self.model) model_response = ModelResponse(
stream=True, model=self.model, stream_options=self.stream_options
)
if self.response_id is not None: if self.response_id is not None:
model_response.id = self.response_id model_response.id = self.response_id
else: else:
@ -10365,6 +10438,11 @@ class CustomStreamWrapper:
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]: if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"] self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider and self.custom_llm_provider == "predibase":
response_obj = self.handle_predibase_chunk(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
elif ( elif (
self.custom_llm_provider and self.custom_llm_provider == "baseten" self.custom_llm_provider and self.custom_llm_provider == "baseten"
): # baseten doesn't provide streaming ): # baseten doesn't provide streaming
@ -10567,18 +10645,6 @@ class CustomStreamWrapper:
elif self.custom_llm_provider == "watsonx": elif self.custom_llm_provider == "watsonx":
response_obj = self.handle_watsonx_stream(chunk) response_obj = self.handle_watsonx_stream(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}")
if getattr(model_response, "usage", None) is None:
model_response.usage = Usage()
if response_obj.get("prompt_tokens") is not None:
prompt_token_count = getattr(model_response.usage, "prompt_tokens", 0)
model_response.usage.prompt_tokens = (prompt_token_count+response_obj["prompt_tokens"])
if response_obj.get("completion_tokens") is not None:
model_response.usage.completion_tokens = response_obj["completion_tokens"]
model_response.usage.total_tokens = (
getattr(model_response.usage, "prompt_tokens", 0)
+ getattr(model_response.usage, "completion_tokens", 0)
)
if response_obj["is_finished"]: if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"] self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "text-completion-openai": elif self.custom_llm_provider == "text-completion-openai":
@ -10587,6 +10653,11 @@ class CustomStreamWrapper:
print_verbose(f"completion obj content: {completion_obj['content']}") print_verbose(f"completion obj content: {completion_obj['content']}")
if response_obj["is_finished"]: if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"] self.received_finish_reason = response_obj["finish_reason"]
if (
self.stream_options
and self.stream_options.get("include_usage", False) == True
):
model_response.usage = response_obj["usage"]
elif self.custom_llm_provider == "azure_text": elif self.custom_llm_provider == "azure_text":
response_obj = self.handle_azure_text_completion_chunk(chunk) response_obj = self.handle_azure_text_completion_chunk(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
@ -10640,6 +10711,12 @@ class CustomStreamWrapper:
if response_obj["logprobs"] is not None: if response_obj["logprobs"] is not None:
model_response.choices[0].logprobs = response_obj["logprobs"] model_response.choices[0].logprobs = response_obj["logprobs"]
if (
self.stream_options is not None
and self.stream_options["include_usage"] == True
):
model_response.usage = response_obj["usage"]
model_response.model = self.model model_response.model = self.model
print_verbose( print_verbose(
f"model_response finish reason 3: {self.received_finish_reason}; response_obj={response_obj}" f"model_response finish reason 3: {self.received_finish_reason}; response_obj={response_obj}"
@ -10727,6 +10804,11 @@ class CustomStreamWrapper:
except Exception as e: except Exception as e:
model_response.choices[0].delta = Delta() model_response.choices[0].delta = Delta()
else: else:
if (
self.stream_options is not None
and self.stream_options["include_usage"] == True
):
return model_response
return return
print_verbose( print_verbose(
f"model_response.choices[0].delta: {model_response.choices[0].delta}; completion_obj: {completion_obj}" f"model_response.choices[0].delta: {model_response.choices[0].delta}; completion_obj: {completion_obj}"
@ -10983,7 +11065,7 @@ class CustomStreamWrapper:
or self.custom_llm_provider == "sagemaker" or self.custom_llm_provider == "sagemaker"
or self.custom_llm_provider == "gemini" or self.custom_llm_provider == "gemini"
or self.custom_llm_provider == "cached_response" or self.custom_llm_provider == "cached_response"
or self.custom_llm_provider == "watsonx" or self.custom_llm_provider == "predibase"
or self.custom_llm_provider in litellm.openai_compatible_endpoints or self.custom_llm_provider in litellm.openai_compatible_endpoints
): ):
async for chunk in self.completion_stream: async for chunk in self.completion_stream:
@ -11106,9 +11188,10 @@ class CustomStreamWrapper:
class TextCompletionStreamWrapper: class TextCompletionStreamWrapper:
def __init__(self, completion_stream, model): def __init__(self, completion_stream, model, stream_options: Optional[dict] = None):
self.completion_stream = completion_stream self.completion_stream = completion_stream
self.model = model self.model = model
self.stream_options = stream_options
def __iter__(self): def __iter__(self):
return self return self
@ -11132,6 +11215,14 @@ class TextCompletionStreamWrapper:
text_choices["index"] = chunk["choices"][0]["index"] text_choices["index"] = chunk["choices"][0]["index"]
text_choices["finish_reason"] = chunk["choices"][0]["finish_reason"] text_choices["finish_reason"] = chunk["choices"][0]["finish_reason"]
response["choices"] = [text_choices] response["choices"] = [text_choices]
# only pass usage when stream_options["include_usage"] is True
if (
self.stream_options
and self.stream_options.get("include_usage", False) == True
):
response["usage"] = chunk.get("usage", None)
return response return response
except Exception as e: except Exception as e:
raise Exception( raise Exception(

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "litellm" name = "litellm"
version = "1.36.4" version = "1.37.0"
description = "Library to easily interface with LLM API providers" description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"] authors = ["BerriAI"]
license = "MIT" license = "MIT"
@ -80,7 +80,7 @@ requires = ["poetry-core", "wheel"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"
[tool.commitizen] [tool.commitizen]
version = "1.36.4" version = "1.37.0"
version_files = [ version_files = [
"pyproject.toml:^version" "pyproject.toml:^version"
] ]