forked from phoenix/litellm-mirror
style(test_completion.py): fix merge conflict
This commit is contained in:
parent
396d9d8e38
commit
dd7e397650
22 changed files with 1535 additions and 250 deletions
|
@ -1,9 +1,10 @@
|
|||
import os
|
||||
import os, types
|
||||
import json
|
||||
import requests
|
||||
import time
|
||||
from typing import Callable
|
||||
from typing import Callable, Optional
|
||||
from litellm.utils import ModelResponse
|
||||
import litellm
|
||||
|
||||
class ReplicateError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
|
@ -13,6 +14,65 @@ class ReplicateError(Exception):
|
|||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
class ReplicateConfig():
|
||||
"""
|
||||
Reference: https://replicate.com/meta/llama-2-70b-chat/api
|
||||
- `prompt` (string): The prompt to send to the model.
|
||||
|
||||
- `system_prompt` (string): The system prompt to send to the model. This is prepended to the prompt and helps guide system behavior. Default value: `You are a helpful assistant`.
|
||||
|
||||
- `max_new_tokens` (integer): Maximum number of tokens to generate. Typically, a word is made up of 2-3 tokens. Default value: `128`.
|
||||
|
||||
- `min_new_tokens` (integer): Minimum number of tokens to generate. To disable, set to `-1`. A word is usually 2-3 tokens. Default value: `-1`.
|
||||
|
||||
- `temperature` (number): Adjusts the randomness of outputs. Values greater than 1 increase randomness, 0 is deterministic, and 0.75 is a reasonable starting value. Default value: `0.75`.
|
||||
|
||||
- `top_p` (number): During text decoding, it samples from the top `p` percentage of most likely tokens. Reduce this to ignore less probable tokens. Default value: `0.9`.
|
||||
|
||||
- `top_k` (integer): During text decoding, samples from the top `k` most likely tokens. Reduce this to ignore less probable tokens. Default value: `50`.
|
||||
|
||||
- `stop_sequences` (string): A comma-separated list of sequences to stop generation at. For example, inputting '<end>,<stop>' will cease generation at the first occurrence of either 'end' or '<stop>'.
|
||||
|
||||
- `seed` (integer): This is the seed for the random generator. Leave it blank to randomize the seed.
|
||||
|
||||
- `debug` (boolean): If set to `True`, it provides debugging output in logs.
|
||||
|
||||
Please note that Replicate's mapping of these parameters can be inconsistent across different models, indicating that not all of these parameters may be available for use with all models.
|
||||
"""
|
||||
system_prompt: Optional[str]=None
|
||||
max_new_tokens: Optional[int]=None
|
||||
min_new_tokens: Optional[int]=None
|
||||
temperature: Optional[int]=None
|
||||
top_p: Optional[int]=None
|
||||
top_k: Optional[int]=None
|
||||
stop_sequences: Optional[str]=None
|
||||
seed: Optional[int]=None
|
||||
debug: Optional[bool]=None
|
||||
|
||||
def __init__(self,
|
||||
system_prompt: Optional[str]=None,
|
||||
max_new_tokens: Optional[int]=None,
|
||||
min_new_tokens: Optional[int]=None,
|
||||
temperature: Optional[int]=None,
|
||||
top_p: Optional[int]=None,
|
||||
top_k: Optional[int]=None,
|
||||
stop_sequences: Optional[str]=None,
|
||||
seed: Optional[int]=None,
|
||||
debug: Optional[bool]=None) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != 'self' and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {k: v for k, v in cls.__dict__.items()
|
||||
if not k.startswith('__')
|
||||
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
|
||||
and v is not None}
|
||||
|
||||
|
||||
|
||||
# Function to start a prediction and get the prediction URL
|
||||
def start_prediction(version_id, input_data, api_token, logging_obj):
|
||||
base_url = "https://api.replicate.com/v1"
|
||||
|
@ -110,6 +170,13 @@ def completion(
|
|||
):
|
||||
# Start a prediction and get the prediction URL
|
||||
version_id = model_to_version_id(model)
|
||||
|
||||
## Load Config
|
||||
config = litellm.ReplicateConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in optional_params: # completion(top_k=3) > replicate_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
if "meta/llama-2-13b-chat" in model:
|
||||
system_prompt = ""
|
||||
prompt = ""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue