Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tests] reset known modules that are patched on each test function end #2147

Merged
merged 6 commits into from
Dec 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/axolotl/monkeypatch/trainer_grad_accum.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
see https://github.com/huggingface/transformers/pull/35128
"""
import inspect
import logging

from accelerate.logging import get_logger
from transformers import LlamaForCausalLM
from transformers.trainer import Trainer

from axolotl.monkeypatch.unsloth_ import detab_code

LOG = get_logger("axolotl.monkeypatch.trainer_grad_accum")
LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum")

ORIGINAL_CONTEXT_CODE = """
with self.compute_loss_context_manager():
Expand Down Expand Up @@ -145,7 +145,7 @@ def patch_training_step_for_ga():
globals(),
)
exec(training_step, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching training_step", main_process_only=True)
LOG.info("patching training_step")
Trainer.training_step = ( # pylint: disable=protected-access
_fixed_training_step # pylint: disable=undefined-variable # noqa: F821
)
Expand Down Expand Up @@ -201,7 +201,7 @@ def patch_forward_for_ga():
globals(),
)
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching forward", main_process_only=True)
LOG.info("patching forward")
LlamaForCausalLM.forward = ( # pylint: disable=protected-access
_fixed_forward # pylint: disable=undefined-variable # noqa: F821
)
8 changes: 8 additions & 0 deletions src/axolotl/monkeypatch/unsloth_.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,14 @@ def detab_code(code: str) -> Tuple[str, str]:
return code, spaces


self_attn_lora_patched = False # pylint: disable=invalid-name


def patch_self_attn_lora():
global self_attn_lora_patched # pylint: disable=global-statement
if self_attn_lora_patched:
# prevent patching multiple times
return
self_attn_forward = get_self_attn_code()
LlamaFlashAttention2._original_forward = ( # pylint: disable=protected-access
self_attn_forward
Expand Down Expand Up @@ -134,6 +141,7 @@ def patch_self_attn_lora():
globals(),
)
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
self_attn_lora_patched = True
LOG.info("patching unsloth attn lora", main_process_only=True)
LlamaFlashAttention2.forward = (
unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821
Expand Down
29 changes: 29 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
shared pytest fixtures
"""
import functools
import importlib
import shutil
import sys
import tempfile
import time

Expand Down Expand Up @@ -113,3 +115,30 @@ def temp_dir():
yield _temp_dir
# Clean up the directory after the test
shutil.rmtree(_temp_dir)


@pytest.fixture(scope="function", autouse=True)
def cleanup_monkeypatches():
from transformers.models.llama.modeling_llama import LlamaFlashAttention2

original_fa2_forward = LlamaFlashAttention2.forward
# monkey patches can happen inside the tests
yield
# Reset LlamaFlashAttention2 forward
LlamaFlashAttention2.forward = original_fa2_forward

# Reset other known monkeypatches
modules_to_reset: list[tuple[str, list[str]]] = [
("transformers.models.llama.modeling_llama", ["LlamaFlashAttention2"]),
("transformers.trainer",),
("transformers.loss.loss_utils",),
]
for module_name_tuple in modules_to_reset:
module_name = module_name_tuple[0]
module = importlib.import_module(module_name)
sys.modules[module_name] = module
importlib.reload(sys.modules[module_name])
if len(module_name_tuple) > 1:
module_globals = module_name_tuple[1]
for module_global in module_globals:
globals().pop(module_global, None)
Loading