Skip to content

Commit

Permalink
Add unit tests for model_utils
Browse files Browse the repository at this point in the history
  • Loading branch information
arnavgarg1 committed Sep 8, 2023
1 parent 7cacb77 commit 0158d03
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 3 deletions.
12 changes: 10 additions & 2 deletions ludwig/utils/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,17 @@
import torch
import torch.nn.functional as F
from bitsandbytes.nn.modules import Embedding
from transformers import GPT2Tokenizer, GPT2TokenizerFast, LlamaTokenizer, LlamaTokenizerFast, PreTrainedTokenizer
from transformers import (
AutoModelForCausalLM,
GPT2Tokenizer,
GPT2TokenizerFast,
LlamaTokenizer,
LlamaTokenizerFast,
PreTrainedTokenizer,
)

from ludwig.constants import IGNORE_INDEX_TOKEN_ID, LOGITS, PREDICTIONS, PROBABILITIES
from ludwig.schema.trainer import LLMTrainerConfig
from ludwig.utils.model_utils import find_embedding_layer_with_path

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -383,7 +391,7 @@ def realign_target_and_prediction_tensors_for_inference(
return targets, predictions


def update_embedding_layer(model, config_obj):
def update_embedding_layer(model: AutoModelForCausalLM, config_obj: LLMTrainerConfig) -> AutoModelForCausalLM:
"""Updates the embedding layer of the model to use the 8-bit embedding layer from bitsandbytes.nn.modules.
This is necessary when using 8-bit optimizers from bitsandbytes.
Expand Down
38 changes: 37 additions & 1 deletion tests/ludwig/utils/test_model_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from transformers import AutoModelForCausalLM

from ludwig.utils.model_utils import extract_tensors, replace_tensors
from ludwig.utils.model_utils import extract_tensors, find_embedding_layer_with_path, replace_tensors

# Define a sample model for testing

Expand Down Expand Up @@ -59,3 +60,38 @@ def test_replace_tensors():
for name, array in tensor_dict["buffers"].items():
assert name in module._buffers
assert torch.allclose(module._buffers[name], torch.as_tensor(array, device=device))


# Define a sample module structure for testing
class SampleModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.embedding = torch.nn.Embedding(10, 20)
self.rnn = torch.nn.LSTM(20, 30)


def test_find_embedding_layer_with_path_simple():
# Test case 1: Test the function with a simple module structure
module = SampleModule()
embedding_layer, path = find_embedding_layer_with_path(module)
assert embedding_layer is not None
assert isinstance(embedding_layer, torch.nn.Embedding)
assert path == "embedding"


def test_find_embedding_layer_with_path_complex():
# Test case 2: Test the function with a more complex module structure including AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("HuggingFaceM4/tiny-random-LlamaForCausalLM")

embedding_layer, path = find_embedding_layer_with_path(model)
assert embedding_layer is not None
assert isinstance(embedding_layer, torch.nn.Embedding)
assert path == "model.embed_tokens"


def test_no_embedding_layer():
# Test case 3: Embedding layer is not present
no_embedding_model = torch.nn.Sequential(torch.nn.Linear(10, 10), torch.nn.Linear(10, 10))
embedding_layer, path = find_embedding_layer_with_path(no_embedding_model)
assert embedding_layer is None
assert path is None

0 comments on commit 0158d03

Please sign in to comment.