mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Merge branch 'main' into litellm_msft_group_assignment
This commit is contained in:
commit
b2b82ecd66
84 changed files with 1302 additions and 223 deletions
|
@ -610,6 +610,8 @@ jobs:
|
||||||
name: Install Dependencies
|
name: Install Dependencies
|
||||||
command: |
|
command: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
|
pip install wheel
|
||||||
|
pip install --upgrade pip wheel setuptools
|
||||||
python -m pip install -r requirements.txt
|
python -m pip install -r requirements.txt
|
||||||
pip install "pytest==7.3.1"
|
pip install "pytest==7.3.1"
|
||||||
pip install "respx==0.21.1"
|
pip install "respx==0.21.1"
|
||||||
|
|
|
@ -2,6 +2,10 @@ apiVersion: v1
|
||||||
kind: Service
|
kind: Service
|
||||||
metadata:
|
metadata:
|
||||||
name: {{ include "litellm.fullname" . }}
|
name: {{ include "litellm.fullname" . }}
|
||||||
|
{{- with .Values.service.annotations }}
|
||||||
|
annotations:
|
||||||
|
{{- toYaml . | nindent 4 }}
|
||||||
|
{{- end }}
|
||||||
labels:
|
labels:
|
||||||
{{- include "litellm.labels" . | nindent 4 }}
|
{{- include "litellm.labels" . | nindent 4 }}
|
||||||
spec:
|
spec:
|
||||||
|
|
|
@ -438,6 +438,179 @@ assert isinstance(
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Google Search Tool
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import completion
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ["GEMINI_API_KEY"] = ".."
|
||||||
|
|
||||||
|
tools = [{"googleSearch": {}}] # 👈 ADD GOOGLE SEARCH
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="gemini/gemini-2.0-flash",
|
||||||
|
messages=[{"role": "user", "content": "What is the weather in San Francisco?"}],
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
1. Setup config.yaml
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: gemini-2.0-flash
|
||||||
|
litellm_params:
|
||||||
|
model: gemini/gemini-2.0-flash
|
||||||
|
api_key: os.environ/GEMINI_API_KEY
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Start Proxy
|
||||||
|
```bash
|
||||||
|
$ litellm --config /path/to/config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Make Request!
|
||||||
|
```bash
|
||||||
|
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-H 'Authorization: Bearer sk-1234' \
|
||||||
|
-d '{
|
||||||
|
"model": "gemini-2.0-flash",
|
||||||
|
"messages": [{"role": "user", "content": "What is the weather in San Francisco?"}],
|
||||||
|
"tools": [{"googleSearch": {}}]
|
||||||
|
}
|
||||||
|
'
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
### Google Search Retrieval
|
||||||
|
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import completion
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ["GEMINI_API_KEY"] = ".."
|
||||||
|
|
||||||
|
tools = [{"googleSearchRetrieval": {}}] # 👈 ADD GOOGLE SEARCH
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="gemini/gemini-2.0-flash",
|
||||||
|
messages=[{"role": "user", "content": "What is the weather in San Francisco?"}],
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
1. Setup config.yaml
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: gemini-2.0-flash
|
||||||
|
litellm_params:
|
||||||
|
model: gemini/gemini-2.0-flash
|
||||||
|
api_key: os.environ/GEMINI_API_KEY
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Start Proxy
|
||||||
|
```bash
|
||||||
|
$ litellm --config /path/to/config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Make Request!
|
||||||
|
```bash
|
||||||
|
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-H 'Authorization: Bearer sk-1234' \
|
||||||
|
-d '{
|
||||||
|
"model": "gemini-2.0-flash",
|
||||||
|
"messages": [{"role": "user", "content": "What is the weather in San Francisco?"}],
|
||||||
|
"tools": [{"googleSearchRetrieval": {}}]
|
||||||
|
}
|
||||||
|
'
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
|
### Code Execution Tool
|
||||||
|
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import completion
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ["GEMINI_API_KEY"] = ".."
|
||||||
|
|
||||||
|
tools = [{"codeExecution": {}}] # 👈 ADD GOOGLE SEARCH
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="gemini/gemini-2.0-flash",
|
||||||
|
messages=[{"role": "user", "content": "What is the weather in San Francisco?"}],
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
1. Setup config.yaml
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: gemini-2.0-flash
|
||||||
|
litellm_params:
|
||||||
|
model: gemini/gemini-2.0-flash
|
||||||
|
api_key: os.environ/GEMINI_API_KEY
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Start Proxy
|
||||||
|
```bash
|
||||||
|
$ litellm --config /path/to/config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Make Request!
|
||||||
|
```bash
|
||||||
|
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-H 'Authorization: Bearer sk-1234' \
|
||||||
|
-d '{
|
||||||
|
"model": "gemini-2.0-flash",
|
||||||
|
"messages": [{"role": "user", "content": "What is the weather in San Francisco?"}],
|
||||||
|
"tools": [{"codeExecution": {}}]
|
||||||
|
}
|
||||||
|
'
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## JSON Mode
|
## JSON Mode
|
||||||
|
|
||||||
<Tabs>
|
<Tabs>
|
||||||
|
|
|
@ -398,6 +398,8 @@ curl http://localhost:4000/v1/chat/completions \
|
||||||
</TabItem>
|
</TabItem>
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
|
You can also use the `enterpriseWebSearch` tool for an [enterprise compliant search](https://cloud.google.com/vertex-ai/generative-ai/docs/grounding/web-grounding-enterprise).
|
||||||
|
|
||||||
#### **Moving from Vertex AI SDK to LiteLLM (GROUNDING)**
|
#### **Moving from Vertex AI SDK to LiteLLM (GROUNDING)**
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -110,5 +110,8 @@ def get_litellm_params(
|
||||||
"azure_password": kwargs.get("azure_password"),
|
"azure_password": kwargs.get("azure_password"),
|
||||||
"max_retries": max_retries,
|
"max_retries": max_retries,
|
||||||
"timeout": kwargs.get("timeout"),
|
"timeout": kwargs.get("timeout"),
|
||||||
|
"bucket_name": kwargs.get("bucket_name"),
|
||||||
|
"vertex_credentials": kwargs.get("vertex_credentials"),
|
||||||
|
"vertex_project": kwargs.get("vertex_project"),
|
||||||
}
|
}
|
||||||
return litellm_params
|
return litellm_params
|
||||||
|
|
|
@ -2,7 +2,10 @@
|
||||||
Common utility functions used for translating messages across providers
|
Common utility functions used for translating messages across providers
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Dict, List, Literal, Optional, Union, cast
|
import io
|
||||||
|
import mimetypes
|
||||||
|
from os import PathLike
|
||||||
|
from typing import Dict, List, Literal, Mapping, Optional, Union, cast
|
||||||
|
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
AllMessageValues,
|
AllMessageValues,
|
||||||
|
@ -10,7 +13,13 @@ from litellm.types.llms.openai import (
|
||||||
ChatCompletionFileObject,
|
ChatCompletionFileObject,
|
||||||
ChatCompletionUserMessage,
|
ChatCompletionUserMessage,
|
||||||
)
|
)
|
||||||
from litellm.types.utils import Choices, ModelResponse, StreamingChoices
|
from litellm.types.utils import (
|
||||||
|
Choices,
|
||||||
|
ExtractedFileData,
|
||||||
|
FileTypes,
|
||||||
|
ModelResponse,
|
||||||
|
StreamingChoices,
|
||||||
|
)
|
||||||
|
|
||||||
DEFAULT_USER_CONTINUE_MESSAGE = ChatCompletionUserMessage(
|
DEFAULT_USER_CONTINUE_MESSAGE = ChatCompletionUserMessage(
|
||||||
content="Please continue.", role="user"
|
content="Please continue.", role="user"
|
||||||
|
@ -350,6 +359,68 @@ def update_messages_with_model_file_ids(
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def extract_file_data(file_data: FileTypes) -> ExtractedFileData:
|
||||||
|
"""
|
||||||
|
Extracts and processes file data from various input formats.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_data: Can be a tuple of (filename, content, [content_type], [headers]) or direct file content
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ExtractedFileData containing:
|
||||||
|
- filename: Name of the file if provided
|
||||||
|
- content: The file content in bytes
|
||||||
|
- content_type: MIME type of the file
|
||||||
|
- headers: Any additional headers
|
||||||
|
"""
|
||||||
|
# Parse the file_data based on its type
|
||||||
|
filename = None
|
||||||
|
file_content = None
|
||||||
|
content_type = None
|
||||||
|
file_headers: Mapping[str, str] = {}
|
||||||
|
|
||||||
|
if isinstance(file_data, tuple):
|
||||||
|
if len(file_data) == 2:
|
||||||
|
filename, file_content = file_data
|
||||||
|
elif len(file_data) == 3:
|
||||||
|
filename, file_content, content_type = file_data
|
||||||
|
elif len(file_data) == 4:
|
||||||
|
filename, file_content, content_type, file_headers = file_data
|
||||||
|
else:
|
||||||
|
file_content = file_data
|
||||||
|
# Convert content to bytes
|
||||||
|
if isinstance(file_content, (str, PathLike)):
|
||||||
|
# If it's a path, open and read the file
|
||||||
|
with open(file_content, "rb") as f:
|
||||||
|
content = f.read()
|
||||||
|
elif isinstance(file_content, io.IOBase):
|
||||||
|
# If it's a file-like object
|
||||||
|
content = file_content.read()
|
||||||
|
|
||||||
|
if isinstance(content, str):
|
||||||
|
content = content.encode("utf-8")
|
||||||
|
# Reset file pointer to beginning
|
||||||
|
file_content.seek(0)
|
||||||
|
elif isinstance(file_content, bytes):
|
||||||
|
content = file_content
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported file content type: {type(file_content)}")
|
||||||
|
|
||||||
|
# Use provided content type or guess based on filename
|
||||||
|
if not content_type:
|
||||||
|
content_type = (
|
||||||
|
mimetypes.guess_type(filename)[0]
|
||||||
|
if filename
|
||||||
|
else "application/octet-stream"
|
||||||
|
)
|
||||||
|
|
||||||
|
return ExtractedFileData(
|
||||||
|
filename=filename,
|
||||||
|
content=content,
|
||||||
|
content_type=content_type,
|
||||||
|
headers=file_headers,
|
||||||
|
)
|
||||||
|
|
||||||
def unpack_defs(schema, defs):
|
def unpack_defs(schema, defs):
|
||||||
properties = schema.get("properties", None)
|
properties = schema.get("properties", None)
|
||||||
if properties is None:
|
if properties is None:
|
||||||
|
@ -381,3 +452,4 @@ def unpack_defs(schema, defs):
|
||||||
unpack_defs(ref, defs)
|
unpack_defs(ref, defs)
|
||||||
value["items"] = ref
|
value["items"] = ref
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
|
@ -50,6 +50,7 @@ class AiohttpOpenAIChatConfig(OpenAILikeChatConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -301,6 +301,7 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
optional_params={**optional_params, "is_vertex_request": is_vertex_request},
|
optional_params={**optional_params, "is_vertex_request": is_vertex_request},
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
config = ProviderConfigManager.get_provider_chat_config(
|
config = ProviderConfigManager.get_provider_chat_config(
|
||||||
|
|
|
@ -868,6 +868,7 @@ class AnthropicConfig(BaseConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
|
|
|
@ -87,6 +87,7 @@ class AnthropicTextConfig(BaseConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -293,6 +293,7 @@ class AzureOpenAIConfig(BaseConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -39,6 +39,7 @@ class AzureAIStudioConfig(OpenAIConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -262,6 +262,7 @@ class BaseConfig(ABC):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import TYPE_CHECKING, Any, List, Optional
|
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
@ -33,23 +33,22 @@ class BaseFilesConfig(BaseConfig):
|
||||||
) -> List[OpenAICreateFileRequestOptionalParams]:
|
) -> List[OpenAICreateFileRequestOptionalParams]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_complete_url(
|
def get_complete_file_url(
|
||||||
self,
|
self,
|
||||||
api_base: Optional[str],
|
api_base: Optional[str],
|
||||||
api_key: Optional[str],
|
api_key: Optional[str],
|
||||||
model: str,
|
model: str,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
litellm_params: dict,
|
litellm_params: dict,
|
||||||
stream: Optional[bool] = None,
|
data: CreateFileRequest,
|
||||||
) -> str:
|
):
|
||||||
"""
|
return self.get_complete_url(
|
||||||
OPTIONAL
|
api_base=api_base,
|
||||||
|
api_key=api_key,
|
||||||
Get the complete url for the request
|
model=model,
|
||||||
|
optional_params=optional_params,
|
||||||
Some providers need `model` in `api_base`
|
litellm_params=litellm_params,
|
||||||
"""
|
)
|
||||||
return api_base or ""
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def transform_create_file_request(
|
def transform_create_file_request(
|
||||||
|
@ -58,7 +57,7 @@ class BaseFilesConfig(BaseConfig):
|
||||||
create_file_data: CreateFileRequest,
|
create_file_data: CreateFileRequest,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
litellm_params: dict,
|
litellm_params: dict,
|
||||||
) -> dict:
|
) -> Union[dict, str, bytes]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|
|
@ -65,6 +65,7 @@ class BaseImageVariationConfig(BaseConfig, ABC):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -831,6 +831,7 @@ class AmazonConverseConfig(BaseConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -442,6 +442,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -118,6 +118,7 @@ class ClarifaiConfig(BaseConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -60,6 +60,7 @@ class CloudflareChatConfig(BaseConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -118,6 +118,7 @@ class CohereChatConfig(BaseConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -101,6 +101,7 @@ class CohereTextConfig(BaseConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -229,6 +229,7 @@ class BaseLLMAIOHTTPHandler:
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -498,6 +499,7 @@ class BaseLLMAIOHTTPHandler:
|
||||||
model=model,
|
model=model,
|
||||||
messages=[{"role": "user", "content": "test"}],
|
messages=[{"role": "user", "content": "test"}],
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -192,7 +192,7 @@ class AsyncHTTPHandler:
|
||||||
async def post(
|
async def post(
|
||||||
self,
|
self,
|
||||||
url: str,
|
url: str,
|
||||||
data: Optional[Union[dict, str]] = None, # type: ignore
|
data: Optional[Union[dict, str, bytes]] = None, # type: ignore
|
||||||
json: Optional[dict] = None,
|
json: Optional[dict] = None,
|
||||||
params: Optional[dict] = None,
|
params: Optional[dict] = None,
|
||||||
headers: Optional[dict] = None,
|
headers: Optional[dict] = None,
|
||||||
|
@ -427,7 +427,7 @@ class AsyncHTTPHandler:
|
||||||
self,
|
self,
|
||||||
url: str,
|
url: str,
|
||||||
client: httpx.AsyncClient,
|
client: httpx.AsyncClient,
|
||||||
data: Optional[Union[dict, str]] = None, # type: ignore
|
data: Optional[Union[dict, str, bytes]] = None, # type: ignore
|
||||||
json: Optional[dict] = None,
|
json: Optional[dict] = None,
|
||||||
params: Optional[dict] = None,
|
params: Optional[dict] = None,
|
||||||
headers: Optional[dict] = None,
|
headers: Optional[dict] = None,
|
||||||
|
@ -527,7 +527,7 @@ class HTTPHandler:
|
||||||
def post(
|
def post(
|
||||||
self,
|
self,
|
||||||
url: str,
|
url: str,
|
||||||
data: Optional[Union[dict, str]] = None,
|
data: Optional[Union[dict, str, bytes]] = None,
|
||||||
json: Optional[Union[dict, str, List]] = None,
|
json: Optional[Union[dict, str, List]] = None,
|
||||||
params: Optional[dict] = None,
|
params: Optional[dict] = None,
|
||||||
headers: Optional[dict] = None,
|
headers: Optional[dict] = None,
|
||||||
|
@ -573,7 +573,6 @@ class HTTPHandler:
|
||||||
setattr(e, "text", error_text)
|
setattr(e, "text", error_text)
|
||||||
|
|
||||||
setattr(e, "status_code", e.response.status_code)
|
setattr(e, "status_code", e.response.status_code)
|
||||||
|
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
|
@ -247,6 +247,7 @@ class BaseLLMHTTPHandler:
|
||||||
messages=messages,
|
messages=messages,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
api_base = provider_config.get_complete_url(
|
api_base = provider_config.get_complete_url(
|
||||||
|
@ -625,6 +626,7 @@ class BaseLLMHTTPHandler:
|
||||||
model=model,
|
model=model,
|
||||||
messages=[],
|
messages=[],
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
api_base = provider_config.get_complete_url(
|
api_base = provider_config.get_complete_url(
|
||||||
|
@ -896,6 +898,7 @@ class BaseLLMHTTPHandler:
|
||||||
model=model,
|
model=model,
|
||||||
messages=[],
|
messages=[],
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
if client is None or not isinstance(client, HTTPHandler):
|
if client is None or not isinstance(client, HTTPHandler):
|
||||||
|
@ -1228,15 +1231,19 @@ class BaseLLMHTTPHandler:
|
||||||
model="",
|
model="",
|
||||||
messages=[],
|
messages=[],
|
||||||
optional_params={},
|
optional_params={},
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
api_base = provider_config.get_complete_url(
|
api_base = provider_config.get_complete_file_url(
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
model="",
|
model="",
|
||||||
optional_params={},
|
optional_params={},
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
|
data=create_file_data,
|
||||||
)
|
)
|
||||||
|
if api_base is None:
|
||||||
|
raise ValueError("api_base is required for create_file")
|
||||||
|
|
||||||
# Get the transformed request data for both steps
|
# Get the transformed request data for both steps
|
||||||
transformed_request = provider_config.transform_create_file_request(
|
transformed_request = provider_config.transform_create_file_request(
|
||||||
|
@ -1263,48 +1270,57 @@ class BaseLLMHTTPHandler:
|
||||||
else:
|
else:
|
||||||
sync_httpx_client = client
|
sync_httpx_client = client
|
||||||
|
|
||||||
try:
|
if isinstance(transformed_request, str) or isinstance(
|
||||||
# Step 1: Initial request to get upload URL
|
transformed_request, bytes
|
||||||
initial_response = sync_httpx_client.post(
|
):
|
||||||
url=api_base,
|
|
||||||
headers={
|
|
||||||
**headers,
|
|
||||||
**transformed_request["initial_request"]["headers"],
|
|
||||||
},
|
|
||||||
data=json.dumps(transformed_request["initial_request"]["data"]),
|
|
||||||
timeout=timeout,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract upload URL from response headers
|
|
||||||
upload_url = initial_response.headers.get("X-Goog-Upload-URL")
|
|
||||||
|
|
||||||
if not upload_url:
|
|
||||||
raise ValueError("Failed to get upload URL from initial request")
|
|
||||||
|
|
||||||
# Step 2: Upload the actual file
|
|
||||||
upload_response = sync_httpx_client.post(
|
upload_response = sync_httpx_client.post(
|
||||||
url=upload_url,
|
url=api_base,
|
||||||
headers=transformed_request["upload_request"]["headers"],
|
headers=headers,
|
||||||
data=transformed_request["upload_request"]["data"],
|
data=transformed_request,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
# Step 1: Initial request to get upload URL
|
||||||
|
initial_response = sync_httpx_client.post(
|
||||||
|
url=api_base,
|
||||||
|
headers={
|
||||||
|
**headers,
|
||||||
|
**transformed_request["initial_request"]["headers"],
|
||||||
|
},
|
||||||
|
data=json.dumps(transformed_request["initial_request"]["data"]),
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
return provider_config.transform_create_file_response(
|
# Extract upload URL from response headers
|
||||||
model=None,
|
upload_url = initial_response.headers.get("X-Goog-Upload-URL")
|
||||||
raw_response=upload_response,
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
litellm_params=litellm_params,
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
if not upload_url:
|
||||||
raise self._handle_error(
|
raise ValueError("Failed to get upload URL from initial request")
|
||||||
e=e,
|
|
||||||
provider_config=provider_config,
|
# Step 2: Upload the actual file
|
||||||
)
|
upload_response = sync_httpx_client.post(
|
||||||
|
url=upload_url,
|
||||||
|
headers=transformed_request["upload_request"]["headers"],
|
||||||
|
data=transformed_request["upload_request"]["data"],
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise self._handle_error(
|
||||||
|
e=e,
|
||||||
|
provider_config=provider_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
return provider_config.transform_create_file_response(
|
||||||
|
model=None,
|
||||||
|
raw_response=upload_response,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
)
|
||||||
|
|
||||||
async def async_create_file(
|
async def async_create_file(
|
||||||
self,
|
self,
|
||||||
transformed_request: dict,
|
transformed_request: Union[bytes, str, dict],
|
||||||
litellm_params: dict,
|
litellm_params: dict,
|
||||||
provider_config: BaseFilesConfig,
|
provider_config: BaseFilesConfig,
|
||||||
headers: dict,
|
headers: dict,
|
||||||
|
@ -1323,45 +1339,54 @@ class BaseLLMHTTPHandler:
|
||||||
else:
|
else:
|
||||||
async_httpx_client = client
|
async_httpx_client = client
|
||||||
|
|
||||||
try:
|
if isinstance(transformed_request, str) or isinstance(
|
||||||
# Step 1: Initial request to get upload URL
|
transformed_request, bytes
|
||||||
initial_response = await async_httpx_client.post(
|
):
|
||||||
url=api_base,
|
|
||||||
headers={
|
|
||||||
**headers,
|
|
||||||
**transformed_request["initial_request"]["headers"],
|
|
||||||
},
|
|
||||||
data=json.dumps(transformed_request["initial_request"]["data"]),
|
|
||||||
timeout=timeout,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract upload URL from response headers
|
|
||||||
upload_url = initial_response.headers.get("X-Goog-Upload-URL")
|
|
||||||
|
|
||||||
if not upload_url:
|
|
||||||
raise ValueError("Failed to get upload URL from initial request")
|
|
||||||
|
|
||||||
# Step 2: Upload the actual file
|
|
||||||
upload_response = await async_httpx_client.post(
|
upload_response = await async_httpx_client.post(
|
||||||
url=upload_url,
|
url=api_base,
|
||||||
headers=transformed_request["upload_request"]["headers"],
|
headers=headers,
|
||||||
data=transformed_request["upload_request"]["data"],
|
data=transformed_request,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
# Step 1: Initial request to get upload URL
|
||||||
|
initial_response = await async_httpx_client.post(
|
||||||
|
url=api_base,
|
||||||
|
headers={
|
||||||
|
**headers,
|
||||||
|
**transformed_request["initial_request"]["headers"],
|
||||||
|
},
|
||||||
|
data=json.dumps(transformed_request["initial_request"]["data"]),
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
return provider_config.transform_create_file_response(
|
# Extract upload URL from response headers
|
||||||
model=None,
|
upload_url = initial_response.headers.get("X-Goog-Upload-URL")
|
||||||
raw_response=upload_response,
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
litellm_params=litellm_params,
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
if not upload_url:
|
||||||
verbose_logger.exception(f"Error creating file: {e}")
|
raise ValueError("Failed to get upload URL from initial request")
|
||||||
raise self._handle_error(
|
|
||||||
e=e,
|
# Step 2: Upload the actual file
|
||||||
provider_config=provider_config,
|
upload_response = await async_httpx_client.post(
|
||||||
)
|
url=upload_url,
|
||||||
|
headers=transformed_request["upload_request"]["headers"],
|
||||||
|
data=transformed_request["upload_request"]["data"],
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.exception(f"Error creating file: {e}")
|
||||||
|
raise self._handle_error(
|
||||||
|
e=e,
|
||||||
|
provider_config=provider_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
return provider_config.transform_create_file_response(
|
||||||
|
model=None,
|
||||||
|
raw_response=upload_response,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
)
|
||||||
|
|
||||||
def list_files(self):
|
def list_files(self):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -116,6 +116,7 @@ class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -171,6 +171,7 @@ class DeepgramAudioTranscriptionConfig(BaseAudioTranscriptionConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -41,6 +41,7 @@ class FireworksAIMixin:
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -20,6 +20,7 @@ class GeminiModelInfo(BaseLLMModelInfo):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -4,11 +4,12 @@ Supports writing files to Google AI Studio Files API.
|
||||||
For vertex ai, check out the vertex_ai/files/handler.py file.
|
For vertex ai, check out the vertex_ai/files/handler.py file.
|
||||||
"""
|
"""
|
||||||
import time
|
import time
|
||||||
from typing import List, Mapping, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
|
from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data
|
||||||
from litellm.llms.base_llm.files.transformation import (
|
from litellm.llms.base_llm.files.transformation import (
|
||||||
BaseFilesConfig,
|
BaseFilesConfig,
|
||||||
LiteLLMLoggingObj,
|
LiteLLMLoggingObj,
|
||||||
|
@ -91,66 +92,28 @@ class GoogleAIStudioFilesHandler(GeminiModelInfo, BaseFilesConfig):
|
||||||
if file_data is None:
|
if file_data is None:
|
||||||
raise ValueError("File data is required")
|
raise ValueError("File data is required")
|
||||||
|
|
||||||
# Parse the file_data based on its type
|
# Use the common utility function to extract file data
|
||||||
filename = None
|
extracted_data = extract_file_data(file_data)
|
||||||
file_content = None
|
|
||||||
content_type = None
|
|
||||||
file_headers: Mapping[str, str] = {}
|
|
||||||
|
|
||||||
if isinstance(file_data, tuple):
|
|
||||||
if len(file_data) == 2:
|
|
||||||
filename, file_content = file_data
|
|
||||||
elif len(file_data) == 3:
|
|
||||||
filename, file_content, content_type = file_data
|
|
||||||
elif len(file_data) == 4:
|
|
||||||
filename, file_content, content_type, file_headers = file_data
|
|
||||||
else:
|
|
||||||
file_content = file_data
|
|
||||||
|
|
||||||
# Handle the file content based on its type
|
|
||||||
import io
|
|
||||||
from os import PathLike
|
|
||||||
|
|
||||||
# Convert content to bytes
|
|
||||||
if isinstance(file_content, (str, PathLike)):
|
|
||||||
# If it's a path, open and read the file
|
|
||||||
with open(file_content, "rb") as f:
|
|
||||||
content = f.read()
|
|
||||||
elif isinstance(file_content, io.IOBase):
|
|
||||||
# If it's a file-like object
|
|
||||||
content = file_content.read()
|
|
||||||
if isinstance(content, str):
|
|
||||||
content = content.encode("utf-8")
|
|
||||||
elif isinstance(file_content, bytes):
|
|
||||||
content = file_content
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported file content type: {type(file_content)}")
|
|
||||||
|
|
||||||
# Get file size
|
# Get file size
|
||||||
file_size = len(content)
|
file_size = len(extracted_data["content"])
|
||||||
|
|
||||||
# Use provided content type or guess based on filename
|
|
||||||
if not content_type:
|
|
||||||
import mimetypes
|
|
||||||
|
|
||||||
content_type = (
|
|
||||||
mimetypes.guess_type(filename)[0]
|
|
||||||
if filename
|
|
||||||
else "application/octet-stream"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Step 1: Initial resumable upload request
|
# Step 1: Initial resumable upload request
|
||||||
headers = {
|
headers = {
|
||||||
"X-Goog-Upload-Protocol": "resumable",
|
"X-Goog-Upload-Protocol": "resumable",
|
||||||
"X-Goog-Upload-Command": "start",
|
"X-Goog-Upload-Command": "start",
|
||||||
"X-Goog-Upload-Header-Content-Length": str(file_size),
|
"X-Goog-Upload-Header-Content-Length": str(file_size),
|
||||||
"X-Goog-Upload-Header-Content-Type": content_type,
|
"X-Goog-Upload-Header-Content-Type": extracted_data["content_type"],
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
headers.update(file_headers) # Add any custom headers
|
headers.update(extracted_data["headers"]) # Add any custom headers
|
||||||
|
|
||||||
# Initial metadata request body
|
# Initial metadata request body
|
||||||
initial_data = {"file": {"display_name": filename or str(int(time.time()))}}
|
initial_data = {
|
||||||
|
"file": {
|
||||||
|
"display_name": extracted_data["filename"] or str(int(time.time()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
# Step 2: Actual file upload data
|
# Step 2: Actual file upload data
|
||||||
upload_headers = {
|
upload_headers = {
|
||||||
|
@ -161,7 +124,10 @@ class GoogleAIStudioFilesHandler(GeminiModelInfo, BaseFilesConfig):
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"initial_request": {"headers": headers, "data": initial_data},
|
"initial_request": {"headers": headers, "data": initial_data},
|
||||||
"upload_request": {"headers": upload_headers, "data": content},
|
"upload_request": {
|
||||||
|
"headers": upload_headers,
|
||||||
|
"data": extracted_data["content"],
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def transform_create_file_response(
|
def transform_create_file_response(
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
@ -18,7 +18,6 @@ from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||||
from ..common_utils import HuggingFaceError, _fetch_inference_provider_mapping
|
from ..common_utils import HuggingFaceError, _fetch_inference_provider_mapping
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
BASE_URL = "https://router.huggingface.co"
|
BASE_URL = "https://router.huggingface.co"
|
||||||
|
@ -34,7 +33,8 @@ class HuggingFaceChatConfig(OpenAIGPTConfig):
|
||||||
headers: dict,
|
headers: dict,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: Dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
@ -51,7 +51,9 @@ class HuggingFaceChatConfig(OpenAIGPTConfig):
|
||||||
def get_error_class(
|
def get_error_class(
|
||||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||||
) -> BaseLLMException:
|
) -> BaseLLMException:
|
||||||
return HuggingFaceError(status_code=status_code, message=error_message, headers=headers)
|
return HuggingFaceError(
|
||||||
|
status_code=status_code, message=error_message, headers=headers
|
||||||
|
)
|
||||||
|
|
||||||
def get_base_url(self, model: str, base_url: Optional[str]) -> Optional[str]:
|
def get_base_url(self, model: str, base_url: Optional[str]) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
|
@ -82,7 +84,9 @@ class HuggingFaceChatConfig(OpenAIGPTConfig):
|
||||||
if api_base is not None:
|
if api_base is not None:
|
||||||
complete_url = api_base
|
complete_url = api_base
|
||||||
elif os.getenv("HF_API_BASE") or os.getenv("HUGGINGFACE_API_BASE"):
|
elif os.getenv("HF_API_BASE") or os.getenv("HUGGINGFACE_API_BASE"):
|
||||||
complete_url = str(os.getenv("HF_API_BASE")) or str(os.getenv("HUGGINGFACE_API_BASE"))
|
complete_url = str(os.getenv("HF_API_BASE")) or str(
|
||||||
|
os.getenv("HUGGINGFACE_API_BASE")
|
||||||
|
)
|
||||||
elif model.startswith(("http://", "https://")):
|
elif model.startswith(("http://", "https://")):
|
||||||
complete_url = model
|
complete_url = model
|
||||||
# 4. Default construction with provider
|
# 4. Default construction with provider
|
||||||
|
@ -138,4 +142,8 @@ class HuggingFaceChatConfig(OpenAIGPTConfig):
|
||||||
)
|
)
|
||||||
mapped_model = provider_mapping["providerId"]
|
mapped_model = provider_mapping["providerId"]
|
||||||
messages = self._transform_messages(messages=messages, model=mapped_model)
|
messages = self._transform_messages(messages=messages, model=mapped_model)
|
||||||
return dict(ChatCompletionRequest(model=mapped_model, messages=messages, **optional_params))
|
return dict(
|
||||||
|
ChatCompletionRequest(
|
||||||
|
model=mapped_model, messages=messages, **optional_params
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
|
@ -1,15 +1,6 @@
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import (
|
from typing import Any, Callable, Dict, List, Literal, Optional, Union, get_args
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Literal,
|
|
||||||
Optional,
|
|
||||||
Union,
|
|
||||||
get_args,
|
|
||||||
)
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
@ -35,8 +26,9 @@ hf_tasks_embeddings = Literal[ # pipeline tags + hf tei endpoints - https://hug
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_hf_task_embedding_for_model(
|
||||||
def get_hf_task_embedding_for_model(model: str, task_type: Optional[str], api_base: str) -> Optional[str]:
|
model: str, task_type: Optional[str], api_base: str
|
||||||
|
) -> Optional[str]:
|
||||||
if task_type is not None:
|
if task_type is not None:
|
||||||
if task_type in get_args(hf_tasks_embeddings):
|
if task_type in get_args(hf_tasks_embeddings):
|
||||||
return task_type
|
return task_type
|
||||||
|
@ -57,7 +49,9 @@ def get_hf_task_embedding_for_model(model: str, task_type: Optional[str], api_ba
|
||||||
return pipeline_tag
|
return pipeline_tag
|
||||||
|
|
||||||
|
|
||||||
async def async_get_hf_task_embedding_for_model(model: str, task_type: Optional[str], api_base: str) -> Optional[str]:
|
async def async_get_hf_task_embedding_for_model(
|
||||||
|
model: str, task_type: Optional[str], api_base: str
|
||||||
|
) -> Optional[str]:
|
||||||
if task_type is not None:
|
if task_type is not None:
|
||||||
if task_type in get_args(hf_tasks_embeddings):
|
if task_type in get_args(hf_tasks_embeddings):
|
||||||
return task_type
|
return task_type
|
||||||
|
@ -116,7 +110,9 @@ class HuggingFaceEmbedding(BaseLLM):
|
||||||
input: List,
|
input: List,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
hf_task = await async_get_hf_task_embedding_for_model(model=model, task_type=task_type, api_base=HF_HUB_URL)
|
hf_task = await async_get_hf_task_embedding_for_model(
|
||||||
|
model=model, task_type=task_type, api_base=HF_HUB_URL
|
||||||
|
)
|
||||||
|
|
||||||
data = self._transform_input_on_pipeline_tag(input=input, pipeline_tag=hf_task)
|
data = self._transform_input_on_pipeline_tag(input=input, pipeline_tag=hf_task)
|
||||||
|
|
||||||
|
@ -173,7 +169,9 @@ class HuggingFaceEmbedding(BaseLLM):
|
||||||
task_type = optional_params.pop("input_type", None)
|
task_type = optional_params.pop("input_type", None)
|
||||||
|
|
||||||
if call_type == "sync":
|
if call_type == "sync":
|
||||||
hf_task = get_hf_task_embedding_for_model(model=model, task_type=task_type, api_base=HF_HUB_URL)
|
hf_task = get_hf_task_embedding_for_model(
|
||||||
|
model=model, task_type=task_type, api_base=HF_HUB_URL
|
||||||
|
)
|
||||||
elif call_type == "async":
|
elif call_type == "async":
|
||||||
return self._async_transform_input(
|
return self._async_transform_input(
|
||||||
model=model, task_type=task_type, embed_url=embed_url, input=input
|
model=model, task_type=task_type, embed_url=embed_url, input=input
|
||||||
|
@ -325,6 +323,7 @@ class HuggingFaceEmbedding(BaseLLM):
|
||||||
input: list,
|
input: list,
|
||||||
model_response: EmbeddingResponse,
|
model_response: EmbeddingResponse,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
logging_obj: LiteLLMLoggingObj,
|
logging_obj: LiteLLMLoggingObj,
|
||||||
encoding: Callable,
|
encoding: Callable,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
@ -341,9 +340,12 @@ class HuggingFaceEmbedding(BaseLLM):
|
||||||
model=model,
|
model=model,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
messages=[],
|
messages=[],
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
task_type = optional_params.pop("input_type", None)
|
task_type = optional_params.pop("input_type", None)
|
||||||
task = get_hf_task_embedding_for_model(model=model, task_type=task_type, api_base=HF_HUB_URL)
|
task = get_hf_task_embedding_for_model(
|
||||||
|
model=model, task_type=task_type, api_base=HF_HUB_URL
|
||||||
|
)
|
||||||
# print_verbose(f"{model}, {task}")
|
# print_verbose(f"{model}, {task}")
|
||||||
embed_url = ""
|
embed_url = ""
|
||||||
if "https" in model:
|
if "https" in model:
|
||||||
|
@ -355,7 +357,9 @@ class HuggingFaceEmbedding(BaseLLM):
|
||||||
elif "HUGGINGFACE_API_BASE" in os.environ:
|
elif "HUGGINGFACE_API_BASE" in os.environ:
|
||||||
embed_url = os.getenv("HUGGINGFACE_API_BASE", "")
|
embed_url = os.getenv("HUGGINGFACE_API_BASE", "")
|
||||||
else:
|
else:
|
||||||
embed_url = f"https://router.huggingface.co/hf-inference/pipeline/{task}/{model}"
|
embed_url = (
|
||||||
|
f"https://router.huggingface.co/hf-inference/pipeline/{task}/{model}"
|
||||||
|
)
|
||||||
|
|
||||||
## ROUTING ##
|
## ROUTING ##
|
||||||
if aembedding is True:
|
if aembedding is True:
|
||||||
|
|
|
@ -355,6 +355,7 @@ class HuggingFaceEmbeddingConfig(BaseConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: Dict,
|
optional_params: Dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
|
|
|
@ -36,6 +36,7 @@ def completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
## Load Config
|
## Load Config
|
||||||
|
|
|
@ -93,6 +93,7 @@ class NLPCloudConfig(BaseConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -353,6 +353,7 @@ class OllamaConfig(BaseConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -32,6 +32,7 @@ def completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
if "https" in model:
|
if "https" in model:
|
||||||
completion_url = model
|
completion_url = model
|
||||||
|
@ -123,6 +124,7 @@ def embedding(
|
||||||
model=model,
|
model=model,
|
||||||
messages=[],
|
messages=[],
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
litellm_params={},
|
||||||
)
|
)
|
||||||
response = litellm.module_level_client.post(
|
response = litellm.module_level_client.post(
|
||||||
embeddings_url, headers=headers, json=data
|
embeddings_url, headers=headers, json=data
|
||||||
|
|
|
@ -88,6 +88,7 @@ class OobaboogaConfig(OpenAIGPTConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -321,6 +321,7 @@ class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -286,6 +286,7 @@ class OpenAIConfig(BaseConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -53,6 +53,7 @@ class OpenAIWhisperAudioTranscriptionConfig(BaseAudioTranscriptionConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -131,6 +131,7 @@ class PetalsConfig(BaseConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -228,10 +228,10 @@ class PredibaseChatCompletion:
|
||||||
api_key: str,
|
api_key: str,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
acompletion=None,
|
acompletion=None,
|
||||||
litellm_params=None,
|
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
headers: dict = {},
|
headers: dict = {},
|
||||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
|
@ -241,6 +241,7 @@ class PredibaseChatCompletion:
|
||||||
messages=messages,
|
messages=messages,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
model=model,
|
model=model,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
completion_url = ""
|
completion_url = ""
|
||||||
input_text = ""
|
input_text = ""
|
||||||
|
|
|
@ -164,6 +164,7 @@ class PredibaseConfig(BaseConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -141,6 +141,7 @@ def completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
# Start a prediction and get the prediction URL
|
# Start a prediction and get the prediction URL
|
||||||
version_id = replicate_config.model_to_version_id(model)
|
version_id = replicate_config.model_to_version_id(model)
|
||||||
|
|
|
@ -312,6 +312,7 @@ class ReplicateConfig(BaseConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -96,6 +96,7 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
model: str,
|
model: str,
|
||||||
data: dict,
|
data: dict,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
|
litellm_params: dict,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
aws_region_name: str,
|
aws_region_name: str,
|
||||||
extra_headers: Optional[dict] = None,
|
extra_headers: Optional[dict] = None,
|
||||||
|
@ -122,6 +123,7 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
request = AWSRequest(
|
request = AWSRequest(
|
||||||
method="POST", url=api_base, data=encoded_data, headers=headers
|
method="POST", url=api_base, data=encoded_data, headers=headers
|
||||||
|
@ -198,6 +200,7 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
data=data,
|
data=data,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
aws_region_name=aws_region_name,
|
aws_region_name=aws_region_name,
|
||||||
)
|
)
|
||||||
|
@ -274,6 +277,7 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
"model": model,
|
"model": model,
|
||||||
"data": _data,
|
"data": _data,
|
||||||
"optional_params": optional_params,
|
"optional_params": optional_params,
|
||||||
|
"litellm_params": litellm_params,
|
||||||
"credentials": credentials,
|
"credentials": credentials,
|
||||||
"aws_region_name": aws_region_name,
|
"aws_region_name": aws_region_name,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
|
@ -426,6 +430,7 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
"model": model,
|
"model": model,
|
||||||
"data": data,
|
"data": data,
|
||||||
"optional_params": optional_params,
|
"optional_params": optional_params,
|
||||||
|
"litellm_params": litellm_params,
|
||||||
"credentials": credentials,
|
"credentials": credentials,
|
||||||
"aws_region_name": aws_region_name,
|
"aws_region_name": aws_region_name,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
|
@ -496,6 +501,7 @@ class SagemakerLLM(BaseAWSLLM):
|
||||||
"model": model,
|
"model": model,
|
||||||
"data": data,
|
"data": data,
|
||||||
"optional_params": optional_params,
|
"optional_params": optional_params,
|
||||||
|
"litellm_params": litellm_params,
|
||||||
"credentials": credentials,
|
"credentials": credentials,
|
||||||
"aws_region_name": aws_region_name,
|
"aws_region_name": aws_region_name,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
|
|
|
@ -263,6 +263,7 @@ class SagemakerConfig(BaseConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -92,6 +92,7 @@ class SnowflakeConfig(OpenAIGPTConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -37,6 +37,7 @@ class TopazImageVariationConfig(BaseImageVariationConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -48,6 +48,7 @@ class TritonConfig(BaseConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: Dict,
|
optional_params: Dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
|
|
|
@ -42,6 +42,7 @@ class TritonEmbeddingConfig(BaseEmbeddingConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import asyncio
|
||||||
from typing import Any, Coroutine, Optional, Union
|
from typing import Any, Coroutine, Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -11,9 +12,9 @@ from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||||
from litellm.types.llms.openai import CreateFileRequest, OpenAIFileObject
|
from litellm.types.llms.openai import CreateFileRequest, OpenAIFileObject
|
||||||
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
|
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
|
||||||
|
|
||||||
from .transformation import VertexAIFilesTransformation
|
from .transformation import VertexAIJsonlFilesTransformation
|
||||||
|
|
||||||
vertex_ai_files_transformation = VertexAIFilesTransformation()
|
vertex_ai_files_transformation = VertexAIJsonlFilesTransformation()
|
||||||
|
|
||||||
|
|
||||||
class VertexAIFilesHandler(GCSBucketBase):
|
class VertexAIFilesHandler(GCSBucketBase):
|
||||||
|
@ -92,5 +93,15 @@ class VertexAIFilesHandler(GCSBucketBase):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
return None # type: ignore
|
return asyncio.run(
|
||||||
|
self.async_create_file(
|
||||||
|
create_file_data=create_file_data,
|
||||||
|
api_base=api_base,
|
||||||
|
vertex_credentials=vertex_credentials,
|
||||||
|
vertex_project=vertex_project,
|
||||||
|
vertex_location=vertex_location,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
|
@ -1,7 +1,17 @@
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from httpx import Headers, Response
|
||||||
|
|
||||||
|
from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data
|
||||||
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
|
from litellm.llms.base_llm.files.transformation import (
|
||||||
|
BaseFilesConfig,
|
||||||
|
LiteLLMLoggingObj,
|
||||||
|
)
|
||||||
from litellm.llms.vertex_ai.common_utils import (
|
from litellm.llms.vertex_ai.common_utils import (
|
||||||
_convert_vertex_datetime_to_openai_datetime,
|
_convert_vertex_datetime_to_openai_datetime,
|
||||||
)
|
)
|
||||||
|
@ -10,14 +20,317 @@ from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
|
||||||
VertexGeminiConfig,
|
VertexGeminiConfig,
|
||||||
)
|
)
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
|
AllMessageValues,
|
||||||
CreateFileRequest,
|
CreateFileRequest,
|
||||||
FileTypes,
|
FileTypes,
|
||||||
|
OpenAICreateFileRequestOptionalParams,
|
||||||
OpenAIFileObject,
|
OpenAIFileObject,
|
||||||
PathLike,
|
PathLike,
|
||||||
)
|
)
|
||||||
|
from litellm.types.llms.vertex_ai import GcsBucketResponse
|
||||||
|
from litellm.types.utils import ExtractedFileData, LlmProviders
|
||||||
|
|
||||||
|
from ..common_utils import VertexAIError
|
||||||
|
from ..vertex_llm_base import VertexBase
|
||||||
|
|
||||||
|
|
||||||
class VertexAIFilesTransformation(VertexGeminiConfig):
|
class VertexAIFilesConfig(VertexBase, BaseFilesConfig):
|
||||||
|
"""
|
||||||
|
Config for VertexAI Files
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.jsonl_transformation = VertexAIJsonlFilesTransformation()
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def custom_llm_provider(self) -> LlmProviders:
|
||||||
|
return LlmProviders.VERTEX_AI
|
||||||
|
|
||||||
|
def validate_environment(
|
||||||
|
self,
|
||||||
|
headers: dict,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
if not api_key:
|
||||||
|
api_key, _ = self.get_access_token(
|
||||||
|
credentials=litellm_params.get("vertex_credentials"),
|
||||||
|
project_id=litellm_params.get("vertex_project"),
|
||||||
|
)
|
||||||
|
if not api_key:
|
||||||
|
raise ValueError("api_key is required")
|
||||||
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
return headers
|
||||||
|
|
||||||
|
def _get_content_from_openai_file(self, openai_file_content: FileTypes) -> str:
|
||||||
|
"""
|
||||||
|
Helper to extract content from various OpenAI file types and return as string.
|
||||||
|
|
||||||
|
Handles:
|
||||||
|
- Direct content (str, bytes, IO[bytes])
|
||||||
|
- Tuple formats: (filename, content, [content_type], [headers])
|
||||||
|
- PathLike objects
|
||||||
|
"""
|
||||||
|
content: Union[str, bytes] = b""
|
||||||
|
# Extract file content from tuple if necessary
|
||||||
|
if isinstance(openai_file_content, tuple):
|
||||||
|
# Take the second element which is always the file content
|
||||||
|
file_content = openai_file_content[1]
|
||||||
|
else:
|
||||||
|
file_content = openai_file_content
|
||||||
|
|
||||||
|
# Handle different file content types
|
||||||
|
if isinstance(file_content, str):
|
||||||
|
# String content can be used directly
|
||||||
|
content = file_content
|
||||||
|
elif isinstance(file_content, bytes):
|
||||||
|
# Bytes content can be decoded
|
||||||
|
content = file_content
|
||||||
|
elif isinstance(file_content, PathLike): # PathLike
|
||||||
|
with open(str(file_content), "rb") as f:
|
||||||
|
content = f.read()
|
||||||
|
elif hasattr(file_content, "read"): # IO[bytes]
|
||||||
|
# File-like objects need to be read
|
||||||
|
content = file_content.read()
|
||||||
|
|
||||||
|
# Ensure content is string
|
||||||
|
if isinstance(content, bytes):
|
||||||
|
content = content.decode("utf-8")
|
||||||
|
|
||||||
|
return content
|
||||||
|
|
||||||
|
def _get_gcs_object_name_from_batch_jsonl(
|
||||||
|
self,
|
||||||
|
openai_jsonl_content: List[Dict[str, Any]],
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Gets a unique GCS object name for the VertexAI batch prediction job
|
||||||
|
|
||||||
|
named as: litellm-vertex-{model}-{uuid}
|
||||||
|
"""
|
||||||
|
_model = openai_jsonl_content[0].get("body", {}).get("model", "")
|
||||||
|
if "publishers/google/models" not in _model:
|
||||||
|
_model = f"publishers/google/models/{_model}"
|
||||||
|
object_name = f"litellm-vertex-files/{_model}/{uuid.uuid4()}"
|
||||||
|
return object_name
|
||||||
|
|
||||||
|
def get_object_name(
|
||||||
|
self, extracted_file_data: ExtractedFileData, purpose: str
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Get the object name for the request
|
||||||
|
"""
|
||||||
|
extracted_file_data_content = extracted_file_data.get("content")
|
||||||
|
|
||||||
|
if extracted_file_data_content is None:
|
||||||
|
raise ValueError("file content is required")
|
||||||
|
|
||||||
|
if purpose == "batch":
|
||||||
|
## 1. If jsonl, check if there's a model name
|
||||||
|
file_content = self._get_content_from_openai_file(
|
||||||
|
extracted_file_data_content
|
||||||
|
)
|
||||||
|
|
||||||
|
# Split into lines and parse each line as JSON
|
||||||
|
openai_jsonl_content = [
|
||||||
|
json.loads(line) for line in file_content.splitlines() if line.strip()
|
||||||
|
]
|
||||||
|
if len(openai_jsonl_content) > 0:
|
||||||
|
return self._get_gcs_object_name_from_batch_jsonl(openai_jsonl_content)
|
||||||
|
|
||||||
|
## 2. If not jsonl, return the filename
|
||||||
|
filename = extracted_file_data.get("filename")
|
||||||
|
if filename:
|
||||||
|
return filename
|
||||||
|
## 3. If no file name, return timestamp
|
||||||
|
return str(int(time.time()))
|
||||||
|
|
||||||
|
def get_complete_file_url(
|
||||||
|
self,
|
||||||
|
api_base: Optional[str],
|
||||||
|
api_key: Optional[str],
|
||||||
|
model: str,
|
||||||
|
optional_params: Dict,
|
||||||
|
litellm_params: Dict,
|
||||||
|
data: CreateFileRequest,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Get the complete url for the request
|
||||||
|
"""
|
||||||
|
bucket_name = litellm_params.get("bucket_name") or os.getenv("GCS_BUCKET_NAME")
|
||||||
|
if not bucket_name:
|
||||||
|
raise ValueError("GCS bucket_name is required")
|
||||||
|
file_data = data.get("file")
|
||||||
|
purpose = data.get("purpose")
|
||||||
|
if file_data is None:
|
||||||
|
raise ValueError("file is required")
|
||||||
|
if purpose is None:
|
||||||
|
raise ValueError("purpose is required")
|
||||||
|
extracted_file_data = extract_file_data(file_data)
|
||||||
|
object_name = self.get_object_name(extracted_file_data, purpose)
|
||||||
|
endpoint = (
|
||||||
|
f"upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}"
|
||||||
|
)
|
||||||
|
api_base = api_base or "https://storage.googleapis.com"
|
||||||
|
if not api_base:
|
||||||
|
raise ValueError("api_base is required")
|
||||||
|
|
||||||
|
return f"{api_base}/{endpoint}"
|
||||||
|
|
||||||
|
def get_supported_openai_params(
|
||||||
|
self, model: str
|
||||||
|
) -> List[OpenAICreateFileRequestOptionalParams]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self,
|
||||||
|
non_default_params: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
model: str,
|
||||||
|
drop_params: bool,
|
||||||
|
) -> dict:
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
def _map_openai_to_vertex_params(
|
||||||
|
self,
|
||||||
|
openai_request_body: Dict[str, Any],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
wrapper to call VertexGeminiConfig.map_openai_params
|
||||||
|
"""
|
||||||
|
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
|
||||||
|
VertexGeminiConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
config = VertexGeminiConfig()
|
||||||
|
_model = openai_request_body.get("model", "")
|
||||||
|
vertex_params = config.map_openai_params(
|
||||||
|
model=_model,
|
||||||
|
non_default_params=openai_request_body,
|
||||||
|
optional_params={},
|
||||||
|
drop_params=False,
|
||||||
|
)
|
||||||
|
return vertex_params
|
||||||
|
|
||||||
|
def _transform_openai_jsonl_content_to_vertex_ai_jsonl_content(
|
||||||
|
self, openai_jsonl_content: List[Dict[str, Any]]
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Transforms OpenAI JSONL content to VertexAI JSONL content
|
||||||
|
|
||||||
|
jsonl body for vertex is {"request": <request_body>}
|
||||||
|
Example Vertex jsonl
|
||||||
|
{"request":{"contents": [{"role": "user", "parts": [{"text": "What is the relation between the following video and image samples?"}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/video/animals.mp4", "mimeType": "video/mp4"}}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/image/cricket.jpeg", "mimeType": "image/jpeg"}}]}]}}
|
||||||
|
{"request":{"contents": [{"role": "user", "parts": [{"text": "Describe what is happening in this video."}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/video/another_video.mov", "mimeType": "video/mov"}}]}]}}
|
||||||
|
"""
|
||||||
|
|
||||||
|
vertex_jsonl_content = []
|
||||||
|
for _openai_jsonl_content in openai_jsonl_content:
|
||||||
|
openai_request_body = _openai_jsonl_content.get("body") or {}
|
||||||
|
vertex_request_body = _transform_request_body(
|
||||||
|
messages=openai_request_body.get("messages", []),
|
||||||
|
model=openai_request_body.get("model", ""),
|
||||||
|
optional_params=self._map_openai_to_vertex_params(openai_request_body),
|
||||||
|
custom_llm_provider="vertex_ai",
|
||||||
|
litellm_params={},
|
||||||
|
cached_content=None,
|
||||||
|
)
|
||||||
|
vertex_jsonl_content.append({"request": vertex_request_body})
|
||||||
|
return vertex_jsonl_content
|
||||||
|
|
||||||
|
def transform_create_file_request(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
create_file_data: CreateFileRequest,
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
) -> Union[bytes, str, dict]:
|
||||||
|
"""
|
||||||
|
2 Cases:
|
||||||
|
1. Handle basic file upload
|
||||||
|
2. Handle batch file upload (.jsonl)
|
||||||
|
"""
|
||||||
|
file_data = create_file_data.get("file")
|
||||||
|
if file_data is None:
|
||||||
|
raise ValueError("file is required")
|
||||||
|
extracted_file_data = extract_file_data(file_data)
|
||||||
|
extracted_file_data_content = extracted_file_data.get("content")
|
||||||
|
if (
|
||||||
|
create_file_data.get("purpose") == "batch"
|
||||||
|
and extracted_file_data.get("content_type") == "application/jsonl"
|
||||||
|
and extracted_file_data_content is not None
|
||||||
|
):
|
||||||
|
## 1. If jsonl, check if there's a model name
|
||||||
|
file_content = self._get_content_from_openai_file(
|
||||||
|
extracted_file_data_content
|
||||||
|
)
|
||||||
|
|
||||||
|
# Split into lines and parse each line as JSON
|
||||||
|
openai_jsonl_content = [
|
||||||
|
json.loads(line) for line in file_content.splitlines() if line.strip()
|
||||||
|
]
|
||||||
|
vertex_jsonl_content = (
|
||||||
|
self._transform_openai_jsonl_content_to_vertex_ai_jsonl_content(
|
||||||
|
openai_jsonl_content
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return json.dumps(vertex_jsonl_content)
|
||||||
|
elif isinstance(extracted_file_data_content, bytes):
|
||||||
|
return extracted_file_data_content
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported file content type")
|
||||||
|
|
||||||
|
def transform_create_file_response(
|
||||||
|
self,
|
||||||
|
model: Optional[str],
|
||||||
|
raw_response: Response,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
litellm_params: dict,
|
||||||
|
) -> OpenAIFileObject:
|
||||||
|
"""
|
||||||
|
Transform VertexAI File upload response into OpenAI-style FileObject
|
||||||
|
"""
|
||||||
|
response_json = raw_response.json()
|
||||||
|
|
||||||
|
try:
|
||||||
|
response_object = GcsBucketResponse(**response_json) # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
raise VertexAIError(
|
||||||
|
status_code=raw_response.status_code,
|
||||||
|
message=f"Error reading GCS bucket response: {e}",
|
||||||
|
headers=raw_response.headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
gcs_id = response_object.get("id", "")
|
||||||
|
# Remove the last numeric ID from the path
|
||||||
|
gcs_id = "/".join(gcs_id.split("/")[:-1]) if gcs_id else ""
|
||||||
|
|
||||||
|
return OpenAIFileObject(
|
||||||
|
purpose=response_object.get("purpose", "batch"),
|
||||||
|
id=f"gs://{gcs_id}",
|
||||||
|
filename=response_object.get("name", ""),
|
||||||
|
created_at=_convert_vertex_datetime_to_openai_datetime(
|
||||||
|
vertex_datetime=response_object.get("timeCreated", "")
|
||||||
|
),
|
||||||
|
status="uploaded",
|
||||||
|
bytes=int(response_object.get("size", 0)),
|
||||||
|
object="file",
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_error_class(
|
||||||
|
self, error_message: str, status_code: int, headers: Union[Dict, Headers]
|
||||||
|
) -> BaseLLMException:
|
||||||
|
return VertexAIError(
|
||||||
|
status_code=status_code, message=error_message, headers=headers
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class VertexAIJsonlFilesTransformation(VertexGeminiConfig):
|
||||||
"""
|
"""
|
||||||
Transforms OpenAI /v1/files/* requests to VertexAI /v1/files/* requests
|
Transforms OpenAI /v1/files/* requests to VertexAI /v1/files/* requests
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -240,6 +240,7 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
|
||||||
gtool_func_declarations = []
|
gtool_func_declarations = []
|
||||||
googleSearch: Optional[dict] = None
|
googleSearch: Optional[dict] = None
|
||||||
googleSearchRetrieval: Optional[dict] = None
|
googleSearchRetrieval: Optional[dict] = None
|
||||||
|
enterpriseWebSearch: Optional[dict] = None
|
||||||
code_execution: Optional[dict] = None
|
code_execution: Optional[dict] = None
|
||||||
# remove 'additionalProperties' from tools
|
# remove 'additionalProperties' from tools
|
||||||
value = _remove_additional_properties(value)
|
value = _remove_additional_properties(value)
|
||||||
|
@ -273,6 +274,8 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
|
||||||
googleSearch = tool["googleSearch"]
|
googleSearch = tool["googleSearch"]
|
||||||
elif tool.get("googleSearchRetrieval", None) is not None:
|
elif tool.get("googleSearchRetrieval", None) is not None:
|
||||||
googleSearchRetrieval = tool["googleSearchRetrieval"]
|
googleSearchRetrieval = tool["googleSearchRetrieval"]
|
||||||
|
elif tool.get("enterpriseWebSearch", None) is not None:
|
||||||
|
enterpriseWebSearch = tool["enterpriseWebSearch"]
|
||||||
elif tool.get("code_execution", None) is not None:
|
elif tool.get("code_execution", None) is not None:
|
||||||
code_execution = tool["code_execution"]
|
code_execution = tool["code_execution"]
|
||||||
elif openai_function_object is not None:
|
elif openai_function_object is not None:
|
||||||
|
@ -299,6 +302,8 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
|
||||||
_tools["googleSearch"] = googleSearch
|
_tools["googleSearch"] = googleSearch
|
||||||
if googleSearchRetrieval is not None:
|
if googleSearchRetrieval is not None:
|
||||||
_tools["googleSearchRetrieval"] = googleSearchRetrieval
|
_tools["googleSearchRetrieval"] = googleSearchRetrieval
|
||||||
|
if enterpriseWebSearch is not None:
|
||||||
|
_tools["enterpriseWebSearch"] = enterpriseWebSearch
|
||||||
if code_execution is not None:
|
if code_execution is not None:
|
||||||
_tools["code_execution"] = code_execution
|
_tools["code_execution"] = code_execution
|
||||||
return [_tools]
|
return [_tools]
|
||||||
|
@ -900,6 +905,7 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: Dict,
|
optional_params: Dict,
|
||||||
|
litellm_params: Dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
|
@ -1017,7 +1023,7 @@ class VertexLLM(VertexBase):
|
||||||
logging_obj,
|
logging_obj,
|
||||||
stream,
|
stream,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
litellm_params=None,
|
litellm_params: dict,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
client: Optional[AsyncHTTPHandler] = None,
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
@ -1058,6 +1064,7 @@ class VertexLLM(VertexBase):
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
@ -1144,6 +1151,7 @@ class VertexLLM(VertexBase):
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
request_body = await async_transform_request_body(**data) # type: ignore
|
request_body = await async_transform_request_body(**data) # type: ignore
|
||||||
|
@ -1317,6 +1325,7 @@ class VertexLLM(VertexBase):
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
## TRANSFORMATION ##
|
## TRANSFORMATION ##
|
||||||
|
|
|
@ -94,6 +94,7 @@ class VertexMultimodalEmbedding(VertexLLM):
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
api_key=auth_header,
|
api_key=auth_header,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
|
|
@ -47,6 +47,7 @@ class VertexAIMultimodalEmbeddingConfig(BaseEmbeddingConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -10,7 +10,6 @@ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple
|
||||||
|
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
from litellm.litellm_core_utils.asyncify import asyncify
|
from litellm.litellm_core_utils.asyncify import asyncify
|
||||||
from litellm.llms.base import BaseLLM
|
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
|
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
|
||||||
|
|
||||||
|
@ -22,7 +21,7 @@ else:
|
||||||
GoogleCredentialsObject = Any
|
GoogleCredentialsObject = Any
|
||||||
|
|
||||||
|
|
||||||
class VertexBase(BaseLLM):
|
class VertexBase:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.access_token: Optional[str] = None
|
self.access_token: Optional[str] = None
|
||||||
|
|
|
@ -83,6 +83,7 @@ class VoyageEmbeddingConfig(BaseEmbeddingConfig):
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -49,6 +49,7 @@ class WatsonXChatHandler(OpenAILikeChatHandler):
|
||||||
messages=messages,
|
messages=messages,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
## UPDATE PAYLOAD (optional params)
|
## UPDATE PAYLOAD (optional params)
|
||||||
|
|
|
@ -165,6 +165,7 @@ class IBMWatsonXMixin:
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
optional_params: Dict,
|
optional_params: Dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
|
|
|
@ -3616,6 +3616,7 @@ def embedding( # noqa: PLR0915
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
client=client,
|
client=client,
|
||||||
aembedding=aembedding,
|
aembedding=aembedding,
|
||||||
|
litellm_params=litellm_params_dict,
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "bedrock":
|
elif custom_llm_provider == "bedrock":
|
||||||
if isinstance(input, str):
|
if isinstance(input, str):
|
||||||
|
|
|
@ -2409,25 +2409,26 @@
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
"max_input_tokens": 131072,
|
"max_input_tokens": 131072,
|
||||||
"max_output_tokens": 4096,
|
"max_output_tokens": 4096,
|
||||||
"input_cost_per_token": 0,
|
"input_cost_per_token": 0.000000075,
|
||||||
"output_cost_per_token": 0,
|
"output_cost_per_token": 0.0000003,
|
||||||
"litellm_provider": "azure_ai",
|
"litellm_provider": "azure_ai",
|
||||||
"mode": "chat",
|
"mode": "chat",
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
"source": "https://learn.microsoft.com/en-us/azure/ai-foundry/concepts/models-featured#microsoft"
|
"source": "https://techcommunity.microsoft.com/blog/Azure-AI-Services-blog/announcing-new-phi-pricing-empowering-your-business-with-small-language-models/4395112"
|
||||||
},
|
},
|
||||||
"azure_ai/Phi-4-multimodal-instruct": {
|
"azure_ai/Phi-4-multimodal-instruct": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
"max_input_tokens": 131072,
|
"max_input_tokens": 131072,
|
||||||
"max_output_tokens": 4096,
|
"max_output_tokens": 4096,
|
||||||
"input_cost_per_token": 0,
|
"input_cost_per_token": 0.00000008,
|
||||||
"output_cost_per_token": 0,
|
"input_cost_per_audio_token": 0.000004,
|
||||||
|
"output_cost_per_token": 0.00032,
|
||||||
"litellm_provider": "azure_ai",
|
"litellm_provider": "azure_ai",
|
||||||
"mode": "chat",
|
"mode": "chat",
|
||||||
"supports_audio_input": true,
|
"supports_audio_input": true,
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
"supports_vision": true,
|
"supports_vision": true,
|
||||||
"source": "https://learn.microsoft.com/en-us/azure/ai-foundry/concepts/models-featured#microsoft"
|
"source": "https://techcommunity.microsoft.com/blog/Azure-AI-Services-blog/announcing-new-phi-pricing-empowering-your-business-with-small-language-models/4395112"
|
||||||
},
|
},
|
||||||
"azure_ai/Phi-4": {
|
"azure_ai/Phi-4": {
|
||||||
"max_tokens": 16384,
|
"max_tokens": 16384,
|
||||||
|
@ -3467,7 +3468,7 @@
|
||||||
"input_cost_per_token": 0.0000008,
|
"input_cost_per_token": 0.0000008,
|
||||||
"output_cost_per_token": 0.000004,
|
"output_cost_per_token": 0.000004,
|
||||||
"cache_creation_input_token_cost": 0.000001,
|
"cache_creation_input_token_cost": 0.000001,
|
||||||
"cache_read_input_token_cost": 0.0000008,
|
"cache_read_input_token_cost": 0.00000008,
|
||||||
"litellm_provider": "anthropic",
|
"litellm_provider": "anthropic",
|
||||||
"mode": "chat",
|
"mode": "chat",
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
|
|
|
@ -1625,6 +1625,7 @@ class LiteLLM_UserTable(LiteLLMPydanticObjectBase):
|
||||||
model_max_budget: Optional[Dict] = {}
|
model_max_budget: Optional[Dict] = {}
|
||||||
model_spend: Optional[Dict] = {}
|
model_spend: Optional[Dict] = {}
|
||||||
user_email: Optional[str] = None
|
user_email: Optional[str] = None
|
||||||
|
user_alias: Optional[str] = None
|
||||||
models: list = []
|
models: list = []
|
||||||
tpm_limit: Optional[int] = None
|
tpm_limit: Optional[int] = None
|
||||||
rpm_limit: Optional[int] = None
|
rpm_limit: Optional[int] = None
|
||||||
|
|
|
@ -4,16 +4,26 @@ import json
|
||||||
import uuid
|
import uuid
|
||||||
from base64 import b64encode
|
from base64 import b64encode
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
from urllib.parse import parse_qs, urlencode, urlparse
|
from urllib.parse import parse_qs, urlencode, urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
from fastapi import (
|
||||||
|
APIRouter,
|
||||||
|
Depends,
|
||||||
|
HTTPException,
|
||||||
|
Request,
|
||||||
|
Response,
|
||||||
|
UploadFile,
|
||||||
|
status,
|
||||||
|
)
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
from starlette.datastructures import UploadFile as StarletteUploadFile
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||||
from litellm.proxy._types import (
|
from litellm.proxy._types import (
|
||||||
ConfigFieldInfo,
|
ConfigFieldInfo,
|
||||||
|
@ -358,6 +368,92 @@ class HttpPassThroughEndpointHelpers:
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def non_streaming_http_request_handler(
|
||||||
|
request: Request,
|
||||||
|
async_client: httpx.AsyncClient,
|
||||||
|
url: httpx.URL,
|
||||||
|
headers: dict,
|
||||||
|
requested_query_params: Optional[dict] = None,
|
||||||
|
_parsed_body: Optional[dict] = None,
|
||||||
|
) -> httpx.Response:
|
||||||
|
"""
|
||||||
|
Handle non-streaming HTTP requests
|
||||||
|
|
||||||
|
Handles special cases when GET requests, multipart/form-data requests, and generic httpx requests
|
||||||
|
"""
|
||||||
|
if request.method == "GET":
|
||||||
|
response = await async_client.request(
|
||||||
|
method=request.method,
|
||||||
|
url=url,
|
||||||
|
headers=headers,
|
||||||
|
params=requested_query_params,
|
||||||
|
)
|
||||||
|
elif HttpPassThroughEndpointHelpers.is_multipart(request) is True:
|
||||||
|
return await HttpPassThroughEndpointHelpers.make_multipart_http_request(
|
||||||
|
request=request,
|
||||||
|
async_client=async_client,
|
||||||
|
url=url,
|
||||||
|
headers=headers,
|
||||||
|
requested_query_params=requested_query_params,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Generic httpx method
|
||||||
|
response = await async_client.request(
|
||||||
|
method=request.method,
|
||||||
|
url=url,
|
||||||
|
headers=headers,
|
||||||
|
params=requested_query_params,
|
||||||
|
json=_parsed_body,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_multipart(request: Request) -> bool:
|
||||||
|
"""Check if the request is a multipart/form-data request"""
|
||||||
|
return "multipart/form-data" in request.headers.get("content-type", "")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _build_request_files_from_upload_file(
|
||||||
|
upload_file: Union[UploadFile, StarletteUploadFile],
|
||||||
|
) -> Tuple[Optional[str], bytes, Optional[str]]:
|
||||||
|
"""Build a request files dict from an UploadFile object"""
|
||||||
|
file_content = await upload_file.read()
|
||||||
|
return (upload_file.filename, file_content, upload_file.content_type)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def make_multipart_http_request(
|
||||||
|
request: Request,
|
||||||
|
async_client: httpx.AsyncClient,
|
||||||
|
url: httpx.URL,
|
||||||
|
headers: dict,
|
||||||
|
requested_query_params: Optional[dict] = None,
|
||||||
|
) -> httpx.Response:
|
||||||
|
"""Process multipart/form-data requests, handling both files and form fields"""
|
||||||
|
form_data = await request.form()
|
||||||
|
files = {}
|
||||||
|
form_data_dict = {}
|
||||||
|
|
||||||
|
for field_name, field_value in form_data.items():
|
||||||
|
if isinstance(field_value, (StarletteUploadFile, UploadFile)):
|
||||||
|
files[field_name] = (
|
||||||
|
await HttpPassThroughEndpointHelpers._build_request_files_from_upload_file(
|
||||||
|
upload_file=field_value
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
form_data_dict[field_name] = field_value
|
||||||
|
|
||||||
|
response = await async_client.request(
|
||||||
|
method=request.method,
|
||||||
|
url=url,
|
||||||
|
headers=headers,
|
||||||
|
params=requested_query_params,
|
||||||
|
files=files,
|
||||||
|
data=form_data_dict,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
async def pass_through_request( # noqa: PLR0915
|
async def pass_through_request( # noqa: PLR0915
|
||||||
request: Request,
|
request: Request,
|
||||||
|
@ -424,7 +520,7 @@ async def pass_through_request( # noqa: PLR0915
|
||||||
start_time = datetime.now()
|
start_time = datetime.now()
|
||||||
logging_obj = Logging(
|
logging_obj = Logging(
|
||||||
model="unknown",
|
model="unknown",
|
||||||
messages=[{"role": "user", "content": json.dumps(_parsed_body)}],
|
messages=[{"role": "user", "content": safe_dumps(_parsed_body)}],
|
||||||
stream=False,
|
stream=False,
|
||||||
call_type="pass_through_endpoint",
|
call_type="pass_through_endpoint",
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
|
@ -453,7 +549,6 @@ async def pass_through_request( # noqa: PLR0915
|
||||||
logging_obj.model_call_details["litellm_call_id"] = litellm_call_id
|
logging_obj.model_call_details["litellm_call_id"] = litellm_call_id
|
||||||
|
|
||||||
# combine url with query params for logging
|
# combine url with query params for logging
|
||||||
|
|
||||||
requested_query_params: Optional[dict] = (
|
requested_query_params: Optional[dict] = (
|
||||||
query_params or request.query_params.__dict__
|
query_params or request.query_params.__dict__
|
||||||
)
|
)
|
||||||
|
@ -474,7 +569,7 @@ async def pass_through_request( # noqa: PLR0915
|
||||||
logging_url = str(url) + "?" + requested_query_params_str
|
logging_url = str(url) + "?" + requested_query_params_str
|
||||||
|
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=[{"role": "user", "content": json.dumps(_parsed_body)}],
|
input=[{"role": "user", "content": safe_dumps(_parsed_body)}],
|
||||||
api_key="",
|
api_key="",
|
||||||
additional_args={
|
additional_args={
|
||||||
"complete_input_dict": _parsed_body,
|
"complete_input_dict": _parsed_body,
|
||||||
|
@ -525,22 +620,16 @@ async def pass_through_request( # noqa: PLR0915
|
||||||
)
|
)
|
||||||
verbose_proxy_logger.debug("request body: {}".format(_parsed_body))
|
verbose_proxy_logger.debug("request body: {}".format(_parsed_body))
|
||||||
|
|
||||||
if request.method == "GET":
|
response = (
|
||||||
response = await async_client.request(
|
await HttpPassThroughEndpointHelpers.non_streaming_http_request_handler(
|
||||||
method=request.method,
|
request=request,
|
||||||
|
async_client=async_client,
|
||||||
url=url,
|
url=url,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
params=requested_query_params,
|
requested_query_params=requested_query_params,
|
||||||
|
_parsed_body=_parsed_body,
|
||||||
)
|
)
|
||||||
else:
|
)
|
||||||
response = await async_client.request(
|
|
||||||
method=request.method,
|
|
||||||
url=url,
|
|
||||||
headers=headers,
|
|
||||||
params=requested_query_params,
|
|
||||||
json=_parsed_body,
|
|
||||||
)
|
|
||||||
|
|
||||||
verbose_proxy_logger.debug("response.headers= %s", response.headers)
|
verbose_proxy_logger.debug("response.headers= %s", response.headers)
|
||||||
|
|
||||||
if _is_streaming_response(response) is True:
|
if _is_streaming_response(response) is True:
|
||||||
|
|
|
@ -187,6 +187,7 @@ class Tools(TypedDict, total=False):
|
||||||
function_declarations: List[FunctionDeclaration]
|
function_declarations: List[FunctionDeclaration]
|
||||||
googleSearch: dict
|
googleSearch: dict
|
||||||
googleSearchRetrieval: dict
|
googleSearchRetrieval: dict
|
||||||
|
enterpriseWebSearch: dict
|
||||||
code_execution: dict
|
code_execution: dict
|
||||||
retrieval: Retrieval
|
retrieval: Retrieval
|
||||||
|
|
||||||
|
@ -497,6 +498,51 @@ class OutputConfig(TypedDict, total=False):
|
||||||
gcsDestination: GcsDestination
|
gcsDestination: GcsDestination
|
||||||
|
|
||||||
|
|
||||||
|
class GcsBucketResponse(TypedDict):
|
||||||
|
"""
|
||||||
|
TypedDict for GCS bucket upload response
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
kind: The kind of item this is. For objects, this is always storage#object
|
||||||
|
id: The ID of the object
|
||||||
|
selfLink: The link to this object
|
||||||
|
mediaLink: The link to download the object
|
||||||
|
name: The name of the object
|
||||||
|
bucket: The name of the bucket containing this object
|
||||||
|
generation: The content generation of this object
|
||||||
|
metageneration: The metadata generation of this object
|
||||||
|
contentType: The content type of the object
|
||||||
|
storageClass: The storage class of the object
|
||||||
|
size: The size of the object in bytes
|
||||||
|
md5Hash: The MD5 hash of the object
|
||||||
|
crc32c: The CRC32c checksum of the object
|
||||||
|
etag: The ETag of the object
|
||||||
|
timeCreated: The creation time of the object
|
||||||
|
updated: The last update time of the object
|
||||||
|
timeStorageClassUpdated: The time the storage class was last updated
|
||||||
|
timeFinalized: The time the object was finalized
|
||||||
|
"""
|
||||||
|
|
||||||
|
kind: Literal["storage#object"]
|
||||||
|
id: str
|
||||||
|
selfLink: str
|
||||||
|
mediaLink: str
|
||||||
|
name: str
|
||||||
|
bucket: str
|
||||||
|
generation: str
|
||||||
|
metageneration: str
|
||||||
|
contentType: str
|
||||||
|
storageClass: str
|
||||||
|
size: str
|
||||||
|
md5Hash: str
|
||||||
|
crc32c: str
|
||||||
|
etag: str
|
||||||
|
timeCreated: str
|
||||||
|
updated: str
|
||||||
|
timeStorageClassUpdated: str
|
||||||
|
timeFinalized: str
|
||||||
|
|
||||||
|
|
||||||
class VertexAIBatchPredictionJob(TypedDict):
|
class VertexAIBatchPredictionJob(TypedDict):
|
||||||
displayName: str
|
displayName: str
|
||||||
model: str
|
model: str
|
||||||
|
|
|
@ -2,7 +2,7 @@ import json
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
from typing import Any, Dict, List, Literal, Mapping, Optional, Tuple, Union
|
||||||
|
|
||||||
from aiohttp import FormData
|
from aiohttp import FormData
|
||||||
from openai._models import BaseModel as OpenAIObject
|
from openai._models import BaseModel as OpenAIObject
|
||||||
|
@ -2170,3 +2170,20 @@ class CreateCredentialItem(CredentialBase):
|
||||||
if not values.get("credential_values") and not values.get("model_id"):
|
if not values.get("credential_values") and not values.get("model_id"):
|
||||||
raise ValueError("Either credential_values or model_id must be set")
|
raise ValueError("Either credential_values or model_id must be set")
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractedFileData(TypedDict):
|
||||||
|
"""
|
||||||
|
TypedDict for storing processed file data
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
filename: Name of the file if provided
|
||||||
|
content: The file content in bytes
|
||||||
|
content_type: MIME type of the file
|
||||||
|
headers: Any additional headers for the file
|
||||||
|
"""
|
||||||
|
|
||||||
|
filename: Optional[str]
|
||||||
|
content: bytes
|
||||||
|
content_type: Optional[str]
|
||||||
|
headers: Mapping[str, str]
|
||||||
|
|
|
@ -6517,6 +6517,10 @@ class ProviderConfigManager:
|
||||||
)
|
)
|
||||||
|
|
||||||
return GoogleAIStudioFilesHandler()
|
return GoogleAIStudioFilesHandler()
|
||||||
|
elif LlmProviders.VERTEX_AI == provider:
|
||||||
|
from litellm.llms.vertex_ai.files.transformation import VertexAIFilesConfig
|
||||||
|
|
||||||
|
return VertexAIFilesConfig()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2409,25 +2409,26 @@
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
"max_input_tokens": 131072,
|
"max_input_tokens": 131072,
|
||||||
"max_output_tokens": 4096,
|
"max_output_tokens": 4096,
|
||||||
"input_cost_per_token": 0,
|
"input_cost_per_token": 0.000000075,
|
||||||
"output_cost_per_token": 0,
|
"output_cost_per_token": 0.0000003,
|
||||||
"litellm_provider": "azure_ai",
|
"litellm_provider": "azure_ai",
|
||||||
"mode": "chat",
|
"mode": "chat",
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
"source": "https://learn.microsoft.com/en-us/azure/ai-foundry/concepts/models-featured#microsoft"
|
"source": "https://techcommunity.microsoft.com/blog/Azure-AI-Services-blog/announcing-new-phi-pricing-empowering-your-business-with-small-language-models/4395112"
|
||||||
},
|
},
|
||||||
"azure_ai/Phi-4-multimodal-instruct": {
|
"azure_ai/Phi-4-multimodal-instruct": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
"max_input_tokens": 131072,
|
"max_input_tokens": 131072,
|
||||||
"max_output_tokens": 4096,
|
"max_output_tokens": 4096,
|
||||||
"input_cost_per_token": 0,
|
"input_cost_per_token": 0.00000008,
|
||||||
"output_cost_per_token": 0,
|
"input_cost_per_audio_token": 0.000004,
|
||||||
|
"output_cost_per_token": 0.00032,
|
||||||
"litellm_provider": "azure_ai",
|
"litellm_provider": "azure_ai",
|
||||||
"mode": "chat",
|
"mode": "chat",
|
||||||
"supports_audio_input": true,
|
"supports_audio_input": true,
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
"supports_vision": true,
|
"supports_vision": true,
|
||||||
"source": "https://learn.microsoft.com/en-us/azure/ai-foundry/concepts/models-featured#microsoft"
|
"source": "https://techcommunity.microsoft.com/blog/Azure-AI-Services-blog/announcing-new-phi-pricing-empowering-your-business-with-small-language-models/4395112"
|
||||||
},
|
},
|
||||||
"azure_ai/Phi-4": {
|
"azure_ai/Phi-4": {
|
||||||
"max_tokens": 16384,
|
"max_tokens": 16384,
|
||||||
|
@ -3467,7 +3468,7 @@
|
||||||
"input_cost_per_token": 0.0000008,
|
"input_cost_per_token": 0.0000008,
|
||||||
"output_cost_per_token": 0.000004,
|
"output_cost_per_token": 0.000004,
|
||||||
"cache_creation_input_token_cost": 0.000001,
|
"cache_creation_input_token_cost": 0.000001,
|
||||||
"cache_read_input_token_cost": 0.0000008,
|
"cache_read_input_token_cost": 0.00000008,
|
||||||
"litellm_provider": "anthropic",
|
"litellm_provider": "anthropic",
|
||||||
"mode": "chat",
|
"mode": "chat",
|
||||||
"supports_function_calling": true,
|
"supports_function_calling": true,
|
||||||
|
|
|
@ -423,25 +423,35 @@ mock_vertex_batch_response = {
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_avertex_batch_prediction():
|
async def test_avertex_batch_prediction(monkeypatch):
|
||||||
with patch(
|
monkeypatch.setenv("GCS_BUCKET_NAME", "litellm-local")
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
|
|
||||||
|
client = AsyncHTTPHandler()
|
||||||
|
|
||||||
|
async def mock_side_effect(*args, **kwargs):
|
||||||
|
print("args", args, "kwargs", kwargs)
|
||||||
|
url = kwargs.get("url", "")
|
||||||
|
if "files" in url:
|
||||||
|
mock_response.json.return_value = mock_file_response
|
||||||
|
elif "batch" in url:
|
||||||
|
mock_response.json.return_value = mock_vertex_batch_response
|
||||||
|
mock_response.status_code = 200
|
||||||
|
return mock_response
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
client, "post", side_effect=mock_side_effect
|
||||||
|
) as mock_post, patch(
|
||||||
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post"
|
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post"
|
||||||
) as mock_post:
|
) as mock_global_post:
|
||||||
# Configure mock responses
|
# Configure mock responses
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.raise_for_status.return_value = None
|
mock_response.raise_for_status.return_value = None
|
||||||
|
|
||||||
# Set up different responses for different API calls
|
# Set up different responses for different API calls
|
||||||
async def mock_side_effect(*args, **kwargs):
|
|
||||||
url = kwargs.get("url", "")
|
|
||||||
if "files" in url:
|
|
||||||
mock_response.json.return_value = mock_file_response
|
|
||||||
elif "batch" in url:
|
|
||||||
mock_response.json.return_value = mock_vertex_batch_response
|
|
||||||
mock_response.status_code = 200
|
|
||||||
return mock_response
|
|
||||||
|
|
||||||
mock_post.side_effect = mock_side_effect
|
mock_post.side_effect = mock_side_effect
|
||||||
|
mock_global_post.side_effect = mock_side_effect
|
||||||
|
|
||||||
# load_vertex_ai_credentials()
|
# load_vertex_ai_credentials()
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
@ -455,6 +465,7 @@ async def test_avertex_batch_prediction():
|
||||||
file=open(file_path, "rb"),
|
file=open(file_path, "rb"),
|
||||||
purpose="batch",
|
purpose="batch",
|
||||||
custom_llm_provider="vertex_ai",
|
custom_llm_provider="vertex_ai",
|
||||||
|
client=client
|
||||||
)
|
)
|
||||||
print("Response from creating file=", file_obj)
|
print("Response from creating file=", file_obj)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,116 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from io import BytesIO
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
from fastapi import Request, UploadFile
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from starlette.datastructures import Headers
|
||||||
|
from starlette.datastructures import UploadFile as StarletteUploadFile
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
|
||||||
|
HttpPassThroughEndpointHelpers,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Test is_multipart
|
||||||
|
def test_is_multipart():
|
||||||
|
# Test with multipart content type
|
||||||
|
request = MagicMock(spec=Request)
|
||||||
|
request.headers = Headers({"content-type": "multipart/form-data; boundary=123"})
|
||||||
|
assert HttpPassThroughEndpointHelpers.is_multipart(request) is True
|
||||||
|
|
||||||
|
# Test with non-multipart content type
|
||||||
|
request.headers = Headers({"content-type": "application/json"})
|
||||||
|
assert HttpPassThroughEndpointHelpers.is_multipart(request) is False
|
||||||
|
|
||||||
|
# Test with no content type
|
||||||
|
request.headers = Headers({})
|
||||||
|
assert HttpPassThroughEndpointHelpers.is_multipart(request) is False
|
||||||
|
|
||||||
|
|
||||||
|
# Test _build_request_files_from_upload_file
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_build_request_files_from_upload_file():
|
||||||
|
# Test with FastAPI UploadFile
|
||||||
|
file_content = b"test content"
|
||||||
|
file = BytesIO(file_content)
|
||||||
|
# Create SpooledTemporaryFile with content type headers
|
||||||
|
headers = {"content-type": "text/plain"}
|
||||||
|
upload_file = UploadFile(file=file, filename="test.txt", headers=headers)
|
||||||
|
upload_file.read = AsyncMock(return_value=file_content)
|
||||||
|
|
||||||
|
result = await HttpPassThroughEndpointHelpers._build_request_files_from_upload_file(
|
||||||
|
upload_file
|
||||||
|
)
|
||||||
|
assert result == ("test.txt", file_content, "text/plain")
|
||||||
|
|
||||||
|
# Test with Starlette UploadFile
|
||||||
|
file2 = BytesIO(file_content)
|
||||||
|
starlette_file = StarletteUploadFile(
|
||||||
|
file=file2,
|
||||||
|
filename="test2.txt",
|
||||||
|
headers=Headers({"content-type": "text/plain"}),
|
||||||
|
)
|
||||||
|
starlette_file.read = AsyncMock(return_value=file_content)
|
||||||
|
|
||||||
|
result = await HttpPassThroughEndpointHelpers._build_request_files_from_upload_file(
|
||||||
|
starlette_file
|
||||||
|
)
|
||||||
|
assert result == ("test2.txt", file_content, "text/plain")
|
||||||
|
|
||||||
|
|
||||||
|
# Test make_multipart_http_request
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_make_multipart_http_request():
|
||||||
|
# Mock request with file and form field
|
||||||
|
request = MagicMock(spec=Request)
|
||||||
|
request.method = "POST"
|
||||||
|
|
||||||
|
# Mock form data
|
||||||
|
file_content = b"test file content"
|
||||||
|
file = BytesIO(file_content)
|
||||||
|
# Create SpooledTemporaryFile with content type headers
|
||||||
|
headers = {"content-type": "text/plain"}
|
||||||
|
upload_file = UploadFile(file=file, filename="test.txt", headers=headers)
|
||||||
|
upload_file.read = AsyncMock(return_value=file_content)
|
||||||
|
|
||||||
|
form_data = {"file": upload_file, "text_field": "test value"}
|
||||||
|
request.form = AsyncMock(return_value=form_data)
|
||||||
|
|
||||||
|
# Mock httpx client
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.headers = {}
|
||||||
|
|
||||||
|
async_client = MagicMock()
|
||||||
|
async_client.request = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
# Test the function
|
||||||
|
response = await HttpPassThroughEndpointHelpers.make_multipart_http_request(
|
||||||
|
request=request,
|
||||||
|
async_client=async_client,
|
||||||
|
url=httpx.URL("http://test.com"),
|
||||||
|
headers={},
|
||||||
|
requested_query_params=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the response
|
||||||
|
assert response == mock_response
|
||||||
|
|
||||||
|
# Verify the client call
|
||||||
|
async_client.request.assert_called_once()
|
||||||
|
call_args = async_client.request.call_args[1]
|
||||||
|
|
||||||
|
assert call_args["method"] == "POST"
|
||||||
|
assert str(call_args["url"]) == "http://test.com"
|
||||||
|
assert isinstance(call_args["files"], dict)
|
||||||
|
assert isinstance(call_args["data"], dict)
|
||||||
|
assert call_args["data"]["text_field"] == "test value"
|
|
@ -323,7 +323,8 @@ class TestHuggingFace(BaseLLMChatTest):
|
||||||
model="huggingface/fireworks-ai/meta-llama/Meta-Llama-3-8B-Instruct",
|
model="huggingface/fireworks-ai/meta-llama/Meta-Llama-3-8B-Instruct",
|
||||||
messages=[{"role": "user", "content": "Hello"}],
|
messages=[{"role": "user", "content": "Hello"}],
|
||||||
optional_params={},
|
optional_params={},
|
||||||
api_key="test_api_key"
|
api_key="test_api_key",
|
||||||
|
litellm_params={}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert headers["Authorization"] == "Bearer test_api_key"
|
assert headers["Authorization"] == "Bearer test_api_key"
|
||||||
|
|
|
@ -141,6 +141,7 @@ def test_build_vertex_schema():
|
||||||
[
|
[
|
||||||
([{"googleSearch": {}}], "googleSearch"),
|
([{"googleSearch": {}}], "googleSearch"),
|
||||||
([{"googleSearchRetrieval": {}}], "googleSearchRetrieval"),
|
([{"googleSearchRetrieval": {}}], "googleSearchRetrieval"),
|
||||||
|
([{"enterpriseWebSearch": {}}], "enterpriseWebSearch"),
|
||||||
([{"code_execution": {}}], "code_execution"),
|
([{"code_execution": {}}], "code_execution"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
2
tests/local_testing/example.jsonl
Normal file
2
tests/local_testing/example.jsonl
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello world!"}], "max_tokens": 10}}
|
||||||
|
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are an unhelpful assistant."}, {"role": "user", "content": "Hello world!"}], "max_tokens": 10}}
|
|
@ -21,7 +21,7 @@ from litellm.integrations.gcs_bucket.gcs_bucket import (
|
||||||
StandardLoggingPayload,
|
StandardLoggingPayload,
|
||||||
)
|
)
|
||||||
from litellm.types.utils import StandardCallbackDynamicParams
|
from litellm.types.utils import StandardCallbackDynamicParams
|
||||||
|
from unittest.mock import patch
|
||||||
verbose_logger.setLevel(logging.DEBUG)
|
verbose_logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
|
@ -687,3 +687,63 @@ async def test_basic_gcs_logger_with_folder_in_bucket_name():
|
||||||
# clean up
|
# clean up
|
||||||
if old_bucket_name is not None:
|
if old_bucket_name is not None:
|
||||||
os.environ["GCS_BUCKET_NAME"] = old_bucket_name
|
os.environ["GCS_BUCKET_NAME"] = old_bucket_name
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="This test is flaky on ci/cd")
|
||||||
|
def test_create_file_e2e():
|
||||||
|
"""
|
||||||
|
Asserts 'create_file' is called with the correct arguments
|
||||||
|
"""
|
||||||
|
load_vertex_ai_credentials()
|
||||||
|
test_file_content = b"test audio content"
|
||||||
|
test_file = ("test.wav", test_file_content, "audio/wav")
|
||||||
|
|
||||||
|
from litellm import create_file
|
||||||
|
response = create_file(
|
||||||
|
file=test_file,
|
||||||
|
purpose="user_data",
|
||||||
|
custom_llm_provider="vertex_ai",
|
||||||
|
)
|
||||||
|
print("response", response)
|
||||||
|
assert response is not None
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="This test is flaky on ci/cd")
|
||||||
|
def test_create_file_e2e_jsonl():
|
||||||
|
"""
|
||||||
|
Asserts 'create_file' is called with the correct arguments
|
||||||
|
"""
|
||||||
|
load_vertex_ai_credentials()
|
||||||
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
|
|
||||||
|
client = HTTPHandler()
|
||||||
|
|
||||||
|
example_jsonl = [{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}},{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}]
|
||||||
|
|
||||||
|
# Create and write to the file
|
||||||
|
file_path = "example.jsonl"
|
||||||
|
with open(file_path, "w") as f:
|
||||||
|
for item in example_jsonl:
|
||||||
|
f.write(json.dumps(item) + "\n")
|
||||||
|
|
||||||
|
# Verify file content
|
||||||
|
with open(file_path, "r") as f:
|
||||||
|
content = f.read()
|
||||||
|
print("File content:", content)
|
||||||
|
assert len(content) > 0, "File is empty"
|
||||||
|
|
||||||
|
from litellm import create_file
|
||||||
|
with patch.object(client, "post") as mock_create_file:
|
||||||
|
try:
|
||||||
|
response = create_file(
|
||||||
|
file=open(file_path, "rb"),
|
||||||
|
purpose="user_data",
|
||||||
|
custom_llm_provider="vertex_ai",
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print("error", e)
|
||||||
|
|
||||||
|
mock_create_file.assert_called_once()
|
||||||
|
|
||||||
|
print(f"kwargs: {mock_create_file.call_args.kwargs}")
|
||||||
|
|
||||||
|
assert mock_create_file.call_args.kwargs["data"] is not None and len(mock_create_file.call_args.kwargs["data"]) > 0
|
|
@ -2,14 +2,31 @@ import pytest
|
||||||
import openai
|
import openai
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import tempfile
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from openai import AssistantEventHandler
|
from openai import AssistantEventHandler
|
||||||
|
|
||||||
|
|
||||||
client = openai.OpenAI(base_url="http://0.0.0.0:4000/openai", api_key="sk-1234")
|
client = openai.OpenAI(base_url="http://0.0.0.0:4000/openai", api_key="sk-1234")
|
||||||
|
|
||||||
|
def test_pass_through_file_operations():
|
||||||
|
# Create a temporary file
|
||||||
|
with tempfile.NamedTemporaryFile(mode='w+', suffix='.txt', delete=False) as temp_file:
|
||||||
|
temp_file.write("This is a test file for the OpenAI Assistants API.")
|
||||||
|
temp_file.flush()
|
||||||
|
|
||||||
|
# create a file
|
||||||
|
file = client.files.create(
|
||||||
|
file=open(temp_file.name, "rb"),
|
||||||
|
purpose="assistants",
|
||||||
|
)
|
||||||
|
print("file created", file)
|
||||||
|
|
||||||
|
# delete the file
|
||||||
|
delete_file = client.files.delete(file.id)
|
||||||
|
print("file deleted", delete_file)
|
||||||
|
|
||||||
def test_openai_assistants_e2e_operations():
|
def test_openai_assistants_e2e_operations():
|
||||||
|
|
||||||
assistant = client.beta.assistants.create(
|
assistant = client.beta.assistants.create(
|
||||||
name="Math Tutor",
|
name="Math Tutor",
|
||||||
instructions="You are a personal math tutor. Write and run code to answer math questions.",
|
instructions="You are a personal math tutor. Write and run code to answer math questions.",
|
||||||
|
|
|
@ -13,9 +13,12 @@ import { Organization, userListCall } from "./networking";
|
||||||
import { createTeamSearchFunction } from "./key_team_helpers/team_search_fn";
|
import { createTeamSearchFunction } from "./key_team_helpers/team_search_fn";
|
||||||
import { createOrgSearchFunction } from "./key_team_helpers/organization_search_fn";
|
import { createOrgSearchFunction } from "./key_team_helpers/organization_search_fn";
|
||||||
import { useFilterLogic } from "./key_team_helpers/filter_logic";
|
import { useFilterLogic } from "./key_team_helpers/filter_logic";
|
||||||
|
import { Setter } from "@/types";
|
||||||
|
import { updateExistingKeys } from "@/utils/dataUtils";
|
||||||
|
|
||||||
interface AllKeysTableProps {
|
interface AllKeysTableProps {
|
||||||
keys: KeyResponse[];
|
keys: KeyResponse[];
|
||||||
|
setKeys: Setter<KeyResponse[]>;
|
||||||
isLoading?: boolean;
|
isLoading?: boolean;
|
||||||
pagination: {
|
pagination: {
|
||||||
currentPage: number;
|
currentPage: number;
|
||||||
|
@ -87,6 +90,7 @@ const TeamFilter = ({
|
||||||
*/
|
*/
|
||||||
export function AllKeysTable({
|
export function AllKeysTable({
|
||||||
keys,
|
keys,
|
||||||
|
setKeys,
|
||||||
isLoading = false,
|
isLoading = false,
|
||||||
pagination,
|
pagination,
|
||||||
onPageChange,
|
onPageChange,
|
||||||
|
@ -364,6 +368,23 @@ export function AllKeysTable({
|
||||||
keyId={selectedKeyId}
|
keyId={selectedKeyId}
|
||||||
onClose={() => setSelectedKeyId(null)}
|
onClose={() => setSelectedKeyId(null)}
|
||||||
keyData={keys.find(k => k.token === selectedKeyId)}
|
keyData={keys.find(k => k.token === selectedKeyId)}
|
||||||
|
onKeyDataUpdate={(updatedKeyData) => {
|
||||||
|
setKeys(keys => keys.map(key => {
|
||||||
|
if (key.token === updatedKeyData.token) {
|
||||||
|
// The shape of key is different from that of
|
||||||
|
// updatedKeyData(received from keyUpdateCall in networking.tsx).
|
||||||
|
// Hence, we can't replace key with updatedKeys since it might lead
|
||||||
|
// to unintended bugs/behaviors.
|
||||||
|
// So instead, we only update fields that are present in both.
|
||||||
|
return updateExistingKeys(key, updatedKeyData)
|
||||||
|
}
|
||||||
|
|
||||||
|
return key
|
||||||
|
}))
|
||||||
|
}}
|
||||||
|
onDelete={() => {
|
||||||
|
setKeys(keys => keys.filter(key => key.token !== selectedKeyId))
|
||||||
|
}}
|
||||||
accessToken={accessToken}
|
accessToken={accessToken}
|
||||||
userID={userID}
|
userID={userID}
|
||||||
userRole={userRole}
|
userRole={userRole}
|
||||||
|
|
|
@ -27,13 +27,15 @@ interface KeyInfoViewProps {
|
||||||
keyId: string;
|
keyId: string;
|
||||||
onClose: () => void;
|
onClose: () => void;
|
||||||
keyData: KeyResponse | undefined;
|
keyData: KeyResponse | undefined;
|
||||||
|
onKeyDataUpdate?: (data: Partial<KeyResponse>) => void;
|
||||||
|
onDelete?: () => void;
|
||||||
accessToken: string | null;
|
accessToken: string | null;
|
||||||
userID: string | null;
|
userID: string | null;
|
||||||
userRole: string | null;
|
userRole: string | null;
|
||||||
teams: any[] | null;
|
teams: any[] | null;
|
||||||
}
|
}
|
||||||
|
|
||||||
export default function KeyInfoView({ keyId, onClose, keyData, accessToken, userID, userRole, teams }: KeyInfoViewProps) {
|
export default function KeyInfoView({ keyId, onClose, keyData, accessToken, userID, userRole, teams, onKeyDataUpdate, onDelete }: KeyInfoViewProps) {
|
||||||
const [isEditing, setIsEditing] = useState(false);
|
const [isEditing, setIsEditing] = useState(false);
|
||||||
const [form] = Form.useForm();
|
const [form] = Form.useForm();
|
||||||
const [isDeleteModalOpen, setIsDeleteModalOpen] = useState(false);
|
const [isDeleteModalOpen, setIsDeleteModalOpen] = useState(false);
|
||||||
|
@ -93,6 +95,9 @@ export default function KeyInfoView({ keyId, onClose, keyData, accessToken, user
|
||||||
}
|
}
|
||||||
|
|
||||||
const newKeyValues = await keyUpdateCall(accessToken, formValues);
|
const newKeyValues = await keyUpdateCall(accessToken, formValues);
|
||||||
|
if (onKeyDataUpdate) {
|
||||||
|
onKeyDataUpdate(newKeyValues)
|
||||||
|
}
|
||||||
message.success("Key updated successfully");
|
message.success("Key updated successfully");
|
||||||
setIsEditing(false);
|
setIsEditing(false);
|
||||||
// Refresh key data here if needed
|
// Refresh key data here if needed
|
||||||
|
@ -107,6 +112,9 @@ export default function KeyInfoView({ keyId, onClose, keyData, accessToken, user
|
||||||
if (!accessToken) return;
|
if (!accessToken) return;
|
||||||
await keyDeleteCall(accessToken as string, keyData.token);
|
await keyDeleteCall(accessToken as string, keyData.token);
|
||||||
message.success("Key deleted successfully");
|
message.success("Key deleted successfully");
|
||||||
|
if (onDelete) {
|
||||||
|
onDelete()
|
||||||
|
}
|
||||||
onClose();
|
onClose();
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error("Error deleting the key:", error);
|
console.error("Error deleting the key:", error);
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import { useState, useEffect } from 'react';
|
import { useState, useEffect } from 'react';
|
||||||
import { keyListCall, Organization } from '../networking';
|
import { keyListCall, Organization } from '../networking';
|
||||||
|
import { Setter } from '@/types';
|
||||||
|
|
||||||
export interface Team {
|
export interface Team {
|
||||||
team_id: string;
|
team_id: string;
|
||||||
|
@ -94,13 +95,14 @@ totalPages: number;
|
||||||
totalCount: number;
|
totalCount: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
interface UseKeyListReturn {
|
interface UseKeyListReturn {
|
||||||
keys: KeyResponse[];
|
keys: KeyResponse[];
|
||||||
isLoading: boolean;
|
isLoading: boolean;
|
||||||
error: Error | null;
|
error: Error | null;
|
||||||
pagination: PaginationData;
|
pagination: PaginationData;
|
||||||
refresh: (params?: Record<string, unknown>) => Promise<void>;
|
refresh: (params?: Record<string, unknown>) => Promise<void>;
|
||||||
setKeys: (newKeysOrUpdater: KeyResponse[] | ((prevKeys: KeyResponse[]) => KeyResponse[])) => void;
|
setKeys: Setter<KeyResponse[]>;
|
||||||
}
|
}
|
||||||
|
|
||||||
const useKeyList = ({
|
const useKeyList = ({
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
import { all_admin_roles } from "@/utils/roles";
|
import { all_admin_roles } from "@/utils/roles";
|
||||||
import { message } from "antd";
|
import { message } from "antd";
|
||||||
import { TagNewRequest, TagUpdateRequest, TagDeleteRequest, TagInfoRequest, TagListResponse, TagInfoResponse } from "./tag_management/types";
|
import { TagNewRequest, TagUpdateRequest, TagDeleteRequest, TagInfoRequest, TagListResponse, TagInfoResponse } from "./tag_management/types";
|
||||||
|
import { Team } from "./key_team_helpers/key_list";
|
||||||
|
|
||||||
const isLocal = process.env.NODE_ENV === "development";
|
const isLocal = process.env.NODE_ENV === "development";
|
||||||
export const proxyBaseUrl = isLocal ? "http://localhost:4000" : null;
|
export const proxyBaseUrl = isLocal ? "http://localhost:4000" : null;
|
||||||
|
@ -2983,7 +2984,7 @@ export const teamUpdateCall = async (
|
||||||
console.error("Error response from the server:", errorData);
|
console.error("Error response from the server:", errorData);
|
||||||
throw new Error("Network response was not ok");
|
throw new Error("Network response was not ok");
|
||||||
}
|
}
|
||||||
const data = await response.json();
|
const data = await response.json() as { data: Team, team_id: string };
|
||||||
console.log("Update Team Response:", data);
|
console.log("Update Team Response:", data);
|
||||||
return data;
|
return data;
|
||||||
// Handle success - you might want to update some state or UI based on the created key
|
// Handle success - you might want to update some state or UI based on the created key
|
||||||
|
|
|
@ -30,6 +30,7 @@ import { PencilAltIcon, PlusIcon, TrashIcon } from "@heroicons/react/outline";
|
||||||
import MemberModal from "./edit_membership";
|
import MemberModal from "./edit_membership";
|
||||||
import UserSearchModal from "@/components/common_components/user_search_modal";
|
import UserSearchModal from "@/components/common_components/user_search_modal";
|
||||||
import { getModelDisplayName } from "../key_team_helpers/fetch_available_models_team_key";
|
import { getModelDisplayName } from "../key_team_helpers/fetch_available_models_team_key";
|
||||||
|
import { Team } from "../key_team_helpers/key_list";
|
||||||
|
|
||||||
|
|
||||||
interface TeamData {
|
interface TeamData {
|
||||||
|
@ -69,6 +70,7 @@ interface TeamInfoProps {
|
||||||
is_proxy_admin: boolean;
|
is_proxy_admin: boolean;
|
||||||
userModels: string[];
|
userModels: string[];
|
||||||
editTeam: boolean;
|
editTeam: boolean;
|
||||||
|
onUpdate?: (team: Team) => void
|
||||||
}
|
}
|
||||||
|
|
||||||
const TeamInfoView: React.FC<TeamInfoProps> = ({
|
const TeamInfoView: React.FC<TeamInfoProps> = ({
|
||||||
|
@ -78,7 +80,8 @@ const TeamInfoView: React.FC<TeamInfoProps> = ({
|
||||||
is_team_admin,
|
is_team_admin,
|
||||||
is_proxy_admin,
|
is_proxy_admin,
|
||||||
userModels,
|
userModels,
|
||||||
editTeam
|
editTeam,
|
||||||
|
onUpdate
|
||||||
}) => {
|
}) => {
|
||||||
const [teamData, setTeamData] = useState<TeamData | null>(null);
|
const [teamData, setTeamData] = useState<TeamData | null>(null);
|
||||||
const [loading, setLoading] = useState(true);
|
const [loading, setLoading] = useState(true);
|
||||||
|
@ -199,7 +202,10 @@ const TeamInfoView: React.FC<TeamInfoProps> = ({
|
||||||
};
|
};
|
||||||
|
|
||||||
const response = await teamUpdateCall(accessToken, updateData);
|
const response = await teamUpdateCall(accessToken, updateData);
|
||||||
|
if (onUpdate) {
|
||||||
|
onUpdate(response.data)
|
||||||
|
}
|
||||||
|
|
||||||
message.success("Team settings updated successfully");
|
message.success("Team settings updated successfully");
|
||||||
setIsEditing(false);
|
setIsEditing(false);
|
||||||
fetchTeamInfo();
|
fetchTeamInfo();
|
||||||
|
|
|
@ -84,6 +84,7 @@ import {
|
||||||
modelAvailableCall,
|
modelAvailableCall,
|
||||||
teamListCall
|
teamListCall
|
||||||
} from "./networking";
|
} from "./networking";
|
||||||
|
import { updateExistingKeys } from "@/utils/dataUtils";
|
||||||
|
|
||||||
const getOrganizationModels = (organization: Organization | null, userModels: string[]) => {
|
const getOrganizationModels = (organization: Organization | null, userModels: string[]) => {
|
||||||
let tempModelsToPick = [];
|
let tempModelsToPick = [];
|
||||||
|
@ -321,6 +322,22 @@ const Teams: React.FC<TeamProps> = ({
|
||||||
{selectedTeamId ? (
|
{selectedTeamId ? (
|
||||||
<TeamInfoView
|
<TeamInfoView
|
||||||
teamId={selectedTeamId}
|
teamId={selectedTeamId}
|
||||||
|
onUpdate={data => {
|
||||||
|
setTeams(teams => {
|
||||||
|
if (teams == null) {
|
||||||
|
return teams;
|
||||||
|
}
|
||||||
|
|
||||||
|
return teams.map(team => {
|
||||||
|
if (data.team_id === team.team_id) {
|
||||||
|
return updateExistingKeys(team, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
return team
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
}}
|
||||||
onClose={() => {
|
onClose={() => {
|
||||||
setSelectedTeamId(null);
|
setSelectedTeamId(null);
|
||||||
setEditTeam(false);
|
setEditTeam(false);
|
||||||
|
|
|
@ -418,6 +418,7 @@ const ViewKeyTable: React.FC<ViewKeyTableProps> = ({
|
||||||
<div>
|
<div>
|
||||||
<AllKeysTable
|
<AllKeysTable
|
||||||
keys={keys}
|
keys={keys}
|
||||||
|
setKeys={setKeys}
|
||||||
isLoading={isLoading}
|
isLoading={isLoading}
|
||||||
pagination={pagination}
|
pagination={pagination}
|
||||||
onPageChange={handlePageChange}
|
onPageChange={handlePageChange}
|
||||||
|
|
1
ui/litellm-dashboard/src/types.ts
Normal file
1
ui/litellm-dashboard/src/types.ts
Normal file
|
@ -0,0 +1 @@
|
||||||
|
export type Setter<T> = (newValueOrUpdater: T | ((previousValue: T) => T)) => void
|
14
ui/litellm-dashboard/src/utils/dataUtils.ts
Normal file
14
ui/litellm-dashboard/src/utils/dataUtils.ts
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
export function updateExistingKeys<Source extends Object>(
|
||||||
|
target: Source,
|
||||||
|
source: Object
|
||||||
|
): Source {
|
||||||
|
const clonedTarget = structuredClone(target);
|
||||||
|
|
||||||
|
for (const [key, value] of Object.entries(source)) {
|
||||||
|
if (key in clonedTarget) {
|
||||||
|
(clonedTarget as any)[key] = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return clonedTarget;
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue