Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support QuestionAnswering Module for ModernBert based models. #35566

Open
wants to merge 17 commits into
base: main
Choose a base branch
from

Conversation

bakrianoo
Copy link

@bakrianoo bakrianoo commented Jan 8, 2025

What does this PR do?

This PR introduces a new feature: ModernBertForQuestionAnswering, which extends the ModernBERT model to support Question-Answering tasks.

This addresses a gap in the current implementation of ModernBERT and enables its use for tasks requiring answer extraction from long texts.

Evaluating

I made an experiment to fine-tune a QnA ModernBert based model. I used this script to produce this model.

Dependencies

  • No new dependencies were introduced.
  • The model was tested using the existing Hugging Face transformers library. Please check a QnA fine-tuned model here rankyx/ModernBERT-QnA-base-squad

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Current Status

Although I was be able to finetune a QnA Modern Bert model, I am still having an issue to complete the PR.

1)

Applying make fixup produces this error

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/xla_bridge.py", line 430, in discover_pjrt_plugins
    plugin_module.initialize()
  File "/usr/local/lib/python3.10/dist-packages/jax_plugins/xla_cuda12/__init__.py", line 85, in initialize
    options = xla_client.generate_pjrt_gpu_plugin_options()
AttributeError: module 'jaxlib.xla_client' has no attribute 'generate_pjrt_gpu_plugin_options'
Traceback (most recent call last):
  File "/content/transformers/utils/check_docstrings.py", line 1060, in <module>
    check_docstrings(overwrite=args.fix_and_overwrite, check_all=args.check_all)
  File "/content/transformers/utils/check_docstrings.py", line 978, in check_docstrings
    for modified_file_diff in repo.index.diff(repo.refs.main.commit):
  File "/usr/local/lib/python3.10/dist-packages/git/util.py", line 972, in __getattr__
    return list.__getattribute__(self, attr)
AttributeError: 'IterableList' object has no attribute 'main'

I tried to reinstall jaxlib==0.4.13 but it does not help.

I appreciate your help here.

@Rocketknight1
Copy link
Member

Hi @bakrianoo, thanks for the PR!

  1. I think the reason you can't load the model with AutoModelForQuestionAnswering is that you've added an extra line in the MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES? Can you make sure the model is only included in the question answering mapping? Once it's in there I believe it should be loadable with the autoclass.

  2. The repo-consistency error is because the new class isn't added to docs/source/en/model_doc/modernbert.md. If you add an autodoc section there, the error should go away!

@bakrianoo bakrianoo requested a review from stevhliu as a code owner January 9, 2025 09:33
@bakrianoo
Copy link
Author

bakrianoo commented Jan 9, 2025

Hi @Rocketknight1
Your advice solved my issues. Thank you.
I now have a new issue, can you check the updated PR description?


Applying make fixup produces this error

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/xla_bridge.py", line 430, in discover_pjrt_plugins
    plugin_module.initialize()
  File "/usr/local/lib/python3.10/dist-packages/jax_plugins/xla_cuda12/__init__.py", line 85, in initialize
    options = xla_client.generate_pjrt_gpu_plugin_options()
AttributeError: module 'jaxlib.xla_client' has no attribute 'generate_pjrt_gpu_plugin_options'
Traceback (most recent call last):
  File "/content/transformers/utils/check_docstrings.py", line 1060, in <module>
    check_docstrings(overwrite=args.fix_and_overwrite, check_all=args.check_all)
  File "/content/transformers/utils/check_docstrings.py", line 978, in check_docstrings
    for modified_file_diff in repo.index.diff(repo.refs.main.commit):
  File "/usr/local/lib/python3.10/dist-packages/git/util.py", line 972, in __getattr__
    return list.__getattribute__(self, attr)
AttributeError: 'IterableList' object has no attribute 'main'

I tried to reinstall jaxlib==0.4.13 but it does not help.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants