forked from phoenix/litellm-mirror
Merge branch 'main' into feature/watsonx-integration
This commit is contained in:
commit
e1372de9ee
23 changed files with 8026 additions and 271 deletions
19
.github/workflows/interpret_load_test.py
vendored
19
.github/workflows/interpret_load_test.py
vendored
|
@ -64,6 +64,11 @@ if __name__ == "__main__":
|
||||||
) # Replace with your repository's username and name
|
) # Replace with your repository's username and name
|
||||||
latest_release = repo.get_latest_release()
|
latest_release = repo.get_latest_release()
|
||||||
print("got latest release: ", latest_release)
|
print("got latest release: ", latest_release)
|
||||||
|
print(latest_release.title)
|
||||||
|
print(latest_release.tag_name)
|
||||||
|
|
||||||
|
release_version = latest_release.title
|
||||||
|
|
||||||
print("latest release body: ", latest_release.body)
|
print("latest release body: ", latest_release.body)
|
||||||
print("markdown table: ", markdown_table)
|
print("markdown table: ", markdown_table)
|
||||||
|
|
||||||
|
@ -74,8 +79,22 @@ if __name__ == "__main__":
|
||||||
start_index = latest_release.body.find("Load Test LiteLLM Proxy Results")
|
start_index = latest_release.body.find("Load Test LiteLLM Proxy Results")
|
||||||
existing_release_body = latest_release.body[:start_index]
|
existing_release_body = latest_release.body[:start_index]
|
||||||
|
|
||||||
|
docker_run_command = f"""
|
||||||
|
\n\n
|
||||||
|
## Docker Run LiteLLM Proxy
|
||||||
|
|
||||||
|
```
|
||||||
|
docker run \\
|
||||||
|
-e STORE_MODEL_IN_DB=True \\
|
||||||
|
-p 4000:4000 \\
|
||||||
|
ghcr.io/berriai/litellm:main-{release_version}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
print("docker run command: ", docker_run_command)
|
||||||
|
|
||||||
new_release_body = (
|
new_release_body = (
|
||||||
existing_release_body
|
existing_release_body
|
||||||
|
+ docker_run_command
|
||||||
+ "\n\n"
|
+ "\n\n"
|
||||||
+ "### Don't want to maintain your internal proxy? get in touch 🎉"
|
+ "### Don't want to maintain your internal proxy? get in touch 🎉"
|
||||||
+ "\nHosted Proxy Alpha: https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat"
|
+ "\nHosted Proxy Alpha: https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat"
|
||||||
|
|
|
@ -16,11 +16,11 @@ repos:
|
||||||
name: Check if files match
|
name: Check if files match
|
||||||
entry: python3 ci_cd/check_files_match.py
|
entry: python3 ci_cd/check_files_match.py
|
||||||
language: system
|
language: system
|
||||||
# - repo: local
|
- repo: local
|
||||||
# hooks:
|
hooks:
|
||||||
# - id: mypy
|
- id: mypy
|
||||||
# name: mypy
|
name: mypy
|
||||||
# entry: python3 -m mypy --ignore-missing-imports
|
entry: python3 -m mypy --ignore-missing-imports
|
||||||
# language: system
|
language: system
|
||||||
# types: [python]
|
types: [python]
|
||||||
# files: ^litellm/
|
files: ^litellm/
|
|
@ -83,6 +83,7 @@ def completion(
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
n: Optional[int] = None,
|
n: Optional[int] = None,
|
||||||
stream: Optional[bool] = None,
|
stream: Optional[bool] = None,
|
||||||
|
stream_options: Optional[dict] = None,
|
||||||
stop=None,
|
stop=None,
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: Optional[int] = None,
|
||||||
presence_penalty: Optional[float] = None,
|
presence_penalty: Optional[float] = None,
|
||||||
|
@ -139,6 +140,10 @@ def completion(
|
||||||
|
|
||||||
- `stream`: *boolean or null (optional)* - If set to true, it sends partial message deltas. Tokens will be sent as they become available, with the stream terminated by a [DONE] message.
|
- `stream`: *boolean or null (optional)* - If set to true, it sends partial message deltas. Tokens will be sent as they become available, with the stream terminated by a [DONE] message.
|
||||||
|
|
||||||
|
- `stream_options` *dict or null (optional)* - Options for streaming response. Only set this when you set `stream: true`
|
||||||
|
|
||||||
|
- `include_usage` *boolean (optional)* - If set, an additional chunk will be streamed before the data: [DONE] message. The usage field on this chunk shows the token usage statistics for the entire request, and the choices field will always be an empty array. All other chunks will also include a usage field, but with a null value.
|
||||||
|
|
||||||
- `stop`: *string/ array/ null (optional)* - Up to 4 sequences where the API will stop generating further tokens.
|
- `stop`: *string/ array/ null (optional)* - Up to 4 sequences where the API will stop generating further tokens.
|
||||||
|
|
||||||
- `max_tokens`: *integer (optional)* - The maximum number of tokens to generate in the chat completion.
|
- `max_tokens`: *integer (optional)* - The maximum number of tokens to generate in the chat completion.
|
||||||
|
|
83
docs/my-website/docs/proxy/customer_routing.md
Normal file
83
docs/my-website/docs/proxy/customer_routing.md
Normal file
|
@ -0,0 +1,83 @@
|
||||||
|
# Region-based Routing
|
||||||
|
|
||||||
|
Route specific customers to eu-only models.
|
||||||
|
|
||||||
|
By specifying 'allowed_model_region' for a customer, LiteLLM will filter-out any models in a model group which is not in the allowed region (i.e. 'eu').
|
||||||
|
|
||||||
|
[**See Code**](https://github.com/BerriAI/litellm/blob/5eb12e30cc5faa73799ebc7e48fc86ebf449c879/litellm/router.py#L2938)
|
||||||
|
|
||||||
|
### 1. Create customer with region-specification
|
||||||
|
|
||||||
|
Use the litellm 'end-user' object for this.
|
||||||
|
|
||||||
|
End-users can be tracked / id'ed by passing the 'user' param to litellm in an openai chat completion/embedding call.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST --location 'http://0.0.0.0:4000/end_user/new' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"user_id" : "ishaan-jaff-45",
|
||||||
|
"allowed_model_region": "eu", # 👈 SPECIFY ALLOWED REGION='eu'
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Add eu models to model-group
|
||||||
|
|
||||||
|
Add eu models to a model group. For azure models, litellm can automatically infer the region (no need to set it).
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: gpt-3.5-turbo
|
||||||
|
litellm_params:
|
||||||
|
model: azure/gpt-35-turbo-eu # 👈 EU azure model
|
||||||
|
api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
|
||||||
|
api_key: os.environ/AZURE_EUROPE_API_KEY
|
||||||
|
- model_name: gpt-3.5-turbo
|
||||||
|
litellm_params:
|
||||||
|
model: azure/chatgpt-v-2
|
||||||
|
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
||||||
|
api_version: "2023-05-15"
|
||||||
|
api_key: os.environ/AZURE_API_KEY
|
||||||
|
|
||||||
|
router_settings:
|
||||||
|
enable_pre_call_checks: true # 👈 IMPORTANT
|
||||||
|
```
|
||||||
|
|
||||||
|
Start the proxy
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
litellm --config /path/to/config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Test it!
|
||||||
|
|
||||||
|
Make a simple chat completions call to the proxy. In the response headers, you should see the returned api base.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST --location 'http://localhost:4000/chat/completions' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--data '{
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "what is the meaning of the universe? 1234"
|
||||||
|
}],
|
||||||
|
"user": "ishaan-jaff-45" # 👈 USER ID
|
||||||
|
}
|
||||||
|
'
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected API Base in response headers
|
||||||
|
|
||||||
|
```
|
||||||
|
x-litellm-api-base: "https://my-endpoint-europe-berri-992.openai.azure.com/"
|
||||||
|
```
|
||||||
|
|
||||||
|
### FAQ
|
||||||
|
|
||||||
|
**What happens if there are no available models for that region?**
|
||||||
|
|
||||||
|
Since the router filters out models not in the specified region, it will return back as an error to the user, if no models in that region are available.
|
|
@ -50,6 +50,7 @@ const sidebars = {
|
||||||
items: ["proxy/logging", "proxy/streaming_logging"],
|
items: ["proxy/logging", "proxy/streaming_logging"],
|
||||||
},
|
},
|
||||||
"proxy/team_based_routing",
|
"proxy/team_based_routing",
|
||||||
|
"proxy/customer_routing",
|
||||||
"proxy/ui",
|
"proxy/ui",
|
||||||
"proxy/cost_tracking",
|
"proxy/cost_tracking",
|
||||||
"proxy/token_auth",
|
"proxy/token_auth",
|
||||||
|
|
|
@ -1,3 +1,6 @@
|
||||||
|
### Hide pydantic namespace conflict warnings globally ###
|
||||||
|
import warnings
|
||||||
|
warnings.filterwarnings("ignore", message=".*conflict with protected namespace.*")
|
||||||
### INIT VARIABLES ###
|
### INIT VARIABLES ###
|
||||||
import threading, requests, os
|
import threading, requests, os
|
||||||
from typing import Callable, List, Optional, Dict, Union, Any, Literal
|
from typing import Callable, List, Optional, Dict, Union, Any, Literal
|
||||||
|
@ -71,9 +74,11 @@ maritalk_key: Optional[str] = None
|
||||||
ai21_key: Optional[str] = None
|
ai21_key: Optional[str] = None
|
||||||
ollama_key: Optional[str] = None
|
ollama_key: Optional[str] = None
|
||||||
openrouter_key: Optional[str] = None
|
openrouter_key: Optional[str] = None
|
||||||
|
predibase_key: Optional[str] = None
|
||||||
huggingface_key: Optional[str] = None
|
huggingface_key: Optional[str] = None
|
||||||
vertex_project: Optional[str] = None
|
vertex_project: Optional[str] = None
|
||||||
vertex_location: Optional[str] = None
|
vertex_location: Optional[str] = None
|
||||||
|
predibase_tenant_id: Optional[str] = None
|
||||||
togetherai_api_key: Optional[str] = None
|
togetherai_api_key: Optional[str] = None
|
||||||
cloudflare_api_key: Optional[str] = None
|
cloudflare_api_key: Optional[str] = None
|
||||||
baseten_key: Optional[str] = None
|
baseten_key: Optional[str] = None
|
||||||
|
@ -532,6 +537,7 @@ provider_list: List = [
|
||||||
"xinference",
|
"xinference",
|
||||||
"fireworks_ai",
|
"fireworks_ai",
|
||||||
"watsonx",
|
"watsonx",
|
||||||
|
"predibase",
|
||||||
"custom", # custom apis
|
"custom", # custom apis
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -644,6 +650,7 @@ from .utils import (
|
||||||
)
|
)
|
||||||
from .llms.huggingface_restapi import HuggingfaceConfig
|
from .llms.huggingface_restapi import HuggingfaceConfig
|
||||||
from .llms.anthropic import AnthropicConfig
|
from .llms.anthropic import AnthropicConfig
|
||||||
|
from .llms.predibase import PredibaseConfig
|
||||||
from .llms.anthropic_text import AnthropicTextConfig
|
from .llms.anthropic_text import AnthropicTextConfig
|
||||||
from .llms.replicate import ReplicateConfig
|
from .llms.replicate import ReplicateConfig
|
||||||
from .llms.cohere import CohereConfig
|
from .llms.cohere import CohereConfig
|
||||||
|
|
|
@ -322,9 +322,9 @@ class Huggingface(BaseLLM):
|
||||||
encoding,
|
encoding,
|
||||||
api_key,
|
api_key,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
|
optional_params: dict,
|
||||||
custom_prompt_dict={},
|
custom_prompt_dict={},
|
||||||
acompletion: bool = False,
|
acompletion: bool = False,
|
||||||
optional_params=None,
|
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
):
|
):
|
||||||
|
@ -399,10 +399,11 @@ class Huggingface(BaseLLM):
|
||||||
data = {
|
data = {
|
||||||
"inputs": prompt,
|
"inputs": prompt,
|
||||||
"parameters": optional_params,
|
"parameters": optional_params,
|
||||||
"stream": (
|
"stream": ( # type: ignore
|
||||||
True
|
True
|
||||||
if "stream" in optional_params
|
if "stream" in optional_params
|
||||||
and optional_params["stream"] == True
|
and isinstance(optional_params["stream"], bool)
|
||||||
|
and optional_params["stream"] == True # type: ignore
|
||||||
else False
|
else False
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
@ -433,7 +434,7 @@ class Huggingface(BaseLLM):
|
||||||
data = {
|
data = {
|
||||||
"inputs": prompt,
|
"inputs": prompt,
|
||||||
"parameters": inference_params,
|
"parameters": inference_params,
|
||||||
"stream": (
|
"stream": ( # type: ignore
|
||||||
True
|
True
|
||||||
if "stream" in optional_params
|
if "stream" in optional_params
|
||||||
and optional_params["stream"] == True
|
and optional_params["stream"] == True
|
||||||
|
|
|
@ -530,6 +530,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
model=model,
|
model=model,
|
||||||
custom_llm_provider="openai",
|
custom_llm_provider="openai",
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
|
stream_options=data.get("stream_options", None),
|
||||||
)
|
)
|
||||||
return streamwrapper
|
return streamwrapper
|
||||||
|
|
||||||
|
@ -579,6 +580,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
model=model,
|
model=model,
|
||||||
custom_llm_provider="openai",
|
custom_llm_provider="openai",
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
|
stream_options=data.get("stream_options", None),
|
||||||
)
|
)
|
||||||
return streamwrapper
|
return streamwrapper
|
||||||
except (
|
except (
|
||||||
|
@ -1203,6 +1205,7 @@ class OpenAITextCompletion(BaseLLM):
|
||||||
model=model,
|
model=model,
|
||||||
custom_llm_provider="text-completion-openai",
|
custom_llm_provider="text-completion-openai",
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
|
stream_options=data.get("stream_options", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
for chunk in streamwrapper:
|
for chunk in streamwrapper:
|
||||||
|
@ -1241,6 +1244,7 @@ class OpenAITextCompletion(BaseLLM):
|
||||||
model=model,
|
model=model,
|
||||||
custom_llm_provider="text-completion-openai",
|
custom_llm_provider="text-completion-openai",
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
|
stream_options=data.get("stream_options", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
async for transformed_chunk in streamwrapper:
|
async for transformed_chunk in streamwrapper:
|
||||||
|
|
520
litellm/llms/predibase.py
Normal file
520
litellm/llms/predibase.py
Normal file
|
@ -0,0 +1,520 @@
|
||||||
|
# What is this?
|
||||||
|
## Controller file for Predibase Integration - https://predibase.com/
|
||||||
|
|
||||||
|
|
||||||
|
import os, types
|
||||||
|
import json
|
||||||
|
from enum import Enum
|
||||||
|
import requests, copy # type: ignore
|
||||||
|
import time
|
||||||
|
from typing import Callable, Optional, List, Literal, Union
|
||||||
|
from litellm.utils import (
|
||||||
|
ModelResponse,
|
||||||
|
Usage,
|
||||||
|
map_finish_reason,
|
||||||
|
CustomStreamWrapper,
|
||||||
|
Message,
|
||||||
|
Choices,
|
||||||
|
)
|
||||||
|
import litellm
|
||||||
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
|
from .base import BaseLLM
|
||||||
|
import httpx # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class PredibaseError(Exception):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
status_code,
|
||||||
|
message,
|
||||||
|
request: Optional[httpx.Request] = None,
|
||||||
|
response: Optional[httpx.Response] = None,
|
||||||
|
):
|
||||||
|
self.status_code = status_code
|
||||||
|
self.message = message
|
||||||
|
if request is not None:
|
||||||
|
self.request = request
|
||||||
|
else:
|
||||||
|
self.request = httpx.Request(
|
||||||
|
method="POST",
|
||||||
|
url="https://docs.predibase.com/user-guide/inference/rest_api",
|
||||||
|
)
|
||||||
|
if response is not None:
|
||||||
|
self.response = response
|
||||||
|
else:
|
||||||
|
self.response = httpx.Response(
|
||||||
|
status_code=status_code, request=self.request
|
||||||
|
)
|
||||||
|
super().__init__(
|
||||||
|
self.message
|
||||||
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
|
class PredibaseConfig:
|
||||||
|
"""
|
||||||
|
Reference: https://docs.predibase.com/user-guide/inference/rest_api
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
adapter_id: Optional[str] = None
|
||||||
|
adapter_source: Optional[Literal["pbase", "hub", "s3"]] = None
|
||||||
|
best_of: Optional[int] = None
|
||||||
|
decoder_input_details: Optional[bool] = None
|
||||||
|
details: bool = True # enables returning logprobs + best of
|
||||||
|
max_new_tokens: int = (
|
||||||
|
256 # openai default - requests hang if max_new_tokens not given
|
||||||
|
)
|
||||||
|
repetition_penalty: Optional[float] = None
|
||||||
|
return_full_text: Optional[bool] = (
|
||||||
|
False # by default don't return the input as part of the output
|
||||||
|
)
|
||||||
|
seed: Optional[int] = None
|
||||||
|
stop: Optional[List[str]] = None
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
top_k: Optional[int] = None
|
||||||
|
top_p: Optional[int] = None
|
||||||
|
truncate: Optional[int] = None
|
||||||
|
typical_p: Optional[float] = None
|
||||||
|
watermark: Optional[bool] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
best_of: Optional[int] = None,
|
||||||
|
decoder_input_details: Optional[bool] = None,
|
||||||
|
details: Optional[bool] = None,
|
||||||
|
max_new_tokens: Optional[int] = None,
|
||||||
|
repetition_penalty: Optional[float] = None,
|
||||||
|
return_full_text: Optional[bool] = None,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
top_p: Optional[int] = None,
|
||||||
|
truncate: Optional[int] = None,
|
||||||
|
typical_p: Optional[float] = None,
|
||||||
|
watermark: Optional[bool] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in cls.__dict__.items()
|
||||||
|
if not k.startswith("__")
|
||||||
|
and not isinstance(
|
||||||
|
v,
|
||||||
|
(
|
||||||
|
types.FunctionType,
|
||||||
|
types.BuiltinFunctionType,
|
||||||
|
classmethod,
|
||||||
|
staticmethod,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_supported_openai_params(self):
|
||||||
|
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
|
||||||
|
|
||||||
|
|
||||||
|
class PredibaseChatCompletion(BaseLLM):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def _validate_environment(self, api_key: Optional[str], user_headers: dict) -> dict:
|
||||||
|
if api_key is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Missing Predibase API Key - A call is being made to predibase but no key is set either in the environment variables or via params"
|
||||||
|
)
|
||||||
|
headers = {
|
||||||
|
"content-type": "application/json",
|
||||||
|
"Authorization": "Bearer {}".format(api_key),
|
||||||
|
}
|
||||||
|
if user_headers is not None and isinstance(user_headers, dict):
|
||||||
|
headers = {**headers, **user_headers}
|
||||||
|
return headers
|
||||||
|
|
||||||
|
def output_parser(self, generated_text: str):
|
||||||
|
"""
|
||||||
|
Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens.
|
||||||
|
|
||||||
|
Initial issue that prompted this - https://github.com/BerriAI/litellm/issues/763
|
||||||
|
"""
|
||||||
|
chat_template_tokens = [
|
||||||
|
"<|assistant|>",
|
||||||
|
"<|system|>",
|
||||||
|
"<|user|>",
|
||||||
|
"<s>",
|
||||||
|
"</s>",
|
||||||
|
]
|
||||||
|
for token in chat_template_tokens:
|
||||||
|
if generated_text.strip().startswith(token):
|
||||||
|
generated_text = generated_text.replace(token, "", 1)
|
||||||
|
if generated_text.endswith(token):
|
||||||
|
generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1]
|
||||||
|
return generated_text
|
||||||
|
|
||||||
|
def process_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
response: Union[requests.Response, httpx.Response],
|
||||||
|
model_response: ModelResponse,
|
||||||
|
stream: bool,
|
||||||
|
logging_obj: litellm.utils.Logging,
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: str,
|
||||||
|
data: dict,
|
||||||
|
messages: list,
|
||||||
|
print_verbose,
|
||||||
|
encoding,
|
||||||
|
) -> ModelResponse:
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key=api_key,
|
||||||
|
original_response=response.text,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
print_verbose(f"raw model_response: {response.text}")
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
try:
|
||||||
|
completion_response = response.json()
|
||||||
|
except:
|
||||||
|
raise PredibaseError(
|
||||||
|
message=response.text, status_code=response.status_code
|
||||||
|
)
|
||||||
|
if "error" in completion_response:
|
||||||
|
raise PredibaseError(
|
||||||
|
message=str(completion_response["error"]),
|
||||||
|
status_code=response.status_code,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if (
|
||||||
|
not isinstance(completion_response, dict)
|
||||||
|
or "generated_text" not in completion_response
|
||||||
|
):
|
||||||
|
raise PredibaseError(
|
||||||
|
status_code=422,
|
||||||
|
message=f"response is not in expected format - {completion_response}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(completion_response["generated_text"]) > 0:
|
||||||
|
model_response["choices"][0]["message"]["content"] = self.output_parser(
|
||||||
|
completion_response["generated_text"]
|
||||||
|
)
|
||||||
|
## GETTING LOGPROBS + FINISH REASON
|
||||||
|
if (
|
||||||
|
"details" in completion_response
|
||||||
|
and "tokens" in completion_response["details"]
|
||||||
|
):
|
||||||
|
model_response.choices[0].finish_reason = completion_response[
|
||||||
|
"details"
|
||||||
|
]["finish_reason"]
|
||||||
|
sum_logprob = 0
|
||||||
|
for token in completion_response["details"]["tokens"]:
|
||||||
|
if token["logprob"] != None:
|
||||||
|
sum_logprob += token["logprob"]
|
||||||
|
model_response["choices"][0][
|
||||||
|
"message"
|
||||||
|
]._logprob = (
|
||||||
|
sum_logprob # [TODO] move this to using the actual logprobs
|
||||||
|
)
|
||||||
|
if "best_of" in optional_params and optional_params["best_of"] > 1:
|
||||||
|
if (
|
||||||
|
"details" in completion_response
|
||||||
|
and "best_of_sequences" in completion_response["details"]
|
||||||
|
):
|
||||||
|
choices_list = []
|
||||||
|
for idx, item in enumerate(
|
||||||
|
completion_response["details"]["best_of_sequences"]
|
||||||
|
):
|
||||||
|
sum_logprob = 0
|
||||||
|
for token in item["tokens"]:
|
||||||
|
if token["logprob"] != None:
|
||||||
|
sum_logprob += token["logprob"]
|
||||||
|
if len(item["generated_text"]) > 0:
|
||||||
|
message_obj = Message(
|
||||||
|
content=self.output_parser(item["generated_text"]),
|
||||||
|
logprobs=sum_logprob,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
message_obj = Message(content=None)
|
||||||
|
choice_obj = Choices(
|
||||||
|
finish_reason=item["finish_reason"],
|
||||||
|
index=idx + 1,
|
||||||
|
message=message_obj,
|
||||||
|
)
|
||||||
|
choices_list.append(choice_obj)
|
||||||
|
model_response["choices"].extend(choices_list)
|
||||||
|
|
||||||
|
## CALCULATING USAGE
|
||||||
|
prompt_tokens = 0
|
||||||
|
try:
|
||||||
|
prompt_tokens = len(
|
||||||
|
encoding.encode(model_response["choices"][0]["message"]["content"])
|
||||||
|
) ##[TODO] use a model-specific tokenizer here
|
||||||
|
except:
|
||||||
|
# this should remain non blocking we should not block a response returning if calculating usage fails
|
||||||
|
pass
|
||||||
|
output_text = model_response["choices"][0]["message"].get("content", "")
|
||||||
|
if output_text is not None and len(output_text) > 0:
|
||||||
|
completion_tokens = 0
|
||||||
|
try:
|
||||||
|
completion_tokens = len(
|
||||||
|
encoding.encode(
|
||||||
|
model_response["choices"][0]["message"].get("content", "")
|
||||||
|
)
|
||||||
|
) ##[TODO] use a model-specific tokenizer
|
||||||
|
except:
|
||||||
|
# this should remain non blocking we should not block a response returning if calculating usage fails
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
completion_tokens = 0
|
||||||
|
|
||||||
|
total_tokens = prompt_tokens + completion_tokens
|
||||||
|
|
||||||
|
model_response["created"] = int(time.time())
|
||||||
|
model_response["model"] = model
|
||||||
|
usage = Usage(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
)
|
||||||
|
model_response.usage = usage # type: ignore
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
def completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
custom_prompt_dict: dict,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
encoding,
|
||||||
|
api_key: str,
|
||||||
|
logging_obj,
|
||||||
|
optional_params: dict,
|
||||||
|
tenant_id: str,
|
||||||
|
acompletion=None,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers: dict = {},
|
||||||
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
|
headers = self._validate_environment(api_key, headers)
|
||||||
|
completion_url = ""
|
||||||
|
input_text = ""
|
||||||
|
base_url = "https://serving.app.predibase.com"
|
||||||
|
if "https" in model:
|
||||||
|
completion_url = model
|
||||||
|
elif api_base:
|
||||||
|
base_url = api_base
|
||||||
|
elif "PREDIBASE_API_BASE" in os.environ:
|
||||||
|
base_url = os.getenv("PREDIBASE_API_BASE", "")
|
||||||
|
|
||||||
|
completion_url = f"{base_url}/{tenant_id}/deployments/v2/llms/{model}"
|
||||||
|
|
||||||
|
if optional_params.get("stream", False) == True:
|
||||||
|
completion_url += "/generate_stream"
|
||||||
|
else:
|
||||||
|
completion_url += "/generate"
|
||||||
|
|
||||||
|
if model in custom_prompt_dict:
|
||||||
|
# check if the model has a registered custom prompt
|
||||||
|
model_prompt_details = custom_prompt_dict[model]
|
||||||
|
prompt = custom_prompt(
|
||||||
|
role_dict=model_prompt_details["roles"],
|
||||||
|
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||||
|
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt = prompt_factory(model=model, messages=messages)
|
||||||
|
|
||||||
|
## Load Config
|
||||||
|
config = litellm.PredibaseConfig.get_config()
|
||||||
|
for k, v in config.items():
|
||||||
|
if (
|
||||||
|
k not in optional_params
|
||||||
|
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||||
|
optional_params[k] = v
|
||||||
|
|
||||||
|
stream = optional_params.pop("stream", False)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"inputs": prompt,
|
||||||
|
"parameters": optional_params,
|
||||||
|
}
|
||||||
|
input_text = prompt
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=input_text,
|
||||||
|
api_key=api_key,
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": data,
|
||||||
|
"headers": headers,
|
||||||
|
"api_base": completion_url,
|
||||||
|
"acompletion": acompletion,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
## COMPLETION CALL
|
||||||
|
if acompletion is True:
|
||||||
|
### ASYNC STREAMING
|
||||||
|
if stream == True:
|
||||||
|
return self.async_streaming(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
data=data,
|
||||||
|
api_base=completion_url,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
api_key=api_key,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
headers=headers,
|
||||||
|
) # type: ignore
|
||||||
|
else:
|
||||||
|
### ASYNC COMPLETION
|
||||||
|
return self.async_completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
data=data,
|
||||||
|
api_base=completion_url,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
api_key=api_key,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
stream=False,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
headers=headers,
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
### SYNC STREAMING
|
||||||
|
if stream == True:
|
||||||
|
response = requests.post(
|
||||||
|
completion_url,
|
||||||
|
headers=headers,
|
||||||
|
data=json.dumps(data),
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
_response = CustomStreamWrapper(
|
||||||
|
response.iter_lines(),
|
||||||
|
model,
|
||||||
|
custom_llm_provider="predibase",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
return _response
|
||||||
|
### SYNC COMPLETION
|
||||||
|
else:
|
||||||
|
response = requests.post(
|
||||||
|
url=completion_url,
|
||||||
|
headers=headers,
|
||||||
|
data=json.dumps(data),
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.process_response(
|
||||||
|
model=model,
|
||||||
|
response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
stream=optional_params.get("stream", False),
|
||||||
|
logging_obj=logging_obj, # type: ignore
|
||||||
|
optional_params=optional_params,
|
||||||
|
api_key=api_key,
|
||||||
|
data=data,
|
||||||
|
messages=messages,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
encoding,
|
||||||
|
api_key,
|
||||||
|
logging_obj,
|
||||||
|
stream,
|
||||||
|
data: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
) -> ModelResponse:
|
||||||
|
self.async_handler = AsyncHTTPHandler(
|
||||||
|
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||||
|
)
|
||||||
|
response = await self.async_handler.post(
|
||||||
|
api_base, headers=headers, data=json.dumps(data)
|
||||||
|
)
|
||||||
|
return self.process_response(
|
||||||
|
model=model,
|
||||||
|
response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
stream=stream,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_key=api_key,
|
||||||
|
data=data,
|
||||||
|
messages=messages,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
optional_params=optional_params,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_streaming(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
encoding,
|
||||||
|
api_key,
|
||||||
|
logging_obj,
|
||||||
|
data: dict,
|
||||||
|
optional_params=None,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
) -> CustomStreamWrapper:
|
||||||
|
self.async_handler = AsyncHTTPHandler(
|
||||||
|
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||||
|
)
|
||||||
|
data["stream"] = True
|
||||||
|
response = await self.async_handler.post(
|
||||||
|
url=api_base,
|
||||||
|
headers=headers,
|
||||||
|
data=json.dumps(data),
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise PredibaseError(
|
||||||
|
status_code=response.status_code, message=response.text
|
||||||
|
)
|
||||||
|
|
||||||
|
completion_stream = response.aiter_lines()
|
||||||
|
|
||||||
|
streamwrapper = CustomStreamWrapper(
|
||||||
|
completion_stream=completion_stream,
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="predibase",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
return streamwrapper
|
||||||
|
|
||||||
|
def embedding(self, *args, **kwargs):
|
||||||
|
pass
|
|
@ -487,7 +487,7 @@ def format_prompt_togetherai(messages, prompt_format, chat_template):
|
||||||
|
|
||||||
def ibm_granite_pt(messages: list):
|
def ibm_granite_pt(messages: list):
|
||||||
"""
|
"""
|
||||||
IBM's Granite chat models uses the template:
|
IBM's Granite models uses the template:
|
||||||
<|system|> {system_message} <|user|> {user_message} <|assistant|> {assistant_message}
|
<|system|> {system_message} <|user|> {user_message} <|assistant|> {assistant_message}
|
||||||
|
|
||||||
See: https://www.ibm.com/docs/en/watsonx-as-a-service?topic=solutions-supported-foundation-models
|
See: https://www.ibm.com/docs/en/watsonx-as-a-service?topic=solutions-supported-foundation-models
|
||||||
|
@ -503,13 +503,12 @@ def ibm_granite_pt(messages: list):
|
||||||
"pre_message": "<|user|>\n",
|
"pre_message": "<|user|>\n",
|
||||||
"post_message": "\n",
|
"post_message": "\n",
|
||||||
},
|
},
|
||||||
'assistant': {
|
"assistant": {
|
||||||
'pre_message': '<|assistant|>\n',
|
"pre_message": "<|assistant|>\n",
|
||||||
'post_message': '\n',
|
"post_message": "\n",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
final_prompt_value='<|assistant|>\n',
|
).strip()
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
### ANTHROPIC ###
|
### ANTHROPIC ###
|
||||||
|
@ -1525,9 +1524,24 @@ def prompt_factory(
|
||||||
return mistral_instruct_pt(messages=messages)
|
return mistral_instruct_pt(messages=messages)
|
||||||
elif "meta-llama/llama-3" in model and "instruct" in model:
|
elif "meta-llama/llama-3" in model and "instruct" in model:
|
||||||
# https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3/
|
# https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3/
|
||||||
return hf_chat_template(
|
return custom_prompt(
|
||||||
model="meta-llama/Meta-Llama-3-8B-Instruct",
|
role_dict={
|
||||||
|
"system": {
|
||||||
|
"pre_message": "<|start_header_id|>system<|end_header_id|>\n",
|
||||||
|
"post_message": "<|eot_id|>",
|
||||||
|
},
|
||||||
|
"user": {
|
||||||
|
"pre_message": "<|start_header_id|>user<|end_header_id|>\n",
|
||||||
|
"post_message": "<|eot_id|>",
|
||||||
|
},
|
||||||
|
"assistant": {
|
||||||
|
"pre_message": "<|start_header_id|>assistant<|end_header_id|>\n",
|
||||||
|
"post_message": "<|eot_id|>",
|
||||||
|
},
|
||||||
|
},
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
initial_prompt_value="<|begin_of_text|>",
|
||||||
|
final_prompt_value="<|start_header_id|>assistant<|end_header_id|>\n",
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
if "meta-llama/llama-2" in model and "chat" in model:
|
if "meta-llama/llama-2" in model and "chat" in model:
|
||||||
|
|
|
@ -451,9 +451,6 @@ class IBMWatsonXAI(BaseLLM):
|
||||||
return streamwrapper
|
return streamwrapper
|
||||||
|
|
||||||
# create the function to manage the request to watsonx.ai
|
# create the function to manage the request to watsonx.ai
|
||||||
# manage_request = self._make_request_manager(
|
|
||||||
# async_=(acompletion is True), logging_obj=logging_obj
|
|
||||||
# )
|
|
||||||
self.request_manager = RequestManager(logging_obj)
|
self.request_manager = RequestManager(logging_obj)
|
||||||
|
|
||||||
def handle_text_request(request_params: dict) -> ModelResponse:
|
def handle_text_request(request_params: dict) -> ModelResponse:
|
||||||
|
@ -576,9 +573,6 @@ class IBMWatsonXAI(BaseLLM):
|
||||||
"json": payload,
|
"json": payload,
|
||||||
"params": request_params,
|
"params": request_params,
|
||||||
}
|
}
|
||||||
# manage_request = self._make_request_manager(
|
|
||||||
# async_=(aembedding is True), logging_obj=logging_obj
|
|
||||||
# )
|
|
||||||
request_manager = RequestManager(logging_obj)
|
request_manager = RequestManager(logging_obj)
|
||||||
|
|
||||||
def process_embedding_response(json_resp: dict) -> ModelResponse:
|
def process_embedding_response(json_resp: dict) -> ModelResponse:
|
||||||
|
@ -654,143 +648,12 @@ class IBMWatsonXAI(BaseLLM):
|
||||||
request_params = dict(version=api_params["api_version"])
|
request_params = dict(version=api_params["api_version"])
|
||||||
url = api_params["url"].rstrip("/") + WatsonXAIEndpoint.AVAILABLE_MODELS
|
url = api_params["url"].rstrip("/") + WatsonXAIEndpoint.AVAILABLE_MODELS
|
||||||
req_params = dict(method="GET", url=url, headers=headers, params=request_params)
|
req_params = dict(method="GET", url=url, headers=headers, params=request_params)
|
||||||
# manage_request = self._make_request_manager(async_=False, logging_obj=None)
|
|
||||||
with RequestManager(logging_obj=None).request(req_params) as resp:
|
with RequestManager(logging_obj=None).request(req_params) as resp:
|
||||||
json_resp = resp.json()
|
json_resp = resp.json()
|
||||||
if not ids_only:
|
if not ids_only:
|
||||||
return json_resp
|
return json_resp
|
||||||
return [res["model_id"] for res in json_resp["resources"]]
|
return [res["model_id"] for res in json_resp["resources"]]
|
||||||
|
|
||||||
def _make_request_manager(
|
|
||||||
self, async_: bool, logging_obj=None
|
|
||||||
) -> Callable[
|
|
||||||
...,
|
|
||||||
Union[ContextManager[requests.Response], AsyncContextManager[httpx.Response]],
|
|
||||||
]:
|
|
||||||
"""
|
|
||||||
Returns a context manager that manages the response from the request.
|
|
||||||
if async_ is True, returns an async context manager, otherwise returns a regular context manager.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
```python
|
|
||||||
manage_request = self._make_request_manager(async_=True, logging_obj=logging_obj)
|
|
||||||
async with manage_request(request_params) as resp:
|
|
||||||
...
|
|
||||||
# or
|
|
||||||
manage_request = self._make_request_manager(async_=False, logging_obj=logging_obj)
|
|
||||||
with manage_request(request_params) as resp:
|
|
||||||
...
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
def pre_call(
|
|
||||||
request_params: dict,
|
|
||||||
input: Optional[Any] = None,
|
|
||||||
):
|
|
||||||
if logging_obj is None:
|
|
||||||
return
|
|
||||||
request_str = (
|
|
||||||
f"response = {'await ' if async_ else ''}{request_params['method']}(\n"
|
|
||||||
f"\turl={request_params['url']},\n"
|
|
||||||
f"\tjson={request_params.get('json')},\n"
|
|
||||||
f")"
|
|
||||||
)
|
|
||||||
logging_obj.pre_call(
|
|
||||||
input=input,
|
|
||||||
api_key=request_params["headers"].get("Authorization"),
|
|
||||||
additional_args={
|
|
||||||
"complete_input_dict": request_params.get("json"),
|
|
||||||
"request_str": request_str,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
def post_call(resp, request_params):
|
|
||||||
if logging_obj is None:
|
|
||||||
return
|
|
||||||
logging_obj.post_call(
|
|
||||||
input=input,
|
|
||||||
api_key=request_params["headers"].get("Authorization"),
|
|
||||||
original_response=json.dumps(resp.json()),
|
|
||||||
additional_args={
|
|
||||||
"status_code": resp.status_code,
|
|
||||||
"complete_input_dict": request_params.get(
|
|
||||||
"data", request_params.get("json")
|
|
||||||
),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def _manage_request(
|
|
||||||
request_params: dict,
|
|
||||||
stream: bool = False,
|
|
||||||
input: Optional[Any] = None,
|
|
||||||
timeout=None,
|
|
||||||
) -> Generator[requests.Response, None, None]:
|
|
||||||
"""
|
|
||||||
Returns a context manager that yields the response from the request.
|
|
||||||
"""
|
|
||||||
pre_call(request_params, input)
|
|
||||||
if timeout:
|
|
||||||
request_params["timeout"] = timeout
|
|
||||||
if stream:
|
|
||||||
request_params["stream"] = stream
|
|
||||||
try:
|
|
||||||
resp = requests.request(**request_params)
|
|
||||||
if not resp.ok:
|
|
||||||
raise WatsonXAIError(
|
|
||||||
status_code=resp.status_code,
|
|
||||||
message=f"Error {resp.status_code} ({resp.reason}): {resp.text}",
|
|
||||||
)
|
|
||||||
yield resp
|
|
||||||
except Exception as e:
|
|
||||||
raise WatsonXAIError(status_code=500, message=str(e))
|
|
||||||
if not stream:
|
|
||||||
post_call(resp, request_params)
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def _manage_request_async(
|
|
||||||
request_params: dict,
|
|
||||||
stream: bool = False,
|
|
||||||
input: Optional[Any] = None,
|
|
||||||
timeout=None,
|
|
||||||
) -> AsyncGenerator[httpx.Response, None]:
|
|
||||||
pre_call(request_params, input)
|
|
||||||
if timeout:
|
|
||||||
request_params["timeout"] = timeout
|
|
||||||
if stream:
|
|
||||||
request_params["stream"] = stream
|
|
||||||
try:
|
|
||||||
# async with AsyncHTTPHandler(timeout=timeout) as client:
|
|
||||||
self.async_handler = AsyncHTTPHandler(
|
|
||||||
timeout=httpx.Timeout(
|
|
||||||
timeout=request_params.pop("timeout", 600.0), connect=5.0
|
|
||||||
),
|
|
||||||
)
|
|
||||||
# async_handler.client.verify = False
|
|
||||||
if "json" in request_params:
|
|
||||||
request_params["data"] = json.dumps(request_params.pop("json", {}))
|
|
||||||
method = request_params.pop("method")
|
|
||||||
if method.upper() == "POST":
|
|
||||||
resp = await self.async_handler.post(**request_params)
|
|
||||||
else:
|
|
||||||
resp = await self.async_handler.get(**request_params)
|
|
||||||
if resp.status_code not in [200, 201]:
|
|
||||||
raise WatsonXAIError(
|
|
||||||
status_code=resp.status_code,
|
|
||||||
message=f"Error {resp.status_code} ({resp.reason}): {resp.text}",
|
|
||||||
)
|
|
||||||
yield resp
|
|
||||||
# await async_handler.close()
|
|
||||||
except Exception as e:
|
|
||||||
raise WatsonXAIError(status_code=500, message=str(e))
|
|
||||||
if not stream:
|
|
||||||
post_call(resp, request_params)
|
|
||||||
|
|
||||||
if async_:
|
|
||||||
return _manage_request_async
|
|
||||||
else:
|
|
||||||
return _manage_request
|
|
||||||
|
|
||||||
class RequestManager:
|
class RequestManager:
|
||||||
"""
|
"""
|
||||||
Returns a context manager that manages the response from the request.
|
Returns a context manager that manages the response from the request.
|
||||||
|
|
|
@ -14,6 +14,7 @@ import dotenv, traceback, random, asyncio, time, contextvars
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import httpx
|
import httpx
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
from ._logging import verbose_logger
|
from ._logging import verbose_logger
|
||||||
from litellm import ( # type: ignore
|
from litellm import ( # type: ignore
|
||||||
client,
|
client,
|
||||||
|
@ -73,7 +74,7 @@ from .llms.azure_text import AzureTextCompletion
|
||||||
from .llms.anthropic import AnthropicChatCompletion
|
from .llms.anthropic import AnthropicChatCompletion
|
||||||
from .llms.anthropic_text import AnthropicTextCompletion
|
from .llms.anthropic_text import AnthropicTextCompletion
|
||||||
from .llms.huggingface_restapi import Huggingface
|
from .llms.huggingface_restapi import Huggingface
|
||||||
from .llms.watsonx import IBMWatsonXAI
|
from .llms.predibase import PredibaseChatCompletion
|
||||||
from .llms.prompt_templates.factory import (
|
from .llms.prompt_templates.factory import (
|
||||||
prompt_factory,
|
prompt_factory,
|
||||||
custom_prompt,
|
custom_prompt,
|
||||||
|
@ -110,7 +111,7 @@ anthropic_text_completions = AnthropicTextCompletion()
|
||||||
azure_chat_completions = AzureChatCompletion()
|
azure_chat_completions = AzureChatCompletion()
|
||||||
azure_text_completions = AzureTextCompletion()
|
azure_text_completions = AzureTextCompletion()
|
||||||
huggingface = Huggingface()
|
huggingface = Huggingface()
|
||||||
watsonxai = IBMWatsonXAI()
|
predibase_chat_completions = PredibaseChatCompletion()
|
||||||
####### COMPLETION ENDPOINTS ################
|
####### COMPLETION ENDPOINTS ################
|
||||||
|
|
||||||
|
|
||||||
|
@ -189,6 +190,7 @@ async def acompletion(
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
n: Optional[int] = None,
|
n: Optional[int] = None,
|
||||||
stream: Optional[bool] = None,
|
stream: Optional[bool] = None,
|
||||||
|
stream_options: Optional[dict] = None,
|
||||||
stop=None,
|
stop=None,
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: Optional[int] = None,
|
||||||
presence_penalty: Optional[float] = None,
|
presence_penalty: Optional[float] = None,
|
||||||
|
@ -208,6 +210,7 @@ async def acompletion(
|
||||||
api_version: Optional[str] = None,
|
api_version: Optional[str] = None,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
model_list: Optional[list] = None, # pass in a list of api_base,keys, etc.
|
model_list: Optional[list] = None, # pass in a list of api_base,keys, etc.
|
||||||
|
extra_headers: Optional[dict] = None,
|
||||||
# Optional liteLLM function params
|
# Optional liteLLM function params
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
@ -225,6 +228,7 @@ async def acompletion(
|
||||||
top_p (float, optional): The top-p parameter for nucleus sampling (default is 1.0).
|
top_p (float, optional): The top-p parameter for nucleus sampling (default is 1.0).
|
||||||
n (int, optional): The number of completions to generate (default is 1).
|
n (int, optional): The number of completions to generate (default is 1).
|
||||||
stream (bool, optional): If True, return a streaming response (default is False).
|
stream (bool, optional): If True, return a streaming response (default is False).
|
||||||
|
stream_options (dict, optional): A dictionary containing options for the streaming response. Only use this if stream is True.
|
||||||
stop(string/list, optional): - Up to 4 sequences where the LLM API will stop generating further tokens.
|
stop(string/list, optional): - Up to 4 sequences where the LLM API will stop generating further tokens.
|
||||||
max_tokens (integer, optional): The maximum number of tokens in the generated completion (default is infinity).
|
max_tokens (integer, optional): The maximum number of tokens in the generated completion (default is infinity).
|
||||||
presence_penalty (float, optional): It is used to penalize new tokens based on their existence in the text so far.
|
presence_penalty (float, optional): It is used to penalize new tokens based on their existence in the text so far.
|
||||||
|
@ -262,6 +266,7 @@ async def acompletion(
|
||||||
"top_p": top_p,
|
"top_p": top_p,
|
||||||
"n": n,
|
"n": n,
|
||||||
"stream": stream,
|
"stream": stream,
|
||||||
|
"stream_options": stream_options,
|
||||||
"stop": stop,
|
"stop": stop,
|
||||||
"max_tokens": max_tokens,
|
"max_tokens": max_tokens,
|
||||||
"presence_penalty": presence_penalty,
|
"presence_penalty": presence_penalty,
|
||||||
|
@ -315,7 +320,7 @@ async def acompletion(
|
||||||
or custom_llm_provider == "gemini"
|
or custom_llm_provider == "gemini"
|
||||||
or custom_llm_provider == "sagemaker"
|
or custom_llm_provider == "sagemaker"
|
||||||
or custom_llm_provider == "anthropic"
|
or custom_llm_provider == "anthropic"
|
||||||
or custom_llm_provider == "watsonx"
|
or custom_llm_provider == "predibase"
|
||||||
or custom_llm_provider in litellm.openai_compatible_providers
|
or custom_llm_provider in litellm.openai_compatible_providers
|
||||||
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
|
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
|
||||||
init_response = await loop.run_in_executor(None, func_with_context)
|
init_response = await loop.run_in_executor(None, func_with_context)
|
||||||
|
@ -460,6 +465,7 @@ def completion(
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
n: Optional[int] = None,
|
n: Optional[int] = None,
|
||||||
stream: Optional[bool] = None,
|
stream: Optional[bool] = None,
|
||||||
|
stream_options: Optional[dict] = None,
|
||||||
stop=None,
|
stop=None,
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: Optional[int] = None,
|
||||||
presence_penalty: Optional[float] = None,
|
presence_penalty: Optional[float] = None,
|
||||||
|
@ -499,6 +505,7 @@ def completion(
|
||||||
top_p (float, optional): The top-p parameter for nucleus sampling (default is 1.0).
|
top_p (float, optional): The top-p parameter for nucleus sampling (default is 1.0).
|
||||||
n (int, optional): The number of completions to generate (default is 1).
|
n (int, optional): The number of completions to generate (default is 1).
|
||||||
stream (bool, optional): If True, return a streaming response (default is False).
|
stream (bool, optional): If True, return a streaming response (default is False).
|
||||||
|
stream_options (dict, optional): A dictionary containing options for the streaming response. Only set this when you set stream: true.
|
||||||
stop(string/list, optional): - Up to 4 sequences where the LLM API will stop generating further tokens.
|
stop(string/list, optional): - Up to 4 sequences where the LLM API will stop generating further tokens.
|
||||||
max_tokens (integer, optional): The maximum number of tokens in the generated completion (default is infinity).
|
max_tokens (integer, optional): The maximum number of tokens in the generated completion (default is infinity).
|
||||||
presence_penalty (float, optional): It is used to penalize new tokens based on their existence in the text so far.
|
presence_penalty (float, optional): It is used to penalize new tokens based on their existence in the text so far.
|
||||||
|
@ -576,6 +583,7 @@ def completion(
|
||||||
"top_p",
|
"top_p",
|
||||||
"n",
|
"n",
|
||||||
"stream",
|
"stream",
|
||||||
|
"stream_options",
|
||||||
"stop",
|
"stop",
|
||||||
"max_tokens",
|
"max_tokens",
|
||||||
"presence_penalty",
|
"presence_penalty",
|
||||||
|
@ -788,6 +796,7 @@ def completion(
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
n=n,
|
n=n,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
|
stream_options=stream_options,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
presence_penalty=presence_penalty,
|
presence_penalty=presence_penalty,
|
||||||
|
@ -1779,6 +1788,52 @@ def completion(
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
response = model_response
|
response = model_response
|
||||||
|
elif custom_llm_provider == "predibase":
|
||||||
|
tenant_id = (
|
||||||
|
optional_params.pop("tenant_id", None)
|
||||||
|
or optional_params.pop("predibase_tenant_id", None)
|
||||||
|
or litellm.predibase_tenant_id
|
||||||
|
or get_secret("PREDIBASE_TENANT_ID")
|
||||||
|
)
|
||||||
|
|
||||||
|
api_base = (
|
||||||
|
optional_params.pop("api_base", None)
|
||||||
|
or optional_params.pop("base_url", None)
|
||||||
|
or litellm.api_base
|
||||||
|
or get_secret("PREDIBASE_API_BASE")
|
||||||
|
)
|
||||||
|
|
||||||
|
api_key = (
|
||||||
|
api_key
|
||||||
|
or litellm.api_key
|
||||||
|
or litellm.predibase_key
|
||||||
|
or get_secret("PREDIBASE_API_KEY")
|
||||||
|
)
|
||||||
|
|
||||||
|
_model_response = predibase_chat_completions.completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
encoding=encoding,
|
||||||
|
logging_obj=logging,
|
||||||
|
acompletion=acompletion,
|
||||||
|
api_base=api_base,
|
||||||
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
|
api_key=api_key,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
"stream" in optional_params
|
||||||
|
and optional_params["stream"] == True
|
||||||
|
and acompletion == False
|
||||||
|
):
|
||||||
|
return _model_response
|
||||||
|
response = _model_response
|
||||||
elif custom_llm_provider == "ai21":
|
elif custom_llm_provider == "ai21":
|
||||||
custom_llm_provider = "ai21"
|
custom_llm_provider = "ai21"
|
||||||
ai21_key = (
|
ai21_key = (
|
||||||
|
@ -1911,7 +1966,7 @@ def completion(
|
||||||
response = response
|
response = response
|
||||||
elif custom_llm_provider == "watsonx":
|
elif custom_llm_provider == "watsonx":
|
||||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||||
response = watsonxai.completion(
|
response = watsonx.IBMWatsonXAI().completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
custom_prompt_dict=custom_prompt_dict,
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
|
@ -1922,8 +1977,7 @@ def completion(
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
acompletion=acompletion,
|
timeout=timeout, # type: ignore
|
||||||
timeout=timeout,
|
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
"stream" in optional_params
|
"stream" in optional_params
|
||||||
|
@ -2576,7 +2630,6 @@ async def aembedding(*args, **kwargs):
|
||||||
or custom_llm_provider == "fireworks_ai"
|
or custom_llm_provider == "fireworks_ai"
|
||||||
or custom_llm_provider == "ollama"
|
or custom_llm_provider == "ollama"
|
||||||
or custom_llm_provider == "vertex_ai"
|
or custom_llm_provider == "vertex_ai"
|
||||||
or custom_llm_provider == "watsonx"
|
|
||||||
): # currently implemented aiohttp calls for just azure and openai, soon all.
|
): # currently implemented aiohttp calls for just azure and openai, soon all.
|
||||||
# Await normally
|
# Await normally
|
||||||
init_response = await loop.run_in_executor(None, func_with_context)
|
init_response = await loop.run_in_executor(None, func_with_context)
|
||||||
|
@ -3034,14 +3087,13 @@ def embedding(
|
||||||
aembedding=aembedding,
|
aembedding=aembedding,
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "watsonx":
|
elif custom_llm_provider == "watsonx":
|
||||||
response = watsonxai.embedding(
|
response = watsonx.IBMWatsonXAI().embedding(
|
||||||
model=model,
|
model=model,
|
||||||
input=input,
|
input=input,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
model_response=EmbeddingResponse(),
|
model_response=EmbeddingResponse(),
|
||||||
aembedding=aembedding,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
args = locals()
|
args = locals()
|
||||||
|
@ -3197,6 +3249,7 @@ def text_completion(
|
||||||
Union[str, List[str]]
|
Union[str, List[str]]
|
||||||
] = None, # Optional: Sequences where the API will stop generating further tokens.
|
] = None, # Optional: Sequences where the API will stop generating further tokens.
|
||||||
stream: Optional[bool] = None, # Optional: Whether to stream back partial progress.
|
stream: Optional[bool] = None, # Optional: Whether to stream back partial progress.
|
||||||
|
stream_options: Optional[dict] = None,
|
||||||
suffix: Optional[
|
suffix: Optional[
|
||||||
str
|
str
|
||||||
] = None, # Optional: The suffix that comes after a completion of inserted text.
|
] = None, # Optional: The suffix that comes after a completion of inserted text.
|
||||||
|
@ -3274,6 +3327,8 @@ def text_completion(
|
||||||
optional_params["stop"] = stop
|
optional_params["stop"] = stop
|
||||||
if stream is not None:
|
if stream is not None:
|
||||||
optional_params["stream"] = stream
|
optional_params["stream"] = stream
|
||||||
|
if stream_options is not None:
|
||||||
|
optional_params["stream_options"] = stream_options
|
||||||
if suffix is not None:
|
if suffix is not None:
|
||||||
optional_params["suffix"] = suffix
|
optional_params["suffix"] = suffix
|
||||||
if temperature is not None:
|
if temperature is not None:
|
||||||
|
@ -3384,7 +3439,9 @@ def text_completion(
|
||||||
if kwargs.get("acompletion", False) == True:
|
if kwargs.get("acompletion", False) == True:
|
||||||
return response
|
return response
|
||||||
if stream == True or kwargs.get("stream", False) == True:
|
if stream == True or kwargs.get("stream", False) == True:
|
||||||
response = TextCompletionStreamWrapper(completion_stream=response, model=model)
|
response = TextCompletionStreamWrapper(
|
||||||
|
completion_stream=response, model=model, stream_options=stream_options
|
||||||
|
)
|
||||||
return response
|
return response
|
||||||
transformed_logprobs = None
|
transformed_logprobs = None
|
||||||
# only supported for TGI models
|
# only supported for TGI models
|
||||||
|
|
|
@ -206,11 +206,9 @@ async def get_end_user_object(
|
||||||
|
|
||||||
if end_user_id is None:
|
if end_user_id is None:
|
||||||
return None
|
return None
|
||||||
|
_key = "end_user_id:{}".format(end_user_id)
|
||||||
# check if in cache
|
# check if in cache
|
||||||
cached_user_obj = user_api_key_cache.async_get_cache(
|
cached_user_obj = await user_api_key_cache.async_get_cache(key=_key)
|
||||||
key="end_user_id:{}".format(end_user_id)
|
|
||||||
)
|
|
||||||
if cached_user_obj is not None:
|
if cached_user_obj is not None:
|
||||||
if isinstance(cached_user_obj, dict):
|
if isinstance(cached_user_obj, dict):
|
||||||
return LiteLLM_EndUserTable(**cached_user_obj)
|
return LiteLLM_EndUserTable(**cached_user_obj)
|
||||||
|
|
|
@ -1086,9 +1086,7 @@ async def user_api_key_auth(
|
||||||
user_id_information, list
|
user_id_information, list
|
||||||
):
|
):
|
||||||
_user = user_id_information[0]
|
_user = user_id_information[0]
|
||||||
user_role = _user.get("user_role", {}).get(
|
user_role = _user.get("user_role", "unknown")
|
||||||
"user_role", "unknown"
|
|
||||||
)
|
|
||||||
user_id = _user.get("user_id", "unknown")
|
user_id = _user.get("user_id", "unknown")
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Only proxy admin can be used to generate, delete, update info for new keys/users/teams. Route={route}. Your role={user_role}. Your user_id={user_id}"
|
f"Only proxy admin can be used to generate, delete, update info for new keys/users/teams. Route={route}. Your role={user_role}. Your user_id={user_id}"
|
||||||
|
@ -1834,6 +1832,9 @@ async def update_cache(
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _update_end_user_cache():
|
async def _update_end_user_cache():
|
||||||
|
if end_user_id is None or response_cost is None:
|
||||||
|
return
|
||||||
|
|
||||||
_id = "end_user_id:{}".format(end_user_id)
|
_id = "end_user_id:{}".format(end_user_id)
|
||||||
try:
|
try:
|
||||||
# Fetch the existing cost for the given user
|
# Fetch the existing cost for the given user
|
||||||
|
@ -1846,7 +1847,7 @@ async def update_cache(
|
||||||
if litellm.max_end_user_budget is not None:
|
if litellm.max_end_user_budget is not None:
|
||||||
max_end_user_budget = litellm.max_end_user_budget
|
max_end_user_budget = litellm.max_end_user_budget
|
||||||
existing_spend_obj = LiteLLM_EndUserTable(
|
existing_spend_obj = LiteLLM_EndUserTable(
|
||||||
user_id=_id,
|
user_id=end_user_id,
|
||||||
spend=0,
|
spend=0,
|
||||||
blocked=False,
|
blocked=False,
|
||||||
litellm_budget_table=LiteLLM_BudgetTable(
|
litellm_budget_table=LiteLLM_BudgetTable(
|
||||||
|
@ -1874,7 +1875,7 @@ async def update_cache(
|
||||||
existing_spend_obj.spend = new_spend
|
existing_spend_obj.spend = new_spend
|
||||||
user_api_key_cache.set_cache(key=_id, value=existing_spend_obj.json())
|
user_api_key_cache.set_cache(key=_id, value=existing_spend_obj.json())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.error(
|
||||||
f"An error occurred updating end user cache: {str(e)}\n\n{traceback.format_exc()}"
|
f"An error occurred updating end user cache: {str(e)}\n\n{traceback.format_exc()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -7310,6 +7311,43 @@ async def unblock_team(
|
||||||
return record
|
return record
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/team/list", tags=["team management"], dependencies=[Depends(user_api_key_auth)]
|
||||||
|
)
|
||||||
|
async def list_team(
|
||||||
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
[Admin-only] List all available teams
|
||||||
|
|
||||||
|
```
|
||||||
|
curl --location --request GET 'http://0.0.0.0:4000/team/list' \
|
||||||
|
--header 'Authorization: Bearer sk-1234'
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
global prisma_client
|
||||||
|
|
||||||
|
if user_api_key_dict.user_role != "proxy_admin":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail={
|
||||||
|
"error": "Admin-only endpoint. Your user role={}".format(
|
||||||
|
user_api_key_dict.user_role
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if prisma_client is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await prisma_client.db.litellm_teamtable.find_many()
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
#### ORGANIZATION MANAGEMENT ####
|
#### ORGANIZATION MANAGEMENT ####
|
||||||
|
|
||||||
|
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -1,5 +1,6 @@
|
||||||
import pytest
|
import pytest
|
||||||
from litellm import acompletion
|
from litellm import acompletion
|
||||||
|
from litellm import completion
|
||||||
|
|
||||||
|
|
||||||
def test_acompletion_params():
|
def test_acompletion_params():
|
||||||
|
@ -7,17 +8,29 @@ def test_acompletion_params():
|
||||||
from litellm.types.completion import CompletionRequest
|
from litellm.types.completion import CompletionRequest
|
||||||
|
|
||||||
acompletion_params_odict = inspect.signature(acompletion).parameters
|
acompletion_params_odict = inspect.signature(acompletion).parameters
|
||||||
acompletion_params = {name: param.annotation for name, param in acompletion_params_odict.items()}
|
completion_params_dict = inspect.signature(completion).parameters
|
||||||
completion_params = {field_name: field_type for field_name, field_type in CompletionRequest.__annotations__.items()}
|
|
||||||
|
|
||||||
# remove kwargs
|
acompletion_params = {
|
||||||
acompletion_params.pop("kwargs", None)
|
name: param.annotation for name, param in acompletion_params_odict.items()
|
||||||
|
}
|
||||||
|
completion_params = {
|
||||||
|
name: param.annotation for name, param in completion_params_dict.items()
|
||||||
|
}
|
||||||
|
|
||||||
keys_acompletion = set(acompletion_params.keys())
|
keys_acompletion = set(acompletion_params.keys())
|
||||||
keys_completion = set(completion_params.keys())
|
keys_completion = set(completion_params.keys())
|
||||||
|
|
||||||
|
print(keys_acompletion)
|
||||||
|
print("\n\n\n")
|
||||||
|
print(keys_completion)
|
||||||
|
|
||||||
|
print("diff=", keys_completion - keys_acompletion)
|
||||||
|
|
||||||
# Assert that the parameters are the same
|
# Assert that the parameters are the same
|
||||||
if keys_acompletion != keys_completion:
|
if keys_acompletion != keys_completion:
|
||||||
pytest.fail("The parameters of the acompletion function and the CompletionRequest class are not the same.")
|
pytest.fail(
|
||||||
|
"The parameters of the litellm.acompletion function and litellm.completion are not the same."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# test_acompletion_params()
|
# test_acompletion_params()
|
||||||
|
|
|
@ -85,6 +85,42 @@ def test_completion_azure_command_r():
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
# @pytest.mark.skip(reason="local test")
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_completion_predibase(sync_mode):
|
||||||
|
try:
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
if sync_mode:
|
||||||
|
response = completion(
|
||||||
|
model="predibase/llama-3-8b-instruct",
|
||||||
|
tenant_id="c4768f95",
|
||||||
|
api_base="https://serving.app.predibase.com",
|
||||||
|
api_key=os.getenv("PREDIBASE_API_KEY"),
|
||||||
|
messages=[{"role": "user", "content": "What is the meaning of life?"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
else:
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="predibase/llama-3-8b-instruct",
|
||||||
|
tenant_id="c4768f95",
|
||||||
|
api_base="https://serving.app.predibase.com",
|
||||||
|
api_key=os.getenv("PREDIBASE_API_KEY"),
|
||||||
|
messages=[{"role": "user", "content": "What is the meaning of life?"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
except litellm.Timeout as e:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
# test_completion_predibase()
|
||||||
|
|
||||||
|
|
||||||
def test_completion_claude():
|
def test_completion_claude():
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
litellm.cache = None
|
litellm.cache = None
|
||||||
|
|
|
@ -418,9 +418,16 @@ def test_call_with_user_over_budget(prisma_client):
|
||||||
print(vars(e))
|
print(vars(e))
|
||||||
|
|
||||||
|
|
||||||
|
def test_end_user_cache_write_unit_test():
|
||||||
|
"""
|
||||||
|
assert end user object is being written to cache as expected
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def test_call_with_end_user_over_budget(prisma_client):
|
def test_call_with_end_user_over_budget(prisma_client):
|
||||||
# Test if a user passed to /chat/completions is tracked & fails when they cross their budget
|
# Test if a user passed to /chat/completions is tracked & fails when they cross their budget
|
||||||
# we only check this when litellm.max_user_budget is set
|
# we only check this when litellm.max_end_user_budget is set
|
||||||
import random
|
import random
|
||||||
|
|
||||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||||
|
|
|
@ -150,9 +150,9 @@ async def test_router_atext_completion_streaming():
|
||||||
{
|
{
|
||||||
"model_name": "azure-model",
|
"model_name": "azure-model",
|
||||||
"litellm_params": {
|
"litellm_params": {
|
||||||
"model": "azure/gpt-35-turbo",
|
"model": "azure/gpt-turbo",
|
||||||
"api_key": "os.environ/AZURE_EUROPE_API_KEY",
|
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
|
||||||
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
|
"api_base": "https://openai-france-1234.openai.azure.com",
|
||||||
"rpm": 6,
|
"rpm": 6,
|
||||||
},
|
},
|
||||||
"model_info": {"id": 2},
|
"model_info": {"id": 2},
|
||||||
|
@ -160,9 +160,9 @@ async def test_router_atext_completion_streaming():
|
||||||
{
|
{
|
||||||
"model_name": "azure-model",
|
"model_name": "azure-model",
|
||||||
"litellm_params": {
|
"litellm_params": {
|
||||||
"model": "azure/gpt-35-turbo",
|
"model": "azure/gpt-turbo",
|
||||||
"api_key": "os.environ/AZURE_CANADA_API_KEY",
|
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
|
||||||
"api_base": "https://my-endpoint-canada-berri992.openai.azure.com",
|
"api_base": "https://openai-france-1234.openai.azure.com",
|
||||||
"rpm": 6,
|
"rpm": 6,
|
||||||
},
|
},
|
||||||
"model_info": {"id": 3},
|
"model_info": {"id": 3},
|
||||||
|
@ -193,7 +193,7 @@ async def test_router_atext_completion_streaming():
|
||||||
## check if calls equally distributed
|
## check if calls equally distributed
|
||||||
cache_dict = router.cache.get_cache(key=cache_key)
|
cache_dict = router.cache.get_cache(key=cache_key)
|
||||||
for k, v in cache_dict.items():
|
for k, v in cache_dict.items():
|
||||||
assert v == 1
|
assert v == 1, f"Failed. K={k} called v={v} times, cache_dict={cache_dict}"
|
||||||
|
|
||||||
|
|
||||||
# asyncio.run(test_router_atext_completion_streaming())
|
# asyncio.run(test_router_atext_completion_streaming())
|
||||||
|
|
|
@ -16,7 +16,7 @@ litellm.set_verbose = True
|
||||||
model_alias_map = {"good-model": "anyscale/meta-llama/Llama-2-7b-chat-hf"}
|
model_alias_map = {"good-model": "anyscale/meta-llama/Llama-2-7b-chat-hf"}
|
||||||
|
|
||||||
|
|
||||||
def test_model_alias_map():
|
def test_model_alias_map(caplog):
|
||||||
try:
|
try:
|
||||||
litellm.model_alias_map = model_alias_map
|
litellm.model_alias_map = model_alias_map
|
||||||
response = completion(
|
response = completion(
|
||||||
|
@ -27,9 +27,15 @@ def test_model_alias_map():
|
||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
)
|
)
|
||||||
print(response.model)
|
print(response.model)
|
||||||
|
|
||||||
|
captured_logs = [rec.levelname for rec in caplog.records]
|
||||||
|
|
||||||
|
for log in captured_logs:
|
||||||
|
assert "ERROR" not in log
|
||||||
|
|
||||||
assert "Llama-2-7b-chat-hf" in response.model
|
assert "Llama-2-7b-chat-hf" in response.model
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
test_model_alias_map()
|
# test_model_alias_map()
|
||||||
|
|
|
@ -5,6 +5,7 @@ import sys, os, asyncio
|
||||||
import traceback
|
import traceback
|
||||||
import time, pytest
|
import time, pytest
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
|
@ -142,7 +143,7 @@ def validate_last_format(chunk):
|
||||||
), "'finish_reason' should be a string."
|
), "'finish_reason' should be a string."
|
||||||
|
|
||||||
|
|
||||||
def streaming_format_tests(idx, chunk):
|
def streaming_format_tests(idx, chunk) -> Tuple[str, bool]:
|
||||||
extracted_chunk = ""
|
extracted_chunk = ""
|
||||||
finished = False
|
finished = False
|
||||||
print(f"chunk: {chunk}")
|
print(f"chunk: {chunk}")
|
||||||
|
@ -306,6 +307,70 @@ def test_completion_azure_stream():
|
||||||
|
|
||||||
|
|
||||||
# test_completion_azure_stream()
|
# test_completion_azure_stream()
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_completion_predibase_streaming(sync_mode):
|
||||||
|
try:
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
if sync_mode:
|
||||||
|
response = completion(
|
||||||
|
model="predibase/llama-3-8b-instruct",
|
||||||
|
tenant_id="c4768f95",
|
||||||
|
api_base="https://serving.app.predibase.com",
|
||||||
|
api_key=os.getenv("PREDIBASE_API_KEY"),
|
||||||
|
messages=[{"role": "user", "content": "What is the meaning of life?"}],
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
complete_response = ""
|
||||||
|
for idx, init_chunk in enumerate(response):
|
||||||
|
chunk, finished = streaming_format_tests(idx, init_chunk)
|
||||||
|
complete_response += chunk
|
||||||
|
custom_llm_provider = init_chunk._hidden_params["custom_llm_provider"]
|
||||||
|
print(f"custom_llm_provider: {custom_llm_provider}")
|
||||||
|
assert custom_llm_provider == "predibase"
|
||||||
|
if finished:
|
||||||
|
assert isinstance(
|
||||||
|
init_chunk.choices[0], litellm.utils.StreamingChoices
|
||||||
|
)
|
||||||
|
break
|
||||||
|
if complete_response.strip() == "":
|
||||||
|
raise Exception("Empty response received")
|
||||||
|
else:
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="predibase/llama-3-8b-instruct",
|
||||||
|
tenant_id="c4768f95",
|
||||||
|
api_base="https://serving.app.predibase.com",
|
||||||
|
api_key=os.getenv("PREDIBASE_API_KEY"),
|
||||||
|
messages=[{"role": "user", "content": "What is the meaning of life?"}],
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# await response
|
||||||
|
|
||||||
|
complete_response = ""
|
||||||
|
idx = 0
|
||||||
|
async for init_chunk in response:
|
||||||
|
chunk, finished = streaming_format_tests(idx, init_chunk)
|
||||||
|
complete_response += chunk
|
||||||
|
custom_llm_provider = init_chunk._hidden_params["custom_llm_provider"]
|
||||||
|
print(f"custom_llm_provider: {custom_llm_provider}")
|
||||||
|
assert custom_llm_provider == "predibase"
|
||||||
|
idx += 1
|
||||||
|
if finished:
|
||||||
|
assert isinstance(
|
||||||
|
init_chunk.choices[0], litellm.utils.StreamingChoices
|
||||||
|
)
|
||||||
|
break
|
||||||
|
if complete_response.strip() == "":
|
||||||
|
raise Exception("Empty response received")
|
||||||
|
|
||||||
|
print(f"complete_response: {complete_response}")
|
||||||
|
except litellm.Timeout as e:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
def test_completion_azure_function_calling_stream():
|
def test_completion_azure_function_calling_stream():
|
||||||
|
@ -1501,6 +1566,70 @@ def test_openai_chat_completion_complete_response_call():
|
||||||
|
|
||||||
|
|
||||||
# test_openai_chat_completion_complete_response_call()
|
# test_openai_chat_completion_complete_response_call()
|
||||||
|
def test_openai_stream_options_call():
|
||||||
|
litellm.set_verbose = False
|
||||||
|
response = litellm.completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "system", "content": "say GM - we're going to make it "}],
|
||||||
|
stream=True,
|
||||||
|
stream_options={"include_usage": True},
|
||||||
|
max_tokens=10,
|
||||||
|
)
|
||||||
|
usage = None
|
||||||
|
chunks = []
|
||||||
|
for chunk in response:
|
||||||
|
print("chunk: ", chunk)
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
last_chunk = chunks[-1]
|
||||||
|
print("last chunk: ", last_chunk)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Assert that:
|
||||||
|
- Last Chunk includes Usage
|
||||||
|
- All chunks prior to last chunk have usage=None
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert last_chunk.usage is not None
|
||||||
|
assert last_chunk.usage.total_tokens > 0
|
||||||
|
assert last_chunk.usage.prompt_tokens > 0
|
||||||
|
assert last_chunk.usage.completion_tokens > 0
|
||||||
|
|
||||||
|
# assert all non last chunks have usage=None
|
||||||
|
assert all(chunk.usage is None for chunk in chunks[:-1])
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_stream_options_call_text_completion():
|
||||||
|
litellm.set_verbose = False
|
||||||
|
response = litellm.text_completion(
|
||||||
|
model="gpt-3.5-turbo-instruct",
|
||||||
|
prompt="say GM - we're going to make it ",
|
||||||
|
stream=True,
|
||||||
|
stream_options={"include_usage": True},
|
||||||
|
max_tokens=10,
|
||||||
|
)
|
||||||
|
usage = None
|
||||||
|
chunks = []
|
||||||
|
for chunk in response:
|
||||||
|
print("chunk: ", chunk)
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
last_chunk = chunks[-1]
|
||||||
|
print("last chunk: ", last_chunk)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Assert that:
|
||||||
|
- Last Chunk includes Usage
|
||||||
|
- All chunks prior to last chunk have usage=None
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert last_chunk.usage is not None
|
||||||
|
assert last_chunk.usage.total_tokens > 0
|
||||||
|
assert last_chunk.usage.prompt_tokens > 0
|
||||||
|
assert last_chunk.usage.completion_tokens > 0
|
||||||
|
|
||||||
|
# assert all non last chunks have usage=None
|
||||||
|
assert all(chunk.usage is None for chunk in chunks[:-1])
|
||||||
|
|
||||||
|
|
||||||
def test_openai_text_completion_call():
|
def test_openai_text_completion_call():
|
||||||
|
|
127
litellm/utils.py
127
litellm/utils.py
|
@ -369,7 +369,7 @@ class ChatCompletionMessageToolCall(OpenAIObject):
|
||||||
class Message(OpenAIObject):
|
class Message(OpenAIObject):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
content="default",
|
content: Optional[str] = "default",
|
||||||
role="assistant",
|
role="assistant",
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
function_call=None,
|
function_call=None,
|
||||||
|
@ -612,6 +612,7 @@ class ModelResponse(OpenAIObject):
|
||||||
system_fingerprint=None,
|
system_fingerprint=None,
|
||||||
usage=None,
|
usage=None,
|
||||||
stream=None,
|
stream=None,
|
||||||
|
stream_options=None,
|
||||||
response_ms=None,
|
response_ms=None,
|
||||||
hidden_params=None,
|
hidden_params=None,
|
||||||
**params,
|
**params,
|
||||||
|
@ -658,6 +659,12 @@ class ModelResponse(OpenAIObject):
|
||||||
usage = usage
|
usage = usage
|
||||||
elif stream is None or stream == False:
|
elif stream is None or stream == False:
|
||||||
usage = Usage()
|
usage = Usage()
|
||||||
|
elif (
|
||||||
|
stream == True
|
||||||
|
and stream_options is not None
|
||||||
|
and stream_options.get("include_usage") == True
|
||||||
|
):
|
||||||
|
usage = Usage()
|
||||||
if hidden_params:
|
if hidden_params:
|
||||||
self._hidden_params = hidden_params
|
self._hidden_params = hidden_params
|
||||||
|
|
||||||
|
@ -4839,6 +4846,7 @@ def get_optional_params(
|
||||||
top_p=None,
|
top_p=None,
|
||||||
n=None,
|
n=None,
|
||||||
stream=False,
|
stream=False,
|
||||||
|
stream_options=None,
|
||||||
stop=None,
|
stop=None,
|
||||||
max_tokens=None,
|
max_tokens=None,
|
||||||
presence_penalty=None,
|
presence_penalty=None,
|
||||||
|
@ -4908,6 +4916,7 @@ def get_optional_params(
|
||||||
"top_p": None,
|
"top_p": None,
|
||||||
"n": None,
|
"n": None,
|
||||||
"stream": None,
|
"stream": None,
|
||||||
|
"stream_options": None,
|
||||||
"stop": None,
|
"stop": None,
|
||||||
"max_tokens": None,
|
"max_tokens": None,
|
||||||
"presence_penalty": None,
|
"presence_penalty": None,
|
||||||
|
@ -5779,6 +5788,8 @@ def get_optional_params(
|
||||||
optional_params["n"] = n
|
optional_params["n"] = n
|
||||||
if stream is not None:
|
if stream is not None:
|
||||||
optional_params["stream"] = stream
|
optional_params["stream"] = stream
|
||||||
|
if stream_options is not None:
|
||||||
|
optional_params["stream_options"] = stream_options
|
||||||
if stop is not None:
|
if stop is not None:
|
||||||
optional_params["stop"] = stop
|
optional_params["stop"] = stop
|
||||||
if max_tokens is not None:
|
if max_tokens is not None:
|
||||||
|
@ -5927,13 +5938,15 @@ def get_api_base(model: str, optional_params: dict) -> Optional[str]:
|
||||||
model=model, **optional_params
|
model=model, **optional_params
|
||||||
) # convert to pydantic object
|
) # convert to pydantic object
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_logger.error("Error occurred in getting api base - {}".format(str(e)))
|
verbose_logger.debug("Error occurred in getting api base - {}".format(str(e)))
|
||||||
return None
|
return None
|
||||||
# get llm provider
|
# get llm provider
|
||||||
|
|
||||||
if _optional_params.api_base is not None:
|
if _optional_params.api_base is not None:
|
||||||
return _optional_params.api_base
|
return _optional_params.api_base
|
||||||
|
|
||||||
|
if litellm.model_alias_map and model in litellm.model_alias_map:
|
||||||
|
model = litellm.model_alias_map[model]
|
||||||
try:
|
try:
|
||||||
model, custom_llm_provider, dynamic_api_key, dynamic_api_base = (
|
model, custom_llm_provider, dynamic_api_key, dynamic_api_base = (
|
||||||
get_llm_provider(
|
get_llm_provider(
|
||||||
|
@ -6083,6 +6096,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
|
||||||
"top_p",
|
"top_p",
|
||||||
"n",
|
"n",
|
||||||
"stream",
|
"stream",
|
||||||
|
"stream_options",
|
||||||
"stop",
|
"stop",
|
||||||
"max_tokens",
|
"max_tokens",
|
||||||
"presence_penalty",
|
"presence_penalty",
|
||||||
|
@ -9500,7 +9514,12 @@ def get_secret(
|
||||||
# replicate/anthropic/cohere
|
# replicate/anthropic/cohere
|
||||||
class CustomStreamWrapper:
|
class CustomStreamWrapper:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, completion_stream, model, custom_llm_provider=None, logging_obj=None
|
self,
|
||||||
|
completion_stream,
|
||||||
|
model,
|
||||||
|
custom_llm_provider=None,
|
||||||
|
logging_obj=None,
|
||||||
|
stream_options=None,
|
||||||
):
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.custom_llm_provider = custom_llm_provider
|
self.custom_llm_provider = custom_llm_provider
|
||||||
|
@ -9526,6 +9545,7 @@ class CustomStreamWrapper:
|
||||||
self.response_id = None
|
self.response_id = None
|
||||||
self.logging_loop = None
|
self.logging_loop = None
|
||||||
self.rules = Rules()
|
self.rules = Rules()
|
||||||
|
self.stream_options = stream_options
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return self
|
return self
|
||||||
|
@ -9737,6 +9757,50 @@ class CustomStreamWrapper:
|
||||||
"finish_reason": finish_reason,
|
"finish_reason": finish_reason,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def handle_predibase_chunk(self, chunk):
|
||||||
|
try:
|
||||||
|
if type(chunk) != str:
|
||||||
|
chunk = chunk.decode(
|
||||||
|
"utf-8"
|
||||||
|
) # DO NOT REMOVE this: This is required for HF inference API + Streaming
|
||||||
|
text = ""
|
||||||
|
is_finished = False
|
||||||
|
finish_reason = ""
|
||||||
|
print_verbose(f"chunk: {chunk}")
|
||||||
|
if chunk.startswith("data:"):
|
||||||
|
data_json = json.loads(chunk[5:])
|
||||||
|
print_verbose(f"data json: {data_json}")
|
||||||
|
if "token" in data_json and "text" in data_json["token"]:
|
||||||
|
text = data_json["token"]["text"]
|
||||||
|
if data_json.get("details", False) and data_json["details"].get(
|
||||||
|
"finish_reason", False
|
||||||
|
):
|
||||||
|
is_finished = True
|
||||||
|
finish_reason = data_json["details"]["finish_reason"]
|
||||||
|
elif data_json.get(
|
||||||
|
"generated_text", False
|
||||||
|
): # if full generated text exists, then stream is complete
|
||||||
|
text = "" # don't return the final bos token
|
||||||
|
is_finished = True
|
||||||
|
finish_reason = "stop"
|
||||||
|
elif data_json.get("error", False):
|
||||||
|
raise Exception(data_json.get("error"))
|
||||||
|
return {
|
||||||
|
"text": text,
|
||||||
|
"is_finished": is_finished,
|
||||||
|
"finish_reason": finish_reason,
|
||||||
|
}
|
||||||
|
elif "error" in chunk:
|
||||||
|
raise ValueError(chunk)
|
||||||
|
return {
|
||||||
|
"text": text,
|
||||||
|
"is_finished": is_finished,
|
||||||
|
"finish_reason": finish_reason,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
raise e
|
||||||
|
|
||||||
def handle_huggingface_chunk(self, chunk):
|
def handle_huggingface_chunk(self, chunk):
|
||||||
try:
|
try:
|
||||||
if type(chunk) != str:
|
if type(chunk) != str:
|
||||||
|
@ -9966,6 +10030,7 @@ class CustomStreamWrapper:
|
||||||
is_finished = False
|
is_finished = False
|
||||||
finish_reason = None
|
finish_reason = None
|
||||||
logprobs = None
|
logprobs = None
|
||||||
|
usage = None
|
||||||
original_chunk = None # this is used for function/tool calling
|
original_chunk = None # this is used for function/tool calling
|
||||||
if len(str_line.choices) > 0:
|
if len(str_line.choices) > 0:
|
||||||
if (
|
if (
|
||||||
|
@ -10000,12 +10065,15 @@ class CustomStreamWrapper:
|
||||||
else:
|
else:
|
||||||
logprobs = None
|
logprobs = None
|
||||||
|
|
||||||
|
usage = getattr(str_line, "usage", None)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"text": text,
|
"text": text,
|
||||||
"is_finished": is_finished,
|
"is_finished": is_finished,
|
||||||
"finish_reason": finish_reason,
|
"finish_reason": finish_reason,
|
||||||
"logprobs": logprobs,
|
"logprobs": logprobs,
|
||||||
"original_chunk": str_line,
|
"original_chunk": str_line,
|
||||||
|
"usage": usage,
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
@ -10038,16 +10106,19 @@ class CustomStreamWrapper:
|
||||||
text = ""
|
text = ""
|
||||||
is_finished = False
|
is_finished = False
|
||||||
finish_reason = None
|
finish_reason = None
|
||||||
|
usage = None
|
||||||
choices = getattr(chunk, "choices", [])
|
choices = getattr(chunk, "choices", [])
|
||||||
if len(choices) > 0:
|
if len(choices) > 0:
|
||||||
text = choices[0].text
|
text = choices[0].text
|
||||||
if choices[0].finish_reason is not None:
|
if choices[0].finish_reason is not None:
|
||||||
is_finished = True
|
is_finished = True
|
||||||
finish_reason = choices[0].finish_reason
|
finish_reason = choices[0].finish_reason
|
||||||
|
usage = getattr(chunk, "usage", None)
|
||||||
return {
|
return {
|
||||||
"text": text,
|
"text": text,
|
||||||
"is_finished": is_finished,
|
"is_finished": is_finished,
|
||||||
"finish_reason": finish_reason,
|
"finish_reason": finish_reason,
|
||||||
|
"usage": usage,
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -10308,7 +10379,9 @@ class CustomStreamWrapper:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def model_response_creator(self):
|
def model_response_creator(self):
|
||||||
model_response = ModelResponse(stream=True, model=self.model)
|
model_response = ModelResponse(
|
||||||
|
stream=True, model=self.model, stream_options=self.stream_options
|
||||||
|
)
|
||||||
if self.response_id is not None:
|
if self.response_id is not None:
|
||||||
model_response.id = self.response_id
|
model_response.id = self.response_id
|
||||||
else:
|
else:
|
||||||
|
@ -10365,6 +10438,11 @@ class CustomStreamWrapper:
|
||||||
completion_obj["content"] = response_obj["text"]
|
completion_obj["content"] = response_obj["text"]
|
||||||
if response_obj["is_finished"]:
|
if response_obj["is_finished"]:
|
||||||
self.received_finish_reason = response_obj["finish_reason"]
|
self.received_finish_reason = response_obj["finish_reason"]
|
||||||
|
elif self.custom_llm_provider and self.custom_llm_provider == "predibase":
|
||||||
|
response_obj = self.handle_predibase_chunk(chunk)
|
||||||
|
completion_obj["content"] = response_obj["text"]
|
||||||
|
if response_obj["is_finished"]:
|
||||||
|
self.received_finish_reason = response_obj["finish_reason"]
|
||||||
elif (
|
elif (
|
||||||
self.custom_llm_provider and self.custom_llm_provider == "baseten"
|
self.custom_llm_provider and self.custom_llm_provider == "baseten"
|
||||||
): # baseten doesn't provide streaming
|
): # baseten doesn't provide streaming
|
||||||
|
@ -10567,18 +10645,6 @@ class CustomStreamWrapper:
|
||||||
elif self.custom_llm_provider == "watsonx":
|
elif self.custom_llm_provider == "watsonx":
|
||||||
response_obj = self.handle_watsonx_stream(chunk)
|
response_obj = self.handle_watsonx_stream(chunk)
|
||||||
completion_obj["content"] = response_obj["text"]
|
completion_obj["content"] = response_obj["text"]
|
||||||
print_verbose(f"completion obj content: {completion_obj['content']}")
|
|
||||||
if getattr(model_response, "usage", None) is None:
|
|
||||||
model_response.usage = Usage()
|
|
||||||
if response_obj.get("prompt_tokens") is not None:
|
|
||||||
prompt_token_count = getattr(model_response.usage, "prompt_tokens", 0)
|
|
||||||
model_response.usage.prompt_tokens = (prompt_token_count+response_obj["prompt_tokens"])
|
|
||||||
if response_obj.get("completion_tokens") is not None:
|
|
||||||
model_response.usage.completion_tokens = response_obj["completion_tokens"]
|
|
||||||
model_response.usage.total_tokens = (
|
|
||||||
getattr(model_response.usage, "prompt_tokens", 0)
|
|
||||||
+ getattr(model_response.usage, "completion_tokens", 0)
|
|
||||||
)
|
|
||||||
if response_obj["is_finished"]:
|
if response_obj["is_finished"]:
|
||||||
self.received_finish_reason = response_obj["finish_reason"]
|
self.received_finish_reason = response_obj["finish_reason"]
|
||||||
elif self.custom_llm_provider == "text-completion-openai":
|
elif self.custom_llm_provider == "text-completion-openai":
|
||||||
|
@ -10587,6 +10653,11 @@ class CustomStreamWrapper:
|
||||||
print_verbose(f"completion obj content: {completion_obj['content']}")
|
print_verbose(f"completion obj content: {completion_obj['content']}")
|
||||||
if response_obj["is_finished"]:
|
if response_obj["is_finished"]:
|
||||||
self.received_finish_reason = response_obj["finish_reason"]
|
self.received_finish_reason = response_obj["finish_reason"]
|
||||||
|
if (
|
||||||
|
self.stream_options
|
||||||
|
and self.stream_options.get("include_usage", False) == True
|
||||||
|
):
|
||||||
|
model_response.usage = response_obj["usage"]
|
||||||
elif self.custom_llm_provider == "azure_text":
|
elif self.custom_llm_provider == "azure_text":
|
||||||
response_obj = self.handle_azure_text_completion_chunk(chunk)
|
response_obj = self.handle_azure_text_completion_chunk(chunk)
|
||||||
completion_obj["content"] = response_obj["text"]
|
completion_obj["content"] = response_obj["text"]
|
||||||
|
@ -10640,6 +10711,12 @@ class CustomStreamWrapper:
|
||||||
if response_obj["logprobs"] is not None:
|
if response_obj["logprobs"] is not None:
|
||||||
model_response.choices[0].logprobs = response_obj["logprobs"]
|
model_response.choices[0].logprobs = response_obj["logprobs"]
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.stream_options is not None
|
||||||
|
and self.stream_options["include_usage"] == True
|
||||||
|
):
|
||||||
|
model_response.usage = response_obj["usage"]
|
||||||
|
|
||||||
model_response.model = self.model
|
model_response.model = self.model
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"model_response finish reason 3: {self.received_finish_reason}; response_obj={response_obj}"
|
f"model_response finish reason 3: {self.received_finish_reason}; response_obj={response_obj}"
|
||||||
|
@ -10727,6 +10804,11 @@ class CustomStreamWrapper:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
model_response.choices[0].delta = Delta()
|
model_response.choices[0].delta = Delta()
|
||||||
else:
|
else:
|
||||||
|
if (
|
||||||
|
self.stream_options is not None
|
||||||
|
and self.stream_options["include_usage"] == True
|
||||||
|
):
|
||||||
|
return model_response
|
||||||
return
|
return
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"model_response.choices[0].delta: {model_response.choices[0].delta}; completion_obj: {completion_obj}"
|
f"model_response.choices[0].delta: {model_response.choices[0].delta}; completion_obj: {completion_obj}"
|
||||||
|
@ -10983,7 +11065,7 @@ class CustomStreamWrapper:
|
||||||
or self.custom_llm_provider == "sagemaker"
|
or self.custom_llm_provider == "sagemaker"
|
||||||
or self.custom_llm_provider == "gemini"
|
or self.custom_llm_provider == "gemini"
|
||||||
or self.custom_llm_provider == "cached_response"
|
or self.custom_llm_provider == "cached_response"
|
||||||
or self.custom_llm_provider == "watsonx"
|
or self.custom_llm_provider == "predibase"
|
||||||
or self.custom_llm_provider in litellm.openai_compatible_endpoints
|
or self.custom_llm_provider in litellm.openai_compatible_endpoints
|
||||||
):
|
):
|
||||||
async for chunk in self.completion_stream:
|
async for chunk in self.completion_stream:
|
||||||
|
@ -11106,9 +11188,10 @@ class CustomStreamWrapper:
|
||||||
|
|
||||||
|
|
||||||
class TextCompletionStreamWrapper:
|
class TextCompletionStreamWrapper:
|
||||||
def __init__(self, completion_stream, model):
|
def __init__(self, completion_stream, model, stream_options: Optional[dict] = None):
|
||||||
self.completion_stream = completion_stream
|
self.completion_stream = completion_stream
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.stream_options = stream_options
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return self
|
return self
|
||||||
|
@ -11132,6 +11215,14 @@ class TextCompletionStreamWrapper:
|
||||||
text_choices["index"] = chunk["choices"][0]["index"]
|
text_choices["index"] = chunk["choices"][0]["index"]
|
||||||
text_choices["finish_reason"] = chunk["choices"][0]["finish_reason"]
|
text_choices["finish_reason"] = chunk["choices"][0]["finish_reason"]
|
||||||
response["choices"] = [text_choices]
|
response["choices"] = [text_choices]
|
||||||
|
|
||||||
|
# only pass usage when stream_options["include_usage"] is True
|
||||||
|
if (
|
||||||
|
self.stream_options
|
||||||
|
and self.stream_options.get("include_usage", False) == True
|
||||||
|
):
|
||||||
|
response["usage"] = chunk.get("usage", None)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "litellm"
|
name = "litellm"
|
||||||
version = "1.36.4"
|
version = "1.37.0"
|
||||||
description = "Library to easily interface with LLM API providers"
|
description = "Library to easily interface with LLM API providers"
|
||||||
authors = ["BerriAI"]
|
authors = ["BerriAI"]
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
|
@ -80,7 +80,7 @@ requires = ["poetry-core", "wheel"]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|
||||||
[tool.commitizen]
|
[tool.commitizen]
|
||||||
version = "1.36.4"
|
version = "1.37.0"
|
||||||
version_files = [
|
version_files = [
|
||||||
"pyproject.toml:^version"
|
"pyproject.toml:^version"
|
||||||
]
|
]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue