Skip to content

Commit

Permalink
[tests] reset known modules that are patched on each test function end (
Browse files Browse the repository at this point in the history
#2147)

* reset known modules that are patched on each test function end

* fix the llama model module name

* prevent unsloth patching multiple times

* pop classes out of the globals after reset

* fix tuple indexing

* manually workaround for llama fa2
  • Loading branch information
winglian authored Dec 7, 2024
1 parent 743ba62 commit 5bef190
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/axolotl/monkeypatch/trainer_grad_accum.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
see https://github.com/huggingface/transformers/pull/35128
"""
import inspect
import logging

from accelerate.logging import get_logger
from transformers import LlamaForCausalLM
from transformers.trainer import Trainer

from axolotl.monkeypatch.unsloth_ import detab_code

LOG = get_logger("axolotl.monkeypatch.trainer_grad_accum")
LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum")

ORIGINAL_CONTEXT_CODE = """
with self.compute_loss_context_manager():
Expand Down Expand Up @@ -145,7 +145,7 @@ def patch_training_step_for_ga():
globals(),
)
exec(training_step, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching training_step", main_process_only=True)
LOG.info("patching training_step")
Trainer.training_step = ( # pylint: disable=protected-access
_fixed_training_step # pylint: disable=undefined-variable # noqa: F821
)
Expand Down Expand Up @@ -201,7 +201,7 @@ def patch_forward_for_ga():
globals(),
)
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching forward", main_process_only=True)
LOG.info("patching forward")
LlamaForCausalLM.forward = ( # pylint: disable=protected-access
_fixed_forward # pylint: disable=undefined-variable # noqa: F821
)
8 changes: 8 additions & 0 deletions src/axolotl/monkeypatch/unsloth_.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,14 @@ def detab_code(code: str) -> Tuple[str, str]:
return code, spaces


self_attn_lora_patched = False # pylint: disable=invalid-name


def patch_self_attn_lora():
global self_attn_lora_patched # pylint: disable=global-statement
if self_attn_lora_patched:
# prevent patching multiple times
return
self_attn_forward = get_self_attn_code()
LlamaFlashAttention2._original_forward = ( # pylint: disable=protected-access
self_attn_forward
Expand Down Expand Up @@ -134,6 +141,7 @@ def patch_self_attn_lora():
globals(),
)
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
self_attn_lora_patched = True
LOG.info("patching unsloth attn lora", main_process_only=True)
LlamaFlashAttention2.forward = (
unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821
Expand Down
29 changes: 29 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
shared pytest fixtures
"""
import functools
import importlib
import shutil
import sys
import tempfile
import time

Expand Down Expand Up @@ -113,3 +115,30 @@ def temp_dir():
yield _temp_dir
# Clean up the directory after the test
shutil.rmtree(_temp_dir)


@pytest.fixture(scope="function", autouse=True)
def cleanup_monkeypatches():
from transformers.models.llama.modeling_llama import LlamaFlashAttention2

original_fa2_forward = LlamaFlashAttention2.forward
# monkey patches can happen inside the tests
yield
# Reset LlamaFlashAttention2 forward
LlamaFlashAttention2.forward = original_fa2_forward

# Reset other known monkeypatches
modules_to_reset: list[tuple[str, list[str]]] = [
("transformers.models.llama.modeling_llama", ["LlamaFlashAttention2"]),
("transformers.trainer",),
("transformers.loss.loss_utils",),
]
for module_name_tuple in modules_to_reset:
module_name = module_name_tuple[0]
module = importlib.import_module(module_name)
sys.modules[module_name] = module
importlib.reload(sys.modules[module_name])
if len(module_name_tuple) > 1:
module_globals = module_name_tuple[1]
for module_global in module_globals:
globals().pop(module_global, None)

0 comments on commit 5bef190

Please sign in to comment.