more fixes and record from the server since server enables more tests

This commit is contained in:
Ashwin Bharambe 2025-07-29 12:14:55 -07:00
parent 6ebc93de81
commit 9b3a860beb
29 changed files with 2584 additions and 762 deletions

View file

@ -200,12 +200,12 @@ class ResponseStorage:
return cast(dict[str, Any], data)
async def _patched_inference_method(original_method, self, client_type, method_name=None, **kwargs):
async def _patched_inference_method(original_method, self, client_type, method_name=None, *args, **kwargs):
global _current_mode, _current_storage
if _current_mode == "live" or _current_storage is None:
# Normal operation
return await original_method(self, **kwargs)
return await original_method(self, *args, **kwargs)
# Get base URL and endpoint based on client type
if client_type == "openai":
@ -284,7 +284,7 @@ async def _patched_inference_method(original_method, self, client_type, method_n
)
elif _current_mode == "record":
response = await original_method(self, **kwargs)
response = await original_method(self, *args, **kwargs)
request_data = {
"method": method,
@ -321,7 +321,7 @@ async def _patched_inference_method(original_method, self, client_type, method_n
return response
else:
return await original_method(self, **kwargs)
raise AssertionError(f"Invalid mode: {_current_mode}")
def patch_inference_clients():
@ -347,14 +347,16 @@ def patch_inference_clients():
}
# Create patched methods for OpenAI client
async def patched_chat_completions_create(self, **kwargs):
return await _patched_inference_method(_original_methods["chat_completions_create"], self, "openai", **kwargs)
async def patched_chat_completions_create(self, *args, **kwargs):
return await _patched_inference_method(
_original_methods["chat_completions_create"], self, "openai", *args, **kwargs
)
async def patched_completions_create(self, **kwargs):
return await _patched_inference_method(_original_methods["completions_create"], self, "openai", **kwargs)
async def patched_completions_create(self, *args, **kwargs):
return await _patched_inference_method(_original_methods["completions_create"], self, "openai", *args, **kwargs)
async def patched_embeddings_create(self, **kwargs):
return await _patched_inference_method(_original_methods["embeddings_create"], self, "openai", **kwargs)
async def patched_embeddings_create(self, *args, **kwargs):
return await _patched_inference_method(_original_methods["embeddings_create"], self, "openai", *args, **kwargs)
# Apply OpenAI patches
AsyncChatCompletions.create = patched_chat_completions_create
@ -362,27 +364,33 @@ def patch_inference_clients():
AsyncEmbeddings.create = patched_embeddings_create
# Create patched methods for Ollama client
async def patched_ollama_generate(self, **kwargs):
async def patched_ollama_generate(self, *args, **kwargs):
return await _patched_inference_method(
_original_methods["ollama_generate"], self, "ollama", "generate", **kwargs
_original_methods["ollama_generate"], self, "ollama", "generate", *args, **kwargs
)
async def patched_ollama_chat(self, **kwargs):
return await _patched_inference_method(_original_methods["ollama_chat"], self, "ollama", "chat", **kwargs)
async def patched_ollama_chat(self, *args, **kwargs):
return await _patched_inference_method(
_original_methods["ollama_chat"], self, "ollama", "chat", *args, **kwargs
)
async def patched_ollama_embed(self, **kwargs):
return await _patched_inference_method(_original_methods["ollama_embed"], self, "ollama", "embed", **kwargs)
async def patched_ollama_embed(self, *args, **kwargs):
return await _patched_inference_method(
_original_methods["ollama_embed"], self, "ollama", "embed", *args, **kwargs
)
async def patched_ollama_ps(self, **kwargs):
logger.info("replay mode: ollama.ps() reporting success")
return []
async def patched_ollama_ps(self, *args, **kwargs):
return await _patched_inference_method(_original_methods["ollama_ps"], self, "ollama", "ps", *args, **kwargs)
async def patched_ollama_pull(self, *args, **kwargs):
logger.info("replay mode: ollama.pull() not actually pulling the model")
return None
return await _patched_inference_method(
_original_methods["ollama_pull"], self, "ollama", "pull", *args, **kwargs
)
async def patched_ollama_list(self, **kwargs):
return await _patched_inference_method(_original_methods["ollama_list"], self, "ollama", "list", **kwargs)
async def patched_ollama_list(self, *args, **kwargs):
return await _patched_inference_method(
_original_methods["ollama_list"], self, "ollama", "list", *args, **kwargs
)
# Apply Ollama patches
OllamaAsyncClient.generate = patched_ollama_generate
@ -416,6 +424,7 @@ def unpatch_inference_clients():
OllamaAsyncClient.chat = _original_methods["ollama_chat"]
OllamaAsyncClient.embed = _original_methods["ollama_embed"]
OllamaAsyncClient.ps = _original_methods["ollama_ps"]
OllamaAsyncClient.pull = _original_methods["ollama_pull"]
OllamaAsyncClient.list = _original_methods["ollama_list"]
_original_methods.clear()