Merge branch 'main' into litellm_msft_group_assignment

This commit is contained in:
Ishaan Jaff 2025-04-09 15:34:12 -07:00
commit b2b82ecd66
84 changed files with 1302 additions and 223 deletions

View file

@ -610,6 +610,8 @@ jobs:
name: Install Dependencies name: Install Dependencies
command: | command: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install wheel
pip install --upgrade pip wheel setuptools
python -m pip install -r requirements.txt python -m pip install -r requirements.txt
pip install "pytest==7.3.1" pip install "pytest==7.3.1"
pip install "respx==0.21.1" pip install "respx==0.21.1"

View file

@ -2,6 +2,10 @@ apiVersion: v1
kind: Service kind: Service
metadata: metadata:
name: {{ include "litellm.fullname" . }} name: {{ include "litellm.fullname" . }}
{{- with .Values.service.annotations }}
annotations:
{{- toYaml . | nindent 4 }}
{{- end }}
labels: labels:
{{- include "litellm.labels" . | nindent 4 }} {{- include "litellm.labels" . | nindent 4 }}
spec: spec:

View file

@ -438,6 +438,179 @@ assert isinstance(
``` ```
### Google Search Tool
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import completion
import os
os.environ["GEMINI_API_KEY"] = ".."
tools = [{"googleSearch": {}}] # 👈 ADD GOOGLE SEARCH
response = completion(
model="gemini/gemini-2.0-flash",
messages=[{"role": "user", "content": "What is the weather in San Francisco?"}],
tools=tools,
)
print(response)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Setup config.yaml
```yaml
model_list:
- model_name: gemini-2.0-flash
litellm_params:
model: gemini/gemini-2.0-flash
api_key: os.environ/GEMINI_API_KEY
```
2. Start Proxy
```bash
$ litellm --config /path/to/config.yaml
```
3. Make Request!
```bash
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer sk-1234' \
-d '{
"model": "gemini-2.0-flash",
"messages": [{"role": "user", "content": "What is the weather in San Francisco?"}],
"tools": [{"googleSearch": {}}]
}
'
```
</TabItem>
</Tabs>
### Google Search Retrieval
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import completion
import os
os.environ["GEMINI_API_KEY"] = ".."
tools = [{"googleSearchRetrieval": {}}] # 👈 ADD GOOGLE SEARCH
response = completion(
model="gemini/gemini-2.0-flash",
messages=[{"role": "user", "content": "What is the weather in San Francisco?"}],
tools=tools,
)
print(response)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Setup config.yaml
```yaml
model_list:
- model_name: gemini-2.0-flash
litellm_params:
model: gemini/gemini-2.0-flash
api_key: os.environ/GEMINI_API_KEY
```
2. Start Proxy
```bash
$ litellm --config /path/to/config.yaml
```
3. Make Request!
```bash
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer sk-1234' \
-d '{
"model": "gemini-2.0-flash",
"messages": [{"role": "user", "content": "What is the weather in San Francisco?"}],
"tools": [{"googleSearchRetrieval": {}}]
}
'
```
</TabItem>
</Tabs>
### Code Execution Tool
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import completion
import os
os.environ["GEMINI_API_KEY"] = ".."
tools = [{"codeExecution": {}}] # 👈 ADD GOOGLE SEARCH
response = completion(
model="gemini/gemini-2.0-flash",
messages=[{"role": "user", "content": "What is the weather in San Francisco?"}],
tools=tools,
)
print(response)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Setup config.yaml
```yaml
model_list:
- model_name: gemini-2.0-flash
litellm_params:
model: gemini/gemini-2.0-flash
api_key: os.environ/GEMINI_API_KEY
```
2. Start Proxy
```bash
$ litellm --config /path/to/config.yaml
```
3. Make Request!
```bash
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer sk-1234' \
-d '{
"model": "gemini-2.0-flash",
"messages": [{"role": "user", "content": "What is the weather in San Francisco?"}],
"tools": [{"codeExecution": {}}]
}
'
```
</TabItem>
</Tabs>
## JSON Mode ## JSON Mode
<Tabs> <Tabs>

View file

@ -398,6 +398,8 @@ curl http://localhost:4000/v1/chat/completions \
</TabItem> </TabItem>
</Tabs> </Tabs>
You can also use the `enterpriseWebSearch` tool for an [enterprise compliant search](https://cloud.google.com/vertex-ai/generative-ai/docs/grounding/web-grounding-enterprise).
#### **Moving from Vertex AI SDK to LiteLLM (GROUNDING)** #### **Moving from Vertex AI SDK to LiteLLM (GROUNDING)**

View file

@ -110,5 +110,8 @@ def get_litellm_params(
"azure_password": kwargs.get("azure_password"), "azure_password": kwargs.get("azure_password"),
"max_retries": max_retries, "max_retries": max_retries,
"timeout": kwargs.get("timeout"), "timeout": kwargs.get("timeout"),
"bucket_name": kwargs.get("bucket_name"),
"vertex_credentials": kwargs.get("vertex_credentials"),
"vertex_project": kwargs.get("vertex_project"),
} }
return litellm_params return litellm_params

View file

@ -2,7 +2,10 @@
Common utility functions used for translating messages across providers Common utility functions used for translating messages across providers
""" """
from typing import Dict, List, Literal, Optional, Union, cast import io
import mimetypes
from os import PathLike
from typing import Dict, List, Literal, Mapping, Optional, Union, cast
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
AllMessageValues, AllMessageValues,
@ -10,7 +13,13 @@ from litellm.types.llms.openai import (
ChatCompletionFileObject, ChatCompletionFileObject,
ChatCompletionUserMessage, ChatCompletionUserMessage,
) )
from litellm.types.utils import Choices, ModelResponse, StreamingChoices from litellm.types.utils import (
Choices,
ExtractedFileData,
FileTypes,
ModelResponse,
StreamingChoices,
)
DEFAULT_USER_CONTINUE_MESSAGE = ChatCompletionUserMessage( DEFAULT_USER_CONTINUE_MESSAGE = ChatCompletionUserMessage(
content="Please continue.", role="user" content="Please continue.", role="user"
@ -350,6 +359,68 @@ def update_messages_with_model_file_ids(
return messages return messages
def extract_file_data(file_data: FileTypes) -> ExtractedFileData:
"""
Extracts and processes file data from various input formats.
Args:
file_data: Can be a tuple of (filename, content, [content_type], [headers]) or direct file content
Returns:
ExtractedFileData containing:
- filename: Name of the file if provided
- content: The file content in bytes
- content_type: MIME type of the file
- headers: Any additional headers
"""
# Parse the file_data based on its type
filename = None
file_content = None
content_type = None
file_headers: Mapping[str, str] = {}
if isinstance(file_data, tuple):
if len(file_data) == 2:
filename, file_content = file_data
elif len(file_data) == 3:
filename, file_content, content_type = file_data
elif len(file_data) == 4:
filename, file_content, content_type, file_headers = file_data
else:
file_content = file_data
# Convert content to bytes
if isinstance(file_content, (str, PathLike)):
# If it's a path, open and read the file
with open(file_content, "rb") as f:
content = f.read()
elif isinstance(file_content, io.IOBase):
# If it's a file-like object
content = file_content.read()
if isinstance(content, str):
content = content.encode("utf-8")
# Reset file pointer to beginning
file_content.seek(0)
elif isinstance(file_content, bytes):
content = file_content
else:
raise ValueError(f"Unsupported file content type: {type(file_content)}")
# Use provided content type or guess based on filename
if not content_type:
content_type = (
mimetypes.guess_type(filename)[0]
if filename
else "application/octet-stream"
)
return ExtractedFileData(
filename=filename,
content=content,
content_type=content_type,
headers=file_headers,
)
def unpack_defs(schema, defs): def unpack_defs(schema, defs):
properties = schema.get("properties", None) properties = schema.get("properties", None)
if properties is None: if properties is None:
@ -381,3 +452,4 @@ def unpack_defs(schema, defs):
unpack_defs(ref, defs) unpack_defs(ref, defs)
value["items"] = ref value["items"] = ref
continue continue

View file

@ -50,6 +50,7 @@ class AiohttpOpenAIChatConfig(OpenAILikeChatConfig):
model: str, model: str,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
) -> dict: ) -> dict:

View file

@ -301,6 +301,7 @@ class AnthropicChatCompletion(BaseLLM):
model=model, model=model,
messages=messages, messages=messages,
optional_params={**optional_params, "is_vertex_request": is_vertex_request}, optional_params={**optional_params, "is_vertex_request": is_vertex_request},
litellm_params=litellm_params,
) )
config = ProviderConfigManager.get_provider_chat_config( config = ProviderConfigManager.get_provider_chat_config(

View file

@ -868,6 +868,7 @@ class AnthropicConfig(BaseConfig):
model: str, model: str,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
) -> Dict: ) -> Dict:

View file

@ -87,6 +87,7 @@ class AnthropicTextConfig(BaseConfig):
model: str, model: str,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
) -> dict: ) -> dict:

View file

@ -293,6 +293,7 @@ class AzureOpenAIConfig(BaseConfig):
model: str, model: str,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
) -> dict: ) -> dict:

View file

@ -39,6 +39,7 @@ class AzureAIStudioConfig(OpenAIConfig):
model: str, model: str,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
) -> dict: ) -> dict:

View file

@ -262,6 +262,7 @@ class BaseConfig(ABC):
model: str, model: str,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
) -> dict: ) -> dict:

View file

@ -1,5 +1,5 @@
from abc import abstractmethod from abc import abstractmethod
from typing import TYPE_CHECKING, Any, List, Optional from typing import TYPE_CHECKING, Any, List, Optional, Union
import httpx import httpx
@ -33,23 +33,22 @@ class BaseFilesConfig(BaseConfig):
) -> List[OpenAICreateFileRequestOptionalParams]: ) -> List[OpenAICreateFileRequestOptionalParams]:
pass pass
def get_complete_url( def get_complete_file_url(
self, self,
api_base: Optional[str], api_base: Optional[str],
api_key: Optional[str], api_key: Optional[str],
model: str, model: str,
optional_params: dict, optional_params: dict,
litellm_params: dict, litellm_params: dict,
stream: Optional[bool] = None, data: CreateFileRequest,
) -> str: ):
""" return self.get_complete_url(
OPTIONAL api_base=api_base,
api_key=api_key,
Get the complete url for the request model=model,
optional_params=optional_params,
Some providers need `model` in `api_base` litellm_params=litellm_params,
""" )
return api_base or ""
@abstractmethod @abstractmethod
def transform_create_file_request( def transform_create_file_request(
@ -58,7 +57,7 @@ class BaseFilesConfig(BaseConfig):
create_file_data: CreateFileRequest, create_file_data: CreateFileRequest,
optional_params: dict, optional_params: dict,
litellm_params: dict, litellm_params: dict,
) -> dict: ) -> Union[dict, str, bytes]:
pass pass
@abstractmethod @abstractmethod

View file

@ -65,6 +65,7 @@ class BaseImageVariationConfig(BaseConfig, ABC):
model: str, model: str,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
) -> dict: ) -> dict:

View file

@ -831,6 +831,7 @@ class AmazonConverseConfig(BaseConfig):
model: str, model: str,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
) -> dict: ) -> dict:

View file

@ -442,6 +442,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
model: str, model: str,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
) -> dict: ) -> dict:

View file

@ -118,6 +118,7 @@ class ClarifaiConfig(BaseConfig):
model: str, model: str,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
) -> dict: ) -> dict:

View file

@ -60,6 +60,7 @@ class CloudflareChatConfig(BaseConfig):
model: str, model: str,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
) -> dict: ) -> dict:

View file

@ -118,6 +118,7 @@ class CohereChatConfig(BaseConfig):
model: str, model: str,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
) -> dict: ) -> dict:

View file

@ -101,6 +101,7 @@ class CohereTextConfig(BaseConfig):
model: str, model: str,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
) -> dict: ) -> dict:

View file

@ -229,6 +229,7 @@ class BaseLLMAIOHTTPHandler:
model=model, model=model,
messages=messages, messages=messages,
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params,
api_base=api_base, api_base=api_base,
) )
@ -498,6 +499,7 @@ class BaseLLMAIOHTTPHandler:
model=model, model=model,
messages=[{"role": "user", "content": "test"}], messages=[{"role": "user", "content": "test"}],
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params,
api_base=api_base, api_base=api_base,
) )

View file

@ -192,7 +192,7 @@ class AsyncHTTPHandler:
async def post( async def post(
self, self,
url: str, url: str,
data: Optional[Union[dict, str]] = None, # type: ignore data: Optional[Union[dict, str, bytes]] = None, # type: ignore
json: Optional[dict] = None, json: Optional[dict] = None,
params: Optional[dict] = None, params: Optional[dict] = None,
headers: Optional[dict] = None, headers: Optional[dict] = None,
@ -427,7 +427,7 @@ class AsyncHTTPHandler:
self, self,
url: str, url: str,
client: httpx.AsyncClient, client: httpx.AsyncClient,
data: Optional[Union[dict, str]] = None, # type: ignore data: Optional[Union[dict, str, bytes]] = None, # type: ignore
json: Optional[dict] = None, json: Optional[dict] = None,
params: Optional[dict] = None, params: Optional[dict] = None,
headers: Optional[dict] = None, headers: Optional[dict] = None,
@ -527,7 +527,7 @@ class HTTPHandler:
def post( def post(
self, self,
url: str, url: str,
data: Optional[Union[dict, str]] = None, data: Optional[Union[dict, str, bytes]] = None,
json: Optional[Union[dict, str, List]] = None, json: Optional[Union[dict, str, List]] = None,
params: Optional[dict] = None, params: Optional[dict] = None,
headers: Optional[dict] = None, headers: Optional[dict] = None,
@ -573,7 +573,6 @@ class HTTPHandler:
setattr(e, "text", error_text) setattr(e, "text", error_text)
setattr(e, "status_code", e.response.status_code) setattr(e, "status_code", e.response.status_code)
raise e raise e
except Exception as e: except Exception as e:
raise e raise e

View file

@ -247,6 +247,7 @@ class BaseLLMHTTPHandler:
messages=messages, messages=messages,
optional_params=optional_params, optional_params=optional_params,
api_base=api_base, api_base=api_base,
litellm_params=litellm_params,
) )
api_base = provider_config.get_complete_url( api_base = provider_config.get_complete_url(
@ -625,6 +626,7 @@ class BaseLLMHTTPHandler:
model=model, model=model,
messages=[], messages=[],
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params,
) )
api_base = provider_config.get_complete_url( api_base = provider_config.get_complete_url(
@ -896,6 +898,7 @@ class BaseLLMHTTPHandler:
model=model, model=model,
messages=[], messages=[],
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params,
) )
if client is None or not isinstance(client, HTTPHandler): if client is None or not isinstance(client, HTTPHandler):
@ -1228,15 +1231,19 @@ class BaseLLMHTTPHandler:
model="", model="",
messages=[], messages=[],
optional_params={}, optional_params={},
litellm_params=litellm_params,
) )
api_base = provider_config.get_complete_url( api_base = provider_config.get_complete_file_url(
api_base=api_base, api_base=api_base,
api_key=api_key, api_key=api_key,
model="", model="",
optional_params={}, optional_params={},
litellm_params=litellm_params, litellm_params=litellm_params,
data=create_file_data,
) )
if api_base is None:
raise ValueError("api_base is required for create_file")
# Get the transformed request data for both steps # Get the transformed request data for both steps
transformed_request = provider_config.transform_create_file_request( transformed_request = provider_config.transform_create_file_request(
@ -1263,48 +1270,57 @@ class BaseLLMHTTPHandler:
else: else:
sync_httpx_client = client sync_httpx_client = client
try: if isinstance(transformed_request, str) or isinstance(
# Step 1: Initial request to get upload URL transformed_request, bytes
initial_response = sync_httpx_client.post( ):
url=api_base,
headers={
**headers,
**transformed_request["initial_request"]["headers"],
},
data=json.dumps(transformed_request["initial_request"]["data"]),
timeout=timeout,
)
# Extract upload URL from response headers
upload_url = initial_response.headers.get("X-Goog-Upload-URL")
if not upload_url:
raise ValueError("Failed to get upload URL from initial request")
# Step 2: Upload the actual file
upload_response = sync_httpx_client.post( upload_response = sync_httpx_client.post(
url=upload_url, url=api_base,
headers=transformed_request["upload_request"]["headers"], headers=headers,
data=transformed_request["upload_request"]["data"], data=transformed_request,
timeout=timeout, timeout=timeout,
) )
else:
try:
# Step 1: Initial request to get upload URL
initial_response = sync_httpx_client.post(
url=api_base,
headers={
**headers,
**transformed_request["initial_request"]["headers"],
},
data=json.dumps(transformed_request["initial_request"]["data"]),
timeout=timeout,
)
return provider_config.transform_create_file_response( # Extract upload URL from response headers
model=None, upload_url = initial_response.headers.get("X-Goog-Upload-URL")
raw_response=upload_response,
logging_obj=logging_obj,
litellm_params=litellm_params,
)
except Exception as e: if not upload_url:
raise self._handle_error( raise ValueError("Failed to get upload URL from initial request")
e=e,
provider_config=provider_config, # Step 2: Upload the actual file
) upload_response = sync_httpx_client.post(
url=upload_url,
headers=transformed_request["upload_request"]["headers"],
data=transformed_request["upload_request"]["data"],
timeout=timeout,
)
except Exception as e:
raise self._handle_error(
e=e,
provider_config=provider_config,
)
return provider_config.transform_create_file_response(
model=None,
raw_response=upload_response,
logging_obj=logging_obj,
litellm_params=litellm_params,
)
async def async_create_file( async def async_create_file(
self, self,
transformed_request: dict, transformed_request: Union[bytes, str, dict],
litellm_params: dict, litellm_params: dict,
provider_config: BaseFilesConfig, provider_config: BaseFilesConfig,
headers: dict, headers: dict,
@ -1323,45 +1339,54 @@ class BaseLLMHTTPHandler:
else: else:
async_httpx_client = client async_httpx_client = client
try: if isinstance(transformed_request, str) or isinstance(
# Step 1: Initial request to get upload URL transformed_request, bytes
initial_response = await async_httpx_client.post( ):
url=api_base,
headers={
**headers,
**transformed_request["initial_request"]["headers"],
},
data=json.dumps(transformed_request["initial_request"]["data"]),
timeout=timeout,
)
# Extract upload URL from response headers
upload_url = initial_response.headers.get("X-Goog-Upload-URL")
if not upload_url:
raise ValueError("Failed to get upload URL from initial request")
# Step 2: Upload the actual file
upload_response = await async_httpx_client.post( upload_response = await async_httpx_client.post(
url=upload_url, url=api_base,
headers=transformed_request["upload_request"]["headers"], headers=headers,
data=transformed_request["upload_request"]["data"], data=transformed_request,
timeout=timeout, timeout=timeout,
) )
else:
try:
# Step 1: Initial request to get upload URL
initial_response = await async_httpx_client.post(
url=api_base,
headers={
**headers,
**transformed_request["initial_request"]["headers"],
},
data=json.dumps(transformed_request["initial_request"]["data"]),
timeout=timeout,
)
return provider_config.transform_create_file_response( # Extract upload URL from response headers
model=None, upload_url = initial_response.headers.get("X-Goog-Upload-URL")
raw_response=upload_response,
logging_obj=logging_obj,
litellm_params=litellm_params,
)
except Exception as e: if not upload_url:
verbose_logger.exception(f"Error creating file: {e}") raise ValueError("Failed to get upload URL from initial request")
raise self._handle_error(
e=e, # Step 2: Upload the actual file
provider_config=provider_config, upload_response = await async_httpx_client.post(
) url=upload_url,
headers=transformed_request["upload_request"]["headers"],
data=transformed_request["upload_request"]["data"],
timeout=timeout,
)
except Exception as e:
verbose_logger.exception(f"Error creating file: {e}")
raise self._handle_error(
e=e,
provider_config=provider_config,
)
return provider_config.transform_create_file_response(
model=None,
raw_response=upload_response,
logging_obj=logging_obj,
litellm_params=litellm_params,
)
def list_files(self): def list_files(self):
""" """

View file

@ -116,6 +116,7 @@ class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig):
model: str, model: str,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
) -> dict: ) -> dict:

View file

@ -171,6 +171,7 @@ class DeepgramAudioTranscriptionConfig(BaseAudioTranscriptionConfig):
model: str, model: str,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
) -> dict: ) -> dict:

View file

@ -41,6 +41,7 @@ class FireworksAIMixin:
model: str, model: str,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
) -> dict: ) -> dict:

View file

@ -20,6 +20,7 @@ class GeminiModelInfo(BaseLLMModelInfo):
model: str, model: str,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
) -> dict: ) -> dict:

View file

@ -4,11 +4,12 @@ Supports writing files to Google AI Studio Files API.
For vertex ai, check out the vertex_ai/files/handler.py file. For vertex ai, check out the vertex_ai/files/handler.py file.
""" """
import time import time
from typing import List, Mapping, Optional from typing import List, Optional
import httpx import httpx
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data
from litellm.llms.base_llm.files.transformation import ( from litellm.llms.base_llm.files.transformation import (
BaseFilesConfig, BaseFilesConfig,
LiteLLMLoggingObj, LiteLLMLoggingObj,
@ -91,66 +92,28 @@ class GoogleAIStudioFilesHandler(GeminiModelInfo, BaseFilesConfig):
if file_data is None: if file_data is None:
raise ValueError("File data is required") raise ValueError("File data is required")
# Parse the file_data based on its type # Use the common utility function to extract file data
filename = None extracted_data = extract_file_data(file_data)
file_content = None
content_type = None
file_headers: Mapping[str, str] = {}
if isinstance(file_data, tuple):
if len(file_data) == 2:
filename, file_content = file_data
elif len(file_data) == 3:
filename, file_content, content_type = file_data
elif len(file_data) == 4:
filename, file_content, content_type, file_headers = file_data
else:
file_content = file_data
# Handle the file content based on its type
import io
from os import PathLike
# Convert content to bytes
if isinstance(file_content, (str, PathLike)):
# If it's a path, open and read the file
with open(file_content, "rb") as f:
content = f.read()
elif isinstance(file_content, io.IOBase):
# If it's a file-like object
content = file_content.read()
if isinstance(content, str):
content = content.encode("utf-8")
elif isinstance(file_content, bytes):
content = file_content
else:
raise ValueError(f"Unsupported file content type: {type(file_content)}")
# Get file size # Get file size
file_size = len(content) file_size = len(extracted_data["content"])
# Use provided content type or guess based on filename
if not content_type:
import mimetypes
content_type = (
mimetypes.guess_type(filename)[0]
if filename
else "application/octet-stream"
)
# Step 1: Initial resumable upload request # Step 1: Initial resumable upload request
headers = { headers = {
"X-Goog-Upload-Protocol": "resumable", "X-Goog-Upload-Protocol": "resumable",
"X-Goog-Upload-Command": "start", "X-Goog-Upload-Command": "start",
"X-Goog-Upload-Header-Content-Length": str(file_size), "X-Goog-Upload-Header-Content-Length": str(file_size),
"X-Goog-Upload-Header-Content-Type": content_type, "X-Goog-Upload-Header-Content-Type": extracted_data["content_type"],
"Content-Type": "application/json", "Content-Type": "application/json",
} }
headers.update(file_headers) # Add any custom headers headers.update(extracted_data["headers"]) # Add any custom headers
# Initial metadata request body # Initial metadata request body
initial_data = {"file": {"display_name": filename or str(int(time.time()))}} initial_data = {
"file": {
"display_name": extracted_data["filename"] or str(int(time.time()))
}
}
# Step 2: Actual file upload data # Step 2: Actual file upload data
upload_headers = { upload_headers = {
@ -161,7 +124,10 @@ class GoogleAIStudioFilesHandler(GeminiModelInfo, BaseFilesConfig):
return { return {
"initial_request": {"headers": headers, "data": initial_data}, "initial_request": {"headers": headers, "data": initial_data},
"upload_request": {"headers": upload_headers, "data": content}, "upload_request": {
"headers": upload_headers,
"data": extracted_data["content"],
},
} }
def transform_create_file_response( def transform_create_file_response(

View file

@ -1,6 +1,6 @@
import logging import logging
import os import os
from typing import TYPE_CHECKING, Any, List, Optional, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import httpx import httpx
@ -18,7 +18,6 @@ from litellm.llms.base_llm.chat.transformation import BaseLLMException
from ...openai.chat.gpt_transformation import OpenAIGPTConfig from ...openai.chat.gpt_transformation import OpenAIGPTConfig
from ..common_utils import HuggingFaceError, _fetch_inference_provider_mapping from ..common_utils import HuggingFaceError, _fetch_inference_provider_mapping
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
BASE_URL = "https://router.huggingface.co" BASE_URL = "https://router.huggingface.co"
@ -34,7 +33,8 @@ class HuggingFaceChatConfig(OpenAIGPTConfig):
headers: dict, headers: dict,
model: str, model: str,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: Dict,
litellm_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
) -> dict: ) -> dict:
@ -51,7 +51,9 @@ class HuggingFaceChatConfig(OpenAIGPTConfig):
def get_error_class( def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException: ) -> BaseLLMException:
return HuggingFaceError(status_code=status_code, message=error_message, headers=headers) return HuggingFaceError(
status_code=status_code, message=error_message, headers=headers
)
def get_base_url(self, model: str, base_url: Optional[str]) -> Optional[str]: def get_base_url(self, model: str, base_url: Optional[str]) -> Optional[str]:
""" """
@ -82,7 +84,9 @@ class HuggingFaceChatConfig(OpenAIGPTConfig):
if api_base is not None: if api_base is not None:
complete_url = api_base complete_url = api_base
elif os.getenv("HF_API_BASE") or os.getenv("HUGGINGFACE_API_BASE"): elif os.getenv("HF_API_BASE") or os.getenv("HUGGINGFACE_API_BASE"):
complete_url = str(os.getenv("HF_API_BASE")) or str(os.getenv("HUGGINGFACE_API_BASE")) complete_url = str(os.getenv("HF_API_BASE")) or str(
os.getenv("HUGGINGFACE_API_BASE")
)
elif model.startswith(("http://", "https://")): elif model.startswith(("http://", "https://")):
complete_url = model complete_url = model
# 4. Default construction with provider # 4. Default construction with provider
@ -138,4 +142,8 @@ class HuggingFaceChatConfig(OpenAIGPTConfig):
) )
mapped_model = provider_mapping["providerId"] mapped_model = provider_mapping["providerId"]
messages = self._transform_messages(messages=messages, model=mapped_model) messages = self._transform_messages(messages=messages, model=mapped_model)
return dict(ChatCompletionRequest(model=mapped_model, messages=messages, **optional_params)) return dict(
ChatCompletionRequest(
model=mapped_model, messages=messages, **optional_params
)
)

View file

@ -1,15 +1,6 @@
import json import json
import os import os
from typing import ( from typing import Any, Callable, Dict, List, Literal, Optional, Union, get_args
Any,
Callable,
Dict,
List,
Literal,
Optional,
Union,
get_args,
)
import httpx import httpx
@ -35,8 +26,9 @@ hf_tasks_embeddings = Literal[ # pipeline tags + hf tei endpoints - https://hug
] ]
def get_hf_task_embedding_for_model(
def get_hf_task_embedding_for_model(model: str, task_type: Optional[str], api_base: str) -> Optional[str]: model: str, task_type: Optional[str], api_base: str
) -> Optional[str]:
if task_type is not None: if task_type is not None:
if task_type in get_args(hf_tasks_embeddings): if task_type in get_args(hf_tasks_embeddings):
return task_type return task_type
@ -57,7 +49,9 @@ def get_hf_task_embedding_for_model(model: str, task_type: Optional[str], api_ba
return pipeline_tag return pipeline_tag
async def async_get_hf_task_embedding_for_model(model: str, task_type: Optional[str], api_base: str) -> Optional[str]: async def async_get_hf_task_embedding_for_model(
model: str, task_type: Optional[str], api_base: str
) -> Optional[str]:
if task_type is not None: if task_type is not None:
if task_type in get_args(hf_tasks_embeddings): if task_type in get_args(hf_tasks_embeddings):
return task_type return task_type
@ -116,7 +110,9 @@ class HuggingFaceEmbedding(BaseLLM):
input: List, input: List,
optional_params: dict, optional_params: dict,
) -> dict: ) -> dict:
hf_task = await async_get_hf_task_embedding_for_model(model=model, task_type=task_type, api_base=HF_HUB_URL) hf_task = await async_get_hf_task_embedding_for_model(
model=model, task_type=task_type, api_base=HF_HUB_URL
)
data = self._transform_input_on_pipeline_tag(input=input, pipeline_tag=hf_task) data = self._transform_input_on_pipeline_tag(input=input, pipeline_tag=hf_task)
@ -173,7 +169,9 @@ class HuggingFaceEmbedding(BaseLLM):
task_type = optional_params.pop("input_type", None) task_type = optional_params.pop("input_type", None)
if call_type == "sync": if call_type == "sync":
hf_task = get_hf_task_embedding_for_model(model=model, task_type=task_type, api_base=HF_HUB_URL) hf_task = get_hf_task_embedding_for_model(
model=model, task_type=task_type, api_base=HF_HUB_URL
)
elif call_type == "async": elif call_type == "async":
return self._async_transform_input( return self._async_transform_input(
model=model, task_type=task_type, embed_url=embed_url, input=input model=model, task_type=task_type, embed_url=embed_url, input=input
@ -325,6 +323,7 @@ class HuggingFaceEmbedding(BaseLLM):
input: list, input: list,
model_response: EmbeddingResponse, model_response: EmbeddingResponse,
optional_params: dict, optional_params: dict,
litellm_params: dict,
logging_obj: LiteLLMLoggingObj, logging_obj: LiteLLMLoggingObj,
encoding: Callable, encoding: Callable,
api_key: Optional[str] = None, api_key: Optional[str] = None,
@ -341,9 +340,12 @@ class HuggingFaceEmbedding(BaseLLM):
model=model, model=model,
optional_params=optional_params, optional_params=optional_params,
messages=[], messages=[],
litellm_params=litellm_params,
) )
task_type = optional_params.pop("input_type", None) task_type = optional_params.pop("input_type", None)
task = get_hf_task_embedding_for_model(model=model, task_type=task_type, api_base=HF_HUB_URL) task = get_hf_task_embedding_for_model(
model=model, task_type=task_type, api_base=HF_HUB_URL
)
# print_verbose(f"{model}, {task}") # print_verbose(f"{model}, {task}")
embed_url = "" embed_url = ""
if "https" in model: if "https" in model:
@ -355,7 +357,9 @@ class HuggingFaceEmbedding(BaseLLM):
elif "HUGGINGFACE_API_BASE" in os.environ: elif "HUGGINGFACE_API_BASE" in os.environ:
embed_url = os.getenv("HUGGINGFACE_API_BASE", "") embed_url = os.getenv("HUGGINGFACE_API_BASE", "")
else: else:
embed_url = f"https://router.huggingface.co/hf-inference/pipeline/{task}/{model}" embed_url = (
f"https://router.huggingface.co/hf-inference/pipeline/{task}/{model}"
)
## ROUTING ## ## ROUTING ##
if aembedding is True: if aembedding is True:

View file

@ -355,6 +355,7 @@ class HuggingFaceEmbeddingConfig(BaseConfig):
model: str, model: str,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: Dict, optional_params: Dict,
litellm_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
) -> Dict: ) -> Dict:

View file

@ -36,6 +36,7 @@ def completion(
model=model, model=model,
messages=messages, messages=messages,
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params,
) )
## Load Config ## Load Config

View file

@ -93,6 +93,7 @@ class NLPCloudConfig(BaseConfig):
model: str, model: str,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
) -> dict: ) -> dict:

View file

@ -353,6 +353,7 @@ class OllamaConfig(BaseConfig):
model: str, model: str,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
) -> dict: ) -> dict:

View file

@ -32,6 +32,7 @@ def completion(
model=model, model=model,
messages=messages, messages=messages,
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params,
) )
if "https" in model: if "https" in model:
completion_url = model completion_url = model
@ -123,6 +124,7 @@ def embedding(
model=model, model=model,
messages=[], messages=[],
optional_params=optional_params, optional_params=optional_params,
litellm_params={},
) )
response = litellm.module_level_client.post( response = litellm.module_level_client.post(
embeddings_url, headers=headers, json=data embeddings_url, headers=headers, json=data

View file

@ -88,6 +88,7 @@ class OobaboogaConfig(OpenAIGPTConfig):
model: str, model: str,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
) -> dict: ) -> dict:

View file

@ -321,6 +321,7 @@ class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
model: str, model: str,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
) -> dict: ) -> dict:

View file

@ -286,6 +286,7 @@ class OpenAIConfig(BaseConfig):
model: str, model: str,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
) -> dict: ) -> dict:

View file

@ -53,6 +53,7 @@ class OpenAIWhisperAudioTranscriptionConfig(BaseAudioTranscriptionConfig):
model: str, model: str,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
) -> dict: ) -> dict:

View file

@ -131,6 +131,7 @@ class PetalsConfig(BaseConfig):
model: str, model: str,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
) -> dict: ) -> dict:

View file

@ -228,10 +228,10 @@ class PredibaseChatCompletion:
api_key: str, api_key: str,
logging_obj, logging_obj,
optional_params: dict, optional_params: dict,
litellm_params: dict,
tenant_id: str, tenant_id: str,
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
acompletion=None, acompletion=None,
litellm_params=None,
logger_fn=None, logger_fn=None,
headers: dict = {}, headers: dict = {},
) -> Union[ModelResponse, CustomStreamWrapper]: ) -> Union[ModelResponse, CustomStreamWrapper]:
@ -241,6 +241,7 @@ class PredibaseChatCompletion:
messages=messages, messages=messages,
optional_params=optional_params, optional_params=optional_params,
model=model, model=model,
litellm_params=litellm_params,
) )
completion_url = "" completion_url = ""
input_text = "" input_text = ""

View file

@ -164,6 +164,7 @@ class PredibaseConfig(BaseConfig):
model: str, model: str,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
) -> dict: ) -> dict:

View file

@ -141,6 +141,7 @@ def completion(
model=model, model=model,
messages=messages, messages=messages,
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params,
) )
# Start a prediction and get the prediction URL # Start a prediction and get the prediction URL
version_id = replicate_config.model_to_version_id(model) version_id = replicate_config.model_to_version_id(model)

View file

@ -312,6 +312,7 @@ class ReplicateConfig(BaseConfig):
model: str, model: str,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
) -> dict: ) -> dict:

View file

@ -96,6 +96,7 @@ class SagemakerLLM(BaseAWSLLM):
model: str, model: str,
data: dict, data: dict,
messages: List[AllMessageValues], messages: List[AllMessageValues],
litellm_params: dict,
optional_params: dict, optional_params: dict,
aws_region_name: str, aws_region_name: str,
extra_headers: Optional[dict] = None, extra_headers: Optional[dict] = None,
@ -122,6 +123,7 @@ class SagemakerLLM(BaseAWSLLM):
model=model, model=model,
messages=messages, messages=messages,
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params,
) )
request = AWSRequest( request = AWSRequest(
method="POST", url=api_base, data=encoded_data, headers=headers method="POST", url=api_base, data=encoded_data, headers=headers
@ -198,6 +200,7 @@ class SagemakerLLM(BaseAWSLLM):
data=data, data=data,
messages=messages, messages=messages,
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params,
credentials=credentials, credentials=credentials,
aws_region_name=aws_region_name, aws_region_name=aws_region_name,
) )
@ -274,6 +277,7 @@ class SagemakerLLM(BaseAWSLLM):
"model": model, "model": model,
"data": _data, "data": _data,
"optional_params": optional_params, "optional_params": optional_params,
"litellm_params": litellm_params,
"credentials": credentials, "credentials": credentials,
"aws_region_name": aws_region_name, "aws_region_name": aws_region_name,
"messages": messages, "messages": messages,
@ -426,6 +430,7 @@ class SagemakerLLM(BaseAWSLLM):
"model": model, "model": model,
"data": data, "data": data,
"optional_params": optional_params, "optional_params": optional_params,
"litellm_params": litellm_params,
"credentials": credentials, "credentials": credentials,
"aws_region_name": aws_region_name, "aws_region_name": aws_region_name,
"messages": messages, "messages": messages,
@ -496,6 +501,7 @@ class SagemakerLLM(BaseAWSLLM):
"model": model, "model": model,
"data": data, "data": data,
"optional_params": optional_params, "optional_params": optional_params,
"litellm_params": litellm_params,
"credentials": credentials, "credentials": credentials,
"aws_region_name": aws_region_name, "aws_region_name": aws_region_name,
"messages": messages, "messages": messages,

View file

@ -263,6 +263,7 @@ class SagemakerConfig(BaseConfig):
model: str, model: str,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
) -> dict: ) -> dict:

View file

@ -92,6 +92,7 @@ class SnowflakeConfig(OpenAIGPTConfig):
model: str, model: str,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
) -> dict: ) -> dict:

View file

@ -37,6 +37,7 @@ class TopazImageVariationConfig(BaseImageVariationConfig):
model: str, model: str,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
) -> dict: ) -> dict:

View file

@ -48,6 +48,7 @@ class TritonConfig(BaseConfig):
model: str, model: str,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: Dict, optional_params: Dict,
litellm_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
) -> Dict: ) -> Dict:

View file

@ -42,6 +42,7 @@ class TritonEmbeddingConfig(BaseEmbeddingConfig):
model: str, model: str,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
) -> dict: ) -> dict:

View file

@ -1,3 +1,4 @@
import asyncio
from typing import Any, Coroutine, Optional, Union from typing import Any, Coroutine, Optional, Union
import httpx import httpx
@ -11,9 +12,9 @@ from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
from litellm.types.llms.openai import CreateFileRequest, OpenAIFileObject from litellm.types.llms.openai import CreateFileRequest, OpenAIFileObject
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
from .transformation import VertexAIFilesTransformation from .transformation import VertexAIJsonlFilesTransformation
vertex_ai_files_transformation = VertexAIFilesTransformation() vertex_ai_files_transformation = VertexAIJsonlFilesTransformation()
class VertexAIFilesHandler(GCSBucketBase): class VertexAIFilesHandler(GCSBucketBase):
@ -92,5 +93,15 @@ class VertexAIFilesHandler(GCSBucketBase):
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
) )
else:
return None # type: ignore return asyncio.run(
self.async_create_file(
create_file_data=create_file_data,
api_base=api_base,
vertex_credentials=vertex_credentials,
vertex_project=vertex_project,
vertex_location=vertex_location,
timeout=timeout,
max_retries=max_retries,
)
)

View file

@ -1,7 +1,17 @@
import json import json
import os
import time
import uuid import uuid
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
from httpx import Headers, Response
from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.llms.base_llm.files.transformation import (
BaseFilesConfig,
LiteLLMLoggingObj,
)
from litellm.llms.vertex_ai.common_utils import ( from litellm.llms.vertex_ai.common_utils import (
_convert_vertex_datetime_to_openai_datetime, _convert_vertex_datetime_to_openai_datetime,
) )
@ -10,14 +20,317 @@ from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
VertexGeminiConfig, VertexGeminiConfig,
) )
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
AllMessageValues,
CreateFileRequest, CreateFileRequest,
FileTypes, FileTypes,
OpenAICreateFileRequestOptionalParams,
OpenAIFileObject, OpenAIFileObject,
PathLike, PathLike,
) )
from litellm.types.llms.vertex_ai import GcsBucketResponse
from litellm.types.utils import ExtractedFileData, LlmProviders
from ..common_utils import VertexAIError
from ..vertex_llm_base import VertexBase
class VertexAIFilesTransformation(VertexGeminiConfig): class VertexAIFilesConfig(VertexBase, BaseFilesConfig):
"""
Config for VertexAI Files
"""
def __init__(self):
self.jsonl_transformation = VertexAIJsonlFilesTransformation()
super().__init__()
@property
def custom_llm_provider(self) -> LlmProviders:
return LlmProviders.VERTEX_AI
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
if not api_key:
api_key, _ = self.get_access_token(
credentials=litellm_params.get("vertex_credentials"),
project_id=litellm_params.get("vertex_project"),
)
if not api_key:
raise ValueError("api_key is required")
headers["Authorization"] = f"Bearer {api_key}"
return headers
def _get_content_from_openai_file(self, openai_file_content: FileTypes) -> str:
"""
Helper to extract content from various OpenAI file types and return as string.
Handles:
- Direct content (str, bytes, IO[bytes])
- Tuple formats: (filename, content, [content_type], [headers])
- PathLike objects
"""
content: Union[str, bytes] = b""
# Extract file content from tuple if necessary
if isinstance(openai_file_content, tuple):
# Take the second element which is always the file content
file_content = openai_file_content[1]
else:
file_content = openai_file_content
# Handle different file content types
if isinstance(file_content, str):
# String content can be used directly
content = file_content
elif isinstance(file_content, bytes):
# Bytes content can be decoded
content = file_content
elif isinstance(file_content, PathLike): # PathLike
with open(str(file_content), "rb") as f:
content = f.read()
elif hasattr(file_content, "read"): # IO[bytes]
# File-like objects need to be read
content = file_content.read()
# Ensure content is string
if isinstance(content, bytes):
content = content.decode("utf-8")
return content
def _get_gcs_object_name_from_batch_jsonl(
self,
openai_jsonl_content: List[Dict[str, Any]],
) -> str:
"""
Gets a unique GCS object name for the VertexAI batch prediction job
named as: litellm-vertex-{model}-{uuid}
"""
_model = openai_jsonl_content[0].get("body", {}).get("model", "")
if "publishers/google/models" not in _model:
_model = f"publishers/google/models/{_model}"
object_name = f"litellm-vertex-files/{_model}/{uuid.uuid4()}"
return object_name
def get_object_name(
self, extracted_file_data: ExtractedFileData, purpose: str
) -> str:
"""
Get the object name for the request
"""
extracted_file_data_content = extracted_file_data.get("content")
if extracted_file_data_content is None:
raise ValueError("file content is required")
if purpose == "batch":
## 1. If jsonl, check if there's a model name
file_content = self._get_content_from_openai_file(
extracted_file_data_content
)
# Split into lines and parse each line as JSON
openai_jsonl_content = [
json.loads(line) for line in file_content.splitlines() if line.strip()
]
if len(openai_jsonl_content) > 0:
return self._get_gcs_object_name_from_batch_jsonl(openai_jsonl_content)
## 2. If not jsonl, return the filename
filename = extracted_file_data.get("filename")
if filename:
return filename
## 3. If no file name, return timestamp
return str(int(time.time()))
def get_complete_file_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: Dict,
litellm_params: Dict,
data: CreateFileRequest,
) -> str:
"""
Get the complete url for the request
"""
bucket_name = litellm_params.get("bucket_name") or os.getenv("GCS_BUCKET_NAME")
if not bucket_name:
raise ValueError("GCS bucket_name is required")
file_data = data.get("file")
purpose = data.get("purpose")
if file_data is None:
raise ValueError("file is required")
if purpose is None:
raise ValueError("purpose is required")
extracted_file_data = extract_file_data(file_data)
object_name = self.get_object_name(extracted_file_data, purpose)
endpoint = (
f"upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}"
)
api_base = api_base or "https://storage.googleapis.com"
if not api_base:
raise ValueError("api_base is required")
return f"{api_base}/{endpoint}"
def get_supported_openai_params(
self, model: str
) -> List[OpenAICreateFileRequestOptionalParams]:
return []
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
return optional_params
def _map_openai_to_vertex_params(
self,
openai_request_body: Dict[str, Any],
) -> Dict[str, Any]:
"""
wrapper to call VertexGeminiConfig.map_openai_params
"""
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
VertexGeminiConfig,
)
config = VertexGeminiConfig()
_model = openai_request_body.get("model", "")
vertex_params = config.map_openai_params(
model=_model,
non_default_params=openai_request_body,
optional_params={},
drop_params=False,
)
return vertex_params
def _transform_openai_jsonl_content_to_vertex_ai_jsonl_content(
self, openai_jsonl_content: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""
Transforms OpenAI JSONL content to VertexAI JSONL content
jsonl body for vertex is {"request": <request_body>}
Example Vertex jsonl
{"request":{"contents": [{"role": "user", "parts": [{"text": "What is the relation between the following video and image samples?"}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/video/animals.mp4", "mimeType": "video/mp4"}}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/image/cricket.jpeg", "mimeType": "image/jpeg"}}]}]}}
{"request":{"contents": [{"role": "user", "parts": [{"text": "Describe what is happening in this video."}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/video/another_video.mov", "mimeType": "video/mov"}}]}]}}
"""
vertex_jsonl_content = []
for _openai_jsonl_content in openai_jsonl_content:
openai_request_body = _openai_jsonl_content.get("body") or {}
vertex_request_body = _transform_request_body(
messages=openai_request_body.get("messages", []),
model=openai_request_body.get("model", ""),
optional_params=self._map_openai_to_vertex_params(openai_request_body),
custom_llm_provider="vertex_ai",
litellm_params={},
cached_content=None,
)
vertex_jsonl_content.append({"request": vertex_request_body})
return vertex_jsonl_content
def transform_create_file_request(
self,
model: str,
create_file_data: CreateFileRequest,
optional_params: dict,
litellm_params: dict,
) -> Union[bytes, str, dict]:
"""
2 Cases:
1. Handle basic file upload
2. Handle batch file upload (.jsonl)
"""
file_data = create_file_data.get("file")
if file_data is None:
raise ValueError("file is required")
extracted_file_data = extract_file_data(file_data)
extracted_file_data_content = extracted_file_data.get("content")
if (
create_file_data.get("purpose") == "batch"
and extracted_file_data.get("content_type") == "application/jsonl"
and extracted_file_data_content is not None
):
## 1. If jsonl, check if there's a model name
file_content = self._get_content_from_openai_file(
extracted_file_data_content
)
# Split into lines and parse each line as JSON
openai_jsonl_content = [
json.loads(line) for line in file_content.splitlines() if line.strip()
]
vertex_jsonl_content = (
self._transform_openai_jsonl_content_to_vertex_ai_jsonl_content(
openai_jsonl_content
)
)
return json.dumps(vertex_jsonl_content)
elif isinstance(extracted_file_data_content, bytes):
return extracted_file_data_content
else:
raise ValueError("Unsupported file content type")
def transform_create_file_response(
self,
model: Optional[str],
raw_response: Response,
logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
) -> OpenAIFileObject:
"""
Transform VertexAI File upload response into OpenAI-style FileObject
"""
response_json = raw_response.json()
try:
response_object = GcsBucketResponse(**response_json) # type: ignore
except Exception as e:
raise VertexAIError(
status_code=raw_response.status_code,
message=f"Error reading GCS bucket response: {e}",
headers=raw_response.headers,
)
gcs_id = response_object.get("id", "")
# Remove the last numeric ID from the path
gcs_id = "/".join(gcs_id.split("/")[:-1]) if gcs_id else ""
return OpenAIFileObject(
purpose=response_object.get("purpose", "batch"),
id=f"gs://{gcs_id}",
filename=response_object.get("name", ""),
created_at=_convert_vertex_datetime_to_openai_datetime(
vertex_datetime=response_object.get("timeCreated", "")
),
status="uploaded",
bytes=int(response_object.get("size", 0)),
object="file",
)
def get_error_class(
self, error_message: str, status_code: int, headers: Union[Dict, Headers]
) -> BaseLLMException:
return VertexAIError(
status_code=status_code, message=error_message, headers=headers
)
class VertexAIJsonlFilesTransformation(VertexGeminiConfig):
""" """
Transforms OpenAI /v1/files/* requests to VertexAI /v1/files/* requests Transforms OpenAI /v1/files/* requests to VertexAI /v1/files/* requests
""" """

View file

@ -240,6 +240,7 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
gtool_func_declarations = [] gtool_func_declarations = []
googleSearch: Optional[dict] = None googleSearch: Optional[dict] = None
googleSearchRetrieval: Optional[dict] = None googleSearchRetrieval: Optional[dict] = None
enterpriseWebSearch: Optional[dict] = None
code_execution: Optional[dict] = None code_execution: Optional[dict] = None
# remove 'additionalProperties' from tools # remove 'additionalProperties' from tools
value = _remove_additional_properties(value) value = _remove_additional_properties(value)
@ -273,6 +274,8 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
googleSearch = tool["googleSearch"] googleSearch = tool["googleSearch"]
elif tool.get("googleSearchRetrieval", None) is not None: elif tool.get("googleSearchRetrieval", None) is not None:
googleSearchRetrieval = tool["googleSearchRetrieval"] googleSearchRetrieval = tool["googleSearchRetrieval"]
elif tool.get("enterpriseWebSearch", None) is not None:
enterpriseWebSearch = tool["enterpriseWebSearch"]
elif tool.get("code_execution", None) is not None: elif tool.get("code_execution", None) is not None:
code_execution = tool["code_execution"] code_execution = tool["code_execution"]
elif openai_function_object is not None: elif openai_function_object is not None:
@ -299,6 +302,8 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
_tools["googleSearch"] = googleSearch _tools["googleSearch"] = googleSearch
if googleSearchRetrieval is not None: if googleSearchRetrieval is not None:
_tools["googleSearchRetrieval"] = googleSearchRetrieval _tools["googleSearchRetrieval"] = googleSearchRetrieval
if enterpriseWebSearch is not None:
_tools["enterpriseWebSearch"] = enterpriseWebSearch
if code_execution is not None: if code_execution is not None:
_tools["code_execution"] = code_execution _tools["code_execution"] = code_execution
return [_tools] return [_tools]
@ -900,6 +905,7 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
model: str, model: str,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: Dict, optional_params: Dict,
litellm_params: Dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
) -> Dict: ) -> Dict:
@ -1017,7 +1023,7 @@ class VertexLLM(VertexBase):
logging_obj, logging_obj,
stream, stream,
optional_params: dict, optional_params: dict,
litellm_params=None, litellm_params: dict,
logger_fn=None, logger_fn=None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
client: Optional[AsyncHTTPHandler] = None, client: Optional[AsyncHTTPHandler] = None,
@ -1058,6 +1064,7 @@ class VertexLLM(VertexBase):
model=model, model=model,
messages=messages, messages=messages,
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params,
) )
## LOGGING ## LOGGING
@ -1144,6 +1151,7 @@ class VertexLLM(VertexBase):
model=model, model=model,
messages=messages, messages=messages,
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params,
) )
request_body = await async_transform_request_body(**data) # type: ignore request_body = await async_transform_request_body(**data) # type: ignore
@ -1317,6 +1325,7 @@ class VertexLLM(VertexBase):
model=model, model=model,
messages=messages, messages=messages,
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params,
) )
## TRANSFORMATION ## ## TRANSFORMATION ##

View file

@ -94,6 +94,7 @@ class VertexMultimodalEmbedding(VertexLLM):
optional_params=optional_params, optional_params=optional_params,
api_key=auth_header, api_key=auth_header,
api_base=api_base, api_base=api_base,
litellm_params=litellm_params,
) )
## LOGGING ## LOGGING

View file

@ -47,6 +47,7 @@ class VertexAIMultimodalEmbeddingConfig(BaseEmbeddingConfig):
model: str, model: str,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
) -> dict: ) -> dict:

View file

@ -10,7 +10,6 @@ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.litellm_core_utils.asyncify import asyncify from litellm.litellm_core_utils.asyncify import asyncify
from litellm.llms.base import BaseLLM
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
@ -22,7 +21,7 @@ else:
GoogleCredentialsObject = Any GoogleCredentialsObject = Any
class VertexBase(BaseLLM): class VertexBase:
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self.access_token: Optional[str] = None self.access_token: Optional[str] = None

View file

@ -83,6 +83,7 @@ class VoyageEmbeddingConfig(BaseEmbeddingConfig):
model: str, model: str,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
) -> dict: ) -> dict:

View file

@ -49,6 +49,7 @@ class WatsonXChatHandler(OpenAILikeChatHandler):
messages=messages, messages=messages,
optional_params=optional_params, optional_params=optional_params,
api_key=api_key, api_key=api_key,
litellm_params=litellm_params,
) )
## UPDATE PAYLOAD (optional params) ## UPDATE PAYLOAD (optional params)

View file

@ -165,6 +165,7 @@ class IBMWatsonXMixin:
model: str, model: str,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: Dict, optional_params: Dict,
litellm_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
) -> Dict: ) -> Dict:

View file

@ -3616,6 +3616,7 @@ def embedding( # noqa: PLR0915
optional_params=optional_params, optional_params=optional_params,
client=client, client=client,
aembedding=aembedding, aembedding=aembedding,
litellm_params=litellm_params_dict,
) )
elif custom_llm_provider == "bedrock": elif custom_llm_provider == "bedrock":
if isinstance(input, str): if isinstance(input, str):

View file

@ -2409,25 +2409,26 @@
"max_tokens": 4096, "max_tokens": 4096,
"max_input_tokens": 131072, "max_input_tokens": 131072,
"max_output_tokens": 4096, "max_output_tokens": 4096,
"input_cost_per_token": 0, "input_cost_per_token": 0.000000075,
"output_cost_per_token": 0, "output_cost_per_token": 0.0000003,
"litellm_provider": "azure_ai", "litellm_provider": "azure_ai",
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"source": "https://learn.microsoft.com/en-us/azure/ai-foundry/concepts/models-featured#microsoft" "source": "https://techcommunity.microsoft.com/blog/Azure-AI-Services-blog/announcing-new-phi-pricing-empowering-your-business-with-small-language-models/4395112"
}, },
"azure_ai/Phi-4-multimodal-instruct": { "azure_ai/Phi-4-multimodal-instruct": {
"max_tokens": 4096, "max_tokens": 4096,
"max_input_tokens": 131072, "max_input_tokens": 131072,
"max_output_tokens": 4096, "max_output_tokens": 4096,
"input_cost_per_token": 0, "input_cost_per_token": 0.00000008,
"output_cost_per_token": 0, "input_cost_per_audio_token": 0.000004,
"output_cost_per_token": 0.00032,
"litellm_provider": "azure_ai", "litellm_provider": "azure_ai",
"mode": "chat", "mode": "chat",
"supports_audio_input": true, "supports_audio_input": true,
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true, "supports_vision": true,
"source": "https://learn.microsoft.com/en-us/azure/ai-foundry/concepts/models-featured#microsoft" "source": "https://techcommunity.microsoft.com/blog/Azure-AI-Services-blog/announcing-new-phi-pricing-empowering-your-business-with-small-language-models/4395112"
}, },
"azure_ai/Phi-4": { "azure_ai/Phi-4": {
"max_tokens": 16384, "max_tokens": 16384,
@ -3467,7 +3468,7 @@
"input_cost_per_token": 0.0000008, "input_cost_per_token": 0.0000008,
"output_cost_per_token": 0.000004, "output_cost_per_token": 0.000004,
"cache_creation_input_token_cost": 0.000001, "cache_creation_input_token_cost": 0.000001,
"cache_read_input_token_cost": 0.0000008, "cache_read_input_token_cost": 0.00000008,
"litellm_provider": "anthropic", "litellm_provider": "anthropic",
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,

View file

@ -1625,6 +1625,7 @@ class LiteLLM_UserTable(LiteLLMPydanticObjectBase):
model_max_budget: Optional[Dict] = {} model_max_budget: Optional[Dict] = {}
model_spend: Optional[Dict] = {} model_spend: Optional[Dict] = {}
user_email: Optional[str] = None user_email: Optional[str] = None
user_alias: Optional[str] = None
models: list = [] models: list = []
tpm_limit: Optional[int] = None tpm_limit: Optional[int] = None
rpm_limit: Optional[int] = None rpm_limit: Optional[int] = None

View file

@ -4,16 +4,26 @@ import json
import uuid import uuid
from base64 import b64encode from base64 import b64encode
from datetime import datetime from datetime import datetime
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Tuple, Union
from urllib.parse import parse_qs, urlencode, urlparse from urllib.parse import parse_qs, urlencode, urlparse
import httpx import httpx
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status from fastapi import (
APIRouter,
Depends,
HTTPException,
Request,
Response,
UploadFile,
status,
)
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from starlette.datastructures import UploadFile as StarletteUploadFile
import litellm import litellm
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
from litellm.proxy._types import ( from litellm.proxy._types import (
ConfigFieldInfo, ConfigFieldInfo,
@ -358,6 +368,92 @@ class HttpPassThroughEndpointHelpers:
) )
return response return response
@staticmethod
async def non_streaming_http_request_handler(
request: Request,
async_client: httpx.AsyncClient,
url: httpx.URL,
headers: dict,
requested_query_params: Optional[dict] = None,
_parsed_body: Optional[dict] = None,
) -> httpx.Response:
"""
Handle non-streaming HTTP requests
Handles special cases when GET requests, multipart/form-data requests, and generic httpx requests
"""
if request.method == "GET":
response = await async_client.request(
method=request.method,
url=url,
headers=headers,
params=requested_query_params,
)
elif HttpPassThroughEndpointHelpers.is_multipart(request) is True:
return await HttpPassThroughEndpointHelpers.make_multipart_http_request(
request=request,
async_client=async_client,
url=url,
headers=headers,
requested_query_params=requested_query_params,
)
else:
# Generic httpx method
response = await async_client.request(
method=request.method,
url=url,
headers=headers,
params=requested_query_params,
json=_parsed_body,
)
return response
@staticmethod
def is_multipart(request: Request) -> bool:
"""Check if the request is a multipart/form-data request"""
return "multipart/form-data" in request.headers.get("content-type", "")
@staticmethod
async def _build_request_files_from_upload_file(
upload_file: Union[UploadFile, StarletteUploadFile],
) -> Tuple[Optional[str], bytes, Optional[str]]:
"""Build a request files dict from an UploadFile object"""
file_content = await upload_file.read()
return (upload_file.filename, file_content, upload_file.content_type)
@staticmethod
async def make_multipart_http_request(
request: Request,
async_client: httpx.AsyncClient,
url: httpx.URL,
headers: dict,
requested_query_params: Optional[dict] = None,
) -> httpx.Response:
"""Process multipart/form-data requests, handling both files and form fields"""
form_data = await request.form()
files = {}
form_data_dict = {}
for field_name, field_value in form_data.items():
if isinstance(field_value, (StarletteUploadFile, UploadFile)):
files[field_name] = (
await HttpPassThroughEndpointHelpers._build_request_files_from_upload_file(
upload_file=field_value
)
)
else:
form_data_dict[field_name] = field_value
response = await async_client.request(
method=request.method,
url=url,
headers=headers,
params=requested_query_params,
files=files,
data=form_data_dict,
)
return response
async def pass_through_request( # noqa: PLR0915 async def pass_through_request( # noqa: PLR0915
request: Request, request: Request,
@ -424,7 +520,7 @@ async def pass_through_request( # noqa: PLR0915
start_time = datetime.now() start_time = datetime.now()
logging_obj = Logging( logging_obj = Logging(
model="unknown", model="unknown",
messages=[{"role": "user", "content": json.dumps(_parsed_body)}], messages=[{"role": "user", "content": safe_dumps(_parsed_body)}],
stream=False, stream=False,
call_type="pass_through_endpoint", call_type="pass_through_endpoint",
start_time=start_time, start_time=start_time,
@ -453,7 +549,6 @@ async def pass_through_request( # noqa: PLR0915
logging_obj.model_call_details["litellm_call_id"] = litellm_call_id logging_obj.model_call_details["litellm_call_id"] = litellm_call_id
# combine url with query params for logging # combine url with query params for logging
requested_query_params: Optional[dict] = ( requested_query_params: Optional[dict] = (
query_params or request.query_params.__dict__ query_params or request.query_params.__dict__
) )
@ -474,7 +569,7 @@ async def pass_through_request( # noqa: PLR0915
logging_url = str(url) + "?" + requested_query_params_str logging_url = str(url) + "?" + requested_query_params_str
logging_obj.pre_call( logging_obj.pre_call(
input=[{"role": "user", "content": json.dumps(_parsed_body)}], input=[{"role": "user", "content": safe_dumps(_parsed_body)}],
api_key="", api_key="",
additional_args={ additional_args={
"complete_input_dict": _parsed_body, "complete_input_dict": _parsed_body,
@ -525,22 +620,16 @@ async def pass_through_request( # noqa: PLR0915
) )
verbose_proxy_logger.debug("request body: {}".format(_parsed_body)) verbose_proxy_logger.debug("request body: {}".format(_parsed_body))
if request.method == "GET": response = (
response = await async_client.request( await HttpPassThroughEndpointHelpers.non_streaming_http_request_handler(
method=request.method, request=request,
async_client=async_client,
url=url, url=url,
headers=headers, headers=headers,
params=requested_query_params, requested_query_params=requested_query_params,
_parsed_body=_parsed_body,
) )
else: )
response = await async_client.request(
method=request.method,
url=url,
headers=headers,
params=requested_query_params,
json=_parsed_body,
)
verbose_proxy_logger.debug("response.headers= %s", response.headers) verbose_proxy_logger.debug("response.headers= %s", response.headers)
if _is_streaming_response(response) is True: if _is_streaming_response(response) is True:

View file

@ -187,6 +187,7 @@ class Tools(TypedDict, total=False):
function_declarations: List[FunctionDeclaration] function_declarations: List[FunctionDeclaration]
googleSearch: dict googleSearch: dict
googleSearchRetrieval: dict googleSearchRetrieval: dict
enterpriseWebSearch: dict
code_execution: dict code_execution: dict
retrieval: Retrieval retrieval: Retrieval
@ -497,6 +498,51 @@ class OutputConfig(TypedDict, total=False):
gcsDestination: GcsDestination gcsDestination: GcsDestination
class GcsBucketResponse(TypedDict):
"""
TypedDict for GCS bucket upload response
Attributes:
kind: The kind of item this is. For objects, this is always storage#object
id: The ID of the object
selfLink: The link to this object
mediaLink: The link to download the object
name: The name of the object
bucket: The name of the bucket containing this object
generation: The content generation of this object
metageneration: The metadata generation of this object
contentType: The content type of the object
storageClass: The storage class of the object
size: The size of the object in bytes
md5Hash: The MD5 hash of the object
crc32c: The CRC32c checksum of the object
etag: The ETag of the object
timeCreated: The creation time of the object
updated: The last update time of the object
timeStorageClassUpdated: The time the storage class was last updated
timeFinalized: The time the object was finalized
"""
kind: Literal["storage#object"]
id: str
selfLink: str
mediaLink: str
name: str
bucket: str
generation: str
metageneration: str
contentType: str
storageClass: str
size: str
md5Hash: str
crc32c: str
etag: str
timeCreated: str
updated: str
timeStorageClassUpdated: str
timeFinalized: str
class VertexAIBatchPredictionJob(TypedDict): class VertexAIBatchPredictionJob(TypedDict):
displayName: str displayName: str
model: str model: str

View file

@ -2,7 +2,7 @@ import json
import time import time
import uuid import uuid
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Tuple, Union from typing import Any, Dict, List, Literal, Mapping, Optional, Tuple, Union
from aiohttp import FormData from aiohttp import FormData
from openai._models import BaseModel as OpenAIObject from openai._models import BaseModel as OpenAIObject
@ -2170,3 +2170,20 @@ class CreateCredentialItem(CredentialBase):
if not values.get("credential_values") and not values.get("model_id"): if not values.get("credential_values") and not values.get("model_id"):
raise ValueError("Either credential_values or model_id must be set") raise ValueError("Either credential_values or model_id must be set")
return values return values
class ExtractedFileData(TypedDict):
"""
TypedDict for storing processed file data
Attributes:
filename: Name of the file if provided
content: The file content in bytes
content_type: MIME type of the file
headers: Any additional headers for the file
"""
filename: Optional[str]
content: bytes
content_type: Optional[str]
headers: Mapping[str, str]

View file

@ -6517,6 +6517,10 @@ class ProviderConfigManager:
) )
return GoogleAIStudioFilesHandler() return GoogleAIStudioFilesHandler()
elif LlmProviders.VERTEX_AI == provider:
from litellm.llms.vertex_ai.files.transformation import VertexAIFilesConfig
return VertexAIFilesConfig()
return None return None

View file

@ -2409,25 +2409,26 @@
"max_tokens": 4096, "max_tokens": 4096,
"max_input_tokens": 131072, "max_input_tokens": 131072,
"max_output_tokens": 4096, "max_output_tokens": 4096,
"input_cost_per_token": 0, "input_cost_per_token": 0.000000075,
"output_cost_per_token": 0, "output_cost_per_token": 0.0000003,
"litellm_provider": "azure_ai", "litellm_provider": "azure_ai",
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"source": "https://learn.microsoft.com/en-us/azure/ai-foundry/concepts/models-featured#microsoft" "source": "https://techcommunity.microsoft.com/blog/Azure-AI-Services-blog/announcing-new-phi-pricing-empowering-your-business-with-small-language-models/4395112"
}, },
"azure_ai/Phi-4-multimodal-instruct": { "azure_ai/Phi-4-multimodal-instruct": {
"max_tokens": 4096, "max_tokens": 4096,
"max_input_tokens": 131072, "max_input_tokens": 131072,
"max_output_tokens": 4096, "max_output_tokens": 4096,
"input_cost_per_token": 0, "input_cost_per_token": 0.00000008,
"output_cost_per_token": 0, "input_cost_per_audio_token": 0.000004,
"output_cost_per_token": 0.00032,
"litellm_provider": "azure_ai", "litellm_provider": "azure_ai",
"mode": "chat", "mode": "chat",
"supports_audio_input": true, "supports_audio_input": true,
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true, "supports_vision": true,
"source": "https://learn.microsoft.com/en-us/azure/ai-foundry/concepts/models-featured#microsoft" "source": "https://techcommunity.microsoft.com/blog/Azure-AI-Services-blog/announcing-new-phi-pricing-empowering-your-business-with-small-language-models/4395112"
}, },
"azure_ai/Phi-4": { "azure_ai/Phi-4": {
"max_tokens": 16384, "max_tokens": 16384,
@ -3467,7 +3468,7 @@
"input_cost_per_token": 0.0000008, "input_cost_per_token": 0.0000008,
"output_cost_per_token": 0.000004, "output_cost_per_token": 0.000004,
"cache_creation_input_token_cost": 0.000001, "cache_creation_input_token_cost": 0.000001,
"cache_read_input_token_cost": 0.0000008, "cache_read_input_token_cost": 0.00000008,
"litellm_provider": "anthropic", "litellm_provider": "anthropic",
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,

View file

@ -423,25 +423,35 @@ mock_vertex_batch_response = {
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_avertex_batch_prediction(): async def test_avertex_batch_prediction(monkeypatch):
with patch( monkeypatch.setenv("GCS_BUCKET_NAME", "litellm-local")
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
client = AsyncHTTPHandler()
async def mock_side_effect(*args, **kwargs):
print("args", args, "kwargs", kwargs)
url = kwargs.get("url", "")
if "files" in url:
mock_response.json.return_value = mock_file_response
elif "batch" in url:
mock_response.json.return_value = mock_vertex_batch_response
mock_response.status_code = 200
return mock_response
with patch.object(
client, "post", side_effect=mock_side_effect
) as mock_post, patch(
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post" "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post"
) as mock_post: ) as mock_global_post:
# Configure mock responses # Configure mock responses
mock_response = MagicMock() mock_response = MagicMock()
mock_response.raise_for_status.return_value = None mock_response.raise_for_status.return_value = None
# Set up different responses for different API calls # Set up different responses for different API calls
async def mock_side_effect(*args, **kwargs):
url = kwargs.get("url", "")
if "files" in url:
mock_response.json.return_value = mock_file_response
elif "batch" in url:
mock_response.json.return_value = mock_vertex_batch_response
mock_response.status_code = 200
return mock_response
mock_post.side_effect = mock_side_effect mock_post.side_effect = mock_side_effect
mock_global_post.side_effect = mock_side_effect
# load_vertex_ai_credentials() # load_vertex_ai_credentials()
litellm.set_verbose = True litellm.set_verbose = True
@ -455,6 +465,7 @@ async def test_avertex_batch_prediction():
file=open(file_path, "rb"), file=open(file_path, "rb"),
purpose="batch", purpose="batch",
custom_llm_provider="vertex_ai", custom_llm_provider="vertex_ai",
client=client
) )
print("Response from creating file=", file_obj) print("Response from creating file=", file_obj)

View file

@ -0,0 +1,116 @@
import json
import os
import sys
from io import BytesIO
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
from fastapi import Request, UploadFile
from fastapi.testclient import TestClient
from starlette.datastructures import Headers
from starlette.datastructures import UploadFile as StarletteUploadFile
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
HttpPassThroughEndpointHelpers,
)
# Test is_multipart
def test_is_multipart():
# Test with multipart content type
request = MagicMock(spec=Request)
request.headers = Headers({"content-type": "multipart/form-data; boundary=123"})
assert HttpPassThroughEndpointHelpers.is_multipart(request) is True
# Test with non-multipart content type
request.headers = Headers({"content-type": "application/json"})
assert HttpPassThroughEndpointHelpers.is_multipart(request) is False
# Test with no content type
request.headers = Headers({})
assert HttpPassThroughEndpointHelpers.is_multipart(request) is False
# Test _build_request_files_from_upload_file
@pytest.mark.asyncio
async def test_build_request_files_from_upload_file():
# Test with FastAPI UploadFile
file_content = b"test content"
file = BytesIO(file_content)
# Create SpooledTemporaryFile with content type headers
headers = {"content-type": "text/plain"}
upload_file = UploadFile(file=file, filename="test.txt", headers=headers)
upload_file.read = AsyncMock(return_value=file_content)
result = await HttpPassThroughEndpointHelpers._build_request_files_from_upload_file(
upload_file
)
assert result == ("test.txt", file_content, "text/plain")
# Test with Starlette UploadFile
file2 = BytesIO(file_content)
starlette_file = StarletteUploadFile(
file=file2,
filename="test2.txt",
headers=Headers({"content-type": "text/plain"}),
)
starlette_file.read = AsyncMock(return_value=file_content)
result = await HttpPassThroughEndpointHelpers._build_request_files_from_upload_file(
starlette_file
)
assert result == ("test2.txt", file_content, "text/plain")
# Test make_multipart_http_request
@pytest.mark.asyncio
async def test_make_multipart_http_request():
# Mock request with file and form field
request = MagicMock(spec=Request)
request.method = "POST"
# Mock form data
file_content = b"test file content"
file = BytesIO(file_content)
# Create SpooledTemporaryFile with content type headers
headers = {"content-type": "text/plain"}
upload_file = UploadFile(file=file, filename="test.txt", headers=headers)
upload_file.read = AsyncMock(return_value=file_content)
form_data = {"file": upload_file, "text_field": "test value"}
request.form = AsyncMock(return_value=form_data)
# Mock httpx client
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {}
async_client = MagicMock()
async_client.request = AsyncMock(return_value=mock_response)
# Test the function
response = await HttpPassThroughEndpointHelpers.make_multipart_http_request(
request=request,
async_client=async_client,
url=httpx.URL("http://test.com"),
headers={},
requested_query_params=None,
)
# Verify the response
assert response == mock_response
# Verify the client call
async_client.request.assert_called_once()
call_args = async_client.request.call_args[1]
assert call_args["method"] == "POST"
assert str(call_args["url"]) == "http://test.com"
assert isinstance(call_args["files"], dict)
assert isinstance(call_args["data"], dict)
assert call_args["data"]["text_field"] == "test value"

View file

@ -323,7 +323,8 @@ class TestHuggingFace(BaseLLMChatTest):
model="huggingface/fireworks-ai/meta-llama/Meta-Llama-3-8B-Instruct", model="huggingface/fireworks-ai/meta-llama/Meta-Llama-3-8B-Instruct",
messages=[{"role": "user", "content": "Hello"}], messages=[{"role": "user", "content": "Hello"}],
optional_params={}, optional_params={},
api_key="test_api_key" api_key="test_api_key",
litellm_params={}
) )
assert headers["Authorization"] == "Bearer test_api_key" assert headers["Authorization"] == "Bearer test_api_key"

View file

@ -141,6 +141,7 @@ def test_build_vertex_schema():
[ [
([{"googleSearch": {}}], "googleSearch"), ([{"googleSearch": {}}], "googleSearch"),
([{"googleSearchRetrieval": {}}], "googleSearchRetrieval"), ([{"googleSearchRetrieval": {}}], "googleSearchRetrieval"),
([{"enterpriseWebSearch": {}}], "enterpriseWebSearch"),
([{"code_execution": {}}], "code_execution"), ([{"code_execution": {}}], "code_execution"),
], ],
) )

View file

@ -0,0 +1,2 @@
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello world!"}], "max_tokens": 10}}
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are an unhelpful assistant."}, {"role": "user", "content": "Hello world!"}], "max_tokens": 10}}

View file

@ -21,7 +21,7 @@ from litellm.integrations.gcs_bucket.gcs_bucket import (
StandardLoggingPayload, StandardLoggingPayload,
) )
from litellm.types.utils import StandardCallbackDynamicParams from litellm.types.utils import StandardCallbackDynamicParams
from unittest.mock import patch
verbose_logger.setLevel(logging.DEBUG) verbose_logger.setLevel(logging.DEBUG)
@ -687,3 +687,63 @@ async def test_basic_gcs_logger_with_folder_in_bucket_name():
# clean up # clean up
if old_bucket_name is not None: if old_bucket_name is not None:
os.environ["GCS_BUCKET_NAME"] = old_bucket_name os.environ["GCS_BUCKET_NAME"] = old_bucket_name
@pytest.mark.skip(reason="This test is flaky on ci/cd")
def test_create_file_e2e():
"""
Asserts 'create_file' is called with the correct arguments
"""
load_vertex_ai_credentials()
test_file_content = b"test audio content"
test_file = ("test.wav", test_file_content, "audio/wav")
from litellm import create_file
response = create_file(
file=test_file,
purpose="user_data",
custom_llm_provider="vertex_ai",
)
print("response", response)
assert response is not None
@pytest.mark.skip(reason="This test is flaky on ci/cd")
def test_create_file_e2e_jsonl():
"""
Asserts 'create_file' is called with the correct arguments
"""
load_vertex_ai_credentials()
from litellm.llms.custom_httpx.http_handler import HTTPHandler
client = HTTPHandler()
example_jsonl = [{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}},{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}]
# Create and write to the file
file_path = "example.jsonl"
with open(file_path, "w") as f:
for item in example_jsonl:
f.write(json.dumps(item) + "\n")
# Verify file content
with open(file_path, "r") as f:
content = f.read()
print("File content:", content)
assert len(content) > 0, "File is empty"
from litellm import create_file
with patch.object(client, "post") as mock_create_file:
try:
response = create_file(
file=open(file_path, "rb"),
purpose="user_data",
custom_llm_provider="vertex_ai",
client=client,
)
except Exception as e:
print("error", e)
mock_create_file.assert_called_once()
print(f"kwargs: {mock_create_file.call_args.kwargs}")
assert mock_create_file.call_args.kwargs["data"] is not None and len(mock_create_file.call_args.kwargs["data"]) > 0

View file

@ -2,14 +2,31 @@ import pytest
import openai import openai
import aiohttp import aiohttp
import asyncio import asyncio
import tempfile
from typing_extensions import override from typing_extensions import override
from openai import AssistantEventHandler from openai import AssistantEventHandler
client = openai.OpenAI(base_url="http://0.0.0.0:4000/openai", api_key="sk-1234") client = openai.OpenAI(base_url="http://0.0.0.0:4000/openai", api_key="sk-1234")
def test_pass_through_file_operations():
# Create a temporary file
with tempfile.NamedTemporaryFile(mode='w+', suffix='.txt', delete=False) as temp_file:
temp_file.write("This is a test file for the OpenAI Assistants API.")
temp_file.flush()
# create a file
file = client.files.create(
file=open(temp_file.name, "rb"),
purpose="assistants",
)
print("file created", file)
# delete the file
delete_file = client.files.delete(file.id)
print("file deleted", delete_file)
def test_openai_assistants_e2e_operations(): def test_openai_assistants_e2e_operations():
assistant = client.beta.assistants.create( assistant = client.beta.assistants.create(
name="Math Tutor", name="Math Tutor",
instructions="You are a personal math tutor. Write and run code to answer math questions.", instructions="You are a personal math tutor. Write and run code to answer math questions.",

View file

@ -13,9 +13,12 @@ import { Organization, userListCall } from "./networking";
import { createTeamSearchFunction } from "./key_team_helpers/team_search_fn"; import { createTeamSearchFunction } from "./key_team_helpers/team_search_fn";
import { createOrgSearchFunction } from "./key_team_helpers/organization_search_fn"; import { createOrgSearchFunction } from "./key_team_helpers/organization_search_fn";
import { useFilterLogic } from "./key_team_helpers/filter_logic"; import { useFilterLogic } from "./key_team_helpers/filter_logic";
import { Setter } from "@/types";
import { updateExistingKeys } from "@/utils/dataUtils";
interface AllKeysTableProps { interface AllKeysTableProps {
keys: KeyResponse[]; keys: KeyResponse[];
setKeys: Setter<KeyResponse[]>;
isLoading?: boolean; isLoading?: boolean;
pagination: { pagination: {
currentPage: number; currentPage: number;
@ -87,6 +90,7 @@ const TeamFilter = ({
*/ */
export function AllKeysTable({ export function AllKeysTable({
keys, keys,
setKeys,
isLoading = false, isLoading = false,
pagination, pagination,
onPageChange, onPageChange,
@ -364,6 +368,23 @@ export function AllKeysTable({
keyId={selectedKeyId} keyId={selectedKeyId}
onClose={() => setSelectedKeyId(null)} onClose={() => setSelectedKeyId(null)}
keyData={keys.find(k => k.token === selectedKeyId)} keyData={keys.find(k => k.token === selectedKeyId)}
onKeyDataUpdate={(updatedKeyData) => {
setKeys(keys => keys.map(key => {
if (key.token === updatedKeyData.token) {
// The shape of key is different from that of
// updatedKeyData(received from keyUpdateCall in networking.tsx).
// Hence, we can't replace key with updatedKeys since it might lead
// to unintended bugs/behaviors.
// So instead, we only update fields that are present in both.
return updateExistingKeys(key, updatedKeyData)
}
return key
}))
}}
onDelete={() => {
setKeys(keys => keys.filter(key => key.token !== selectedKeyId))
}}
accessToken={accessToken} accessToken={accessToken}
userID={userID} userID={userID}
userRole={userRole} userRole={userRole}

View file

@ -27,13 +27,15 @@ interface KeyInfoViewProps {
keyId: string; keyId: string;
onClose: () => void; onClose: () => void;
keyData: KeyResponse | undefined; keyData: KeyResponse | undefined;
onKeyDataUpdate?: (data: Partial<KeyResponse>) => void;
onDelete?: () => void;
accessToken: string | null; accessToken: string | null;
userID: string | null; userID: string | null;
userRole: string | null; userRole: string | null;
teams: any[] | null; teams: any[] | null;
} }
export default function KeyInfoView({ keyId, onClose, keyData, accessToken, userID, userRole, teams }: KeyInfoViewProps) { export default function KeyInfoView({ keyId, onClose, keyData, accessToken, userID, userRole, teams, onKeyDataUpdate, onDelete }: KeyInfoViewProps) {
const [isEditing, setIsEditing] = useState(false); const [isEditing, setIsEditing] = useState(false);
const [form] = Form.useForm(); const [form] = Form.useForm();
const [isDeleteModalOpen, setIsDeleteModalOpen] = useState(false); const [isDeleteModalOpen, setIsDeleteModalOpen] = useState(false);
@ -93,6 +95,9 @@ export default function KeyInfoView({ keyId, onClose, keyData, accessToken, user
} }
const newKeyValues = await keyUpdateCall(accessToken, formValues); const newKeyValues = await keyUpdateCall(accessToken, formValues);
if (onKeyDataUpdate) {
onKeyDataUpdate(newKeyValues)
}
message.success("Key updated successfully"); message.success("Key updated successfully");
setIsEditing(false); setIsEditing(false);
// Refresh key data here if needed // Refresh key data here if needed
@ -107,6 +112,9 @@ export default function KeyInfoView({ keyId, onClose, keyData, accessToken, user
if (!accessToken) return; if (!accessToken) return;
await keyDeleteCall(accessToken as string, keyData.token); await keyDeleteCall(accessToken as string, keyData.token);
message.success("Key deleted successfully"); message.success("Key deleted successfully");
if (onDelete) {
onDelete()
}
onClose(); onClose();
} catch (error) { } catch (error) {
console.error("Error deleting the key:", error); console.error("Error deleting the key:", error);

View file

@ -1,5 +1,6 @@
import { useState, useEffect } from 'react'; import { useState, useEffect } from 'react';
import { keyListCall, Organization } from '../networking'; import { keyListCall, Organization } from '../networking';
import { Setter } from '@/types';
export interface Team { export interface Team {
team_id: string; team_id: string;
@ -94,13 +95,14 @@ totalPages: number;
totalCount: number; totalCount: number;
} }
interface UseKeyListReturn { interface UseKeyListReturn {
keys: KeyResponse[]; keys: KeyResponse[];
isLoading: boolean; isLoading: boolean;
error: Error | null; error: Error | null;
pagination: PaginationData; pagination: PaginationData;
refresh: (params?: Record<string, unknown>) => Promise<void>; refresh: (params?: Record<string, unknown>) => Promise<void>;
setKeys: (newKeysOrUpdater: KeyResponse[] | ((prevKeys: KeyResponse[]) => KeyResponse[])) => void; setKeys: Setter<KeyResponse[]>;
} }
const useKeyList = ({ const useKeyList = ({

View file

@ -4,6 +4,7 @@
import { all_admin_roles } from "@/utils/roles"; import { all_admin_roles } from "@/utils/roles";
import { message } from "antd"; import { message } from "antd";
import { TagNewRequest, TagUpdateRequest, TagDeleteRequest, TagInfoRequest, TagListResponse, TagInfoResponse } from "./tag_management/types"; import { TagNewRequest, TagUpdateRequest, TagDeleteRequest, TagInfoRequest, TagListResponse, TagInfoResponse } from "./tag_management/types";
import { Team } from "./key_team_helpers/key_list";
const isLocal = process.env.NODE_ENV === "development"; const isLocal = process.env.NODE_ENV === "development";
export const proxyBaseUrl = isLocal ? "http://localhost:4000" : null; export const proxyBaseUrl = isLocal ? "http://localhost:4000" : null;
@ -2983,7 +2984,7 @@ export const teamUpdateCall = async (
console.error("Error response from the server:", errorData); console.error("Error response from the server:", errorData);
throw new Error("Network response was not ok"); throw new Error("Network response was not ok");
} }
const data = await response.json(); const data = await response.json() as { data: Team, team_id: string };
console.log("Update Team Response:", data); console.log("Update Team Response:", data);
return data; return data;
// Handle success - you might want to update some state or UI based on the created key // Handle success - you might want to update some state or UI based on the created key

View file

@ -30,6 +30,7 @@ import { PencilAltIcon, PlusIcon, TrashIcon } from "@heroicons/react/outline";
import MemberModal from "./edit_membership"; import MemberModal from "./edit_membership";
import UserSearchModal from "@/components/common_components/user_search_modal"; import UserSearchModal from "@/components/common_components/user_search_modal";
import { getModelDisplayName } from "../key_team_helpers/fetch_available_models_team_key"; import { getModelDisplayName } from "../key_team_helpers/fetch_available_models_team_key";
import { Team } from "../key_team_helpers/key_list";
interface TeamData { interface TeamData {
@ -69,6 +70,7 @@ interface TeamInfoProps {
is_proxy_admin: boolean; is_proxy_admin: boolean;
userModels: string[]; userModels: string[];
editTeam: boolean; editTeam: boolean;
onUpdate?: (team: Team) => void
} }
const TeamInfoView: React.FC<TeamInfoProps> = ({ const TeamInfoView: React.FC<TeamInfoProps> = ({
@ -78,7 +80,8 @@ const TeamInfoView: React.FC<TeamInfoProps> = ({
is_team_admin, is_team_admin,
is_proxy_admin, is_proxy_admin,
userModels, userModels,
editTeam editTeam,
onUpdate
}) => { }) => {
const [teamData, setTeamData] = useState<TeamData | null>(null); const [teamData, setTeamData] = useState<TeamData | null>(null);
const [loading, setLoading] = useState(true); const [loading, setLoading] = useState(true);
@ -199,7 +202,10 @@ const TeamInfoView: React.FC<TeamInfoProps> = ({
}; };
const response = await teamUpdateCall(accessToken, updateData); const response = await teamUpdateCall(accessToken, updateData);
if (onUpdate) {
onUpdate(response.data)
}
message.success("Team settings updated successfully"); message.success("Team settings updated successfully");
setIsEditing(false); setIsEditing(false);
fetchTeamInfo(); fetchTeamInfo();

View file

@ -84,6 +84,7 @@ import {
modelAvailableCall, modelAvailableCall,
teamListCall teamListCall
} from "./networking"; } from "./networking";
import { updateExistingKeys } from "@/utils/dataUtils";
const getOrganizationModels = (organization: Organization | null, userModels: string[]) => { const getOrganizationModels = (organization: Organization | null, userModels: string[]) => {
let tempModelsToPick = []; let tempModelsToPick = [];
@ -321,6 +322,22 @@ const Teams: React.FC<TeamProps> = ({
{selectedTeamId ? ( {selectedTeamId ? (
<TeamInfoView <TeamInfoView
teamId={selectedTeamId} teamId={selectedTeamId}
onUpdate={data => {
setTeams(teams => {
if (teams == null) {
return teams;
}
return teams.map(team => {
if (data.team_id === team.team_id) {
return updateExistingKeys(team, data)
}
return team
})
})
}}
onClose={() => { onClose={() => {
setSelectedTeamId(null); setSelectedTeamId(null);
setEditTeam(false); setEditTeam(false);

View file

@ -418,6 +418,7 @@ const ViewKeyTable: React.FC<ViewKeyTableProps> = ({
<div> <div>
<AllKeysTable <AllKeysTable
keys={keys} keys={keys}
setKeys={setKeys}
isLoading={isLoading} isLoading={isLoading}
pagination={pagination} pagination={pagination}
onPageChange={handlePageChange} onPageChange={handlePageChange}

View file

@ -0,0 +1 @@
export type Setter<T> = (newValueOrUpdater: T | ((previousValue: T) => T)) => void

View file

@ -0,0 +1,14 @@
export function updateExistingKeys<Source extends Object>(
target: Source,
source: Object
): Source {
const clonedTarget = structuredClone(target);
for (const [key, value] of Object.entries(source)) {
if (key in clonedTarget) {
(clonedTarget as any)[key] = value;
}
}
return clonedTarget;
}