Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for module input of type list #56

Open
jacksonsc007 opened this issue Dec 25, 2024 · 2 comments
Open

Add support for module input of type list #56

jacksonsc007 opened this issue Dec 25, 2024 · 2 comments

Comments

@jacksonsc007
Copy link

jacksonsc007 commented Dec 25, 2024

Wonderful project! I tried some baseline models and it worked well.
However, it seems that it does not support modules which take only one input argument of type list. For instance:

    def forward(self, feats: list[Tensor]):
        assert len(feats) == len(self.in_channels)
        proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)]
        ....

And if I put this specific model to torchexplorer, I get the following error:

  File "/root/miniconda3/envs/rtdetr/lib/python3.10/site-packages/torchexplorer/hook/hook.py", line 329, in process_tensor
    return tensor + dummy_tensor if torch.is_floating_point(tensor) else tensor
TypeError: is_floating_point(): argument 'input' (position 1) must be Tensor, not list

I tried to modify the hooks but it did not work out.

Any suggestions?
@spfrommer

@jacksonsc007
Copy link
Author

jacksonsc007 commented Dec 25, 2024

Btw, if I tear the input apart:

    def forward(self, feat1, feat2, feat3):
        feats = [feat1, feat2, feat3]
        proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)]
        ....

I get the following assert error:

  File "/root/miniconda3/envs/rtdetr/lib/python3.10/site-packages/torchexplorer/structure/structure.py", line 201, in <listcomp>
    self._inner_recurse(current_struct, n[0], n[1]) for n in next_functions
  File "/root/miniconda3/envs/rtdetr/lib/python3.10/site-packages/torchexplorer/structure/structure.py", line 200, in _inner_recurse
    upstreams = _flatten([
  File "/root/miniconda3/envs/rtdetr/lib/python3.10/site-packages/torchexplorer/structure/structure.py", line 201, in <listcomp>
    self._inner_recurse(current_struct, n[0], n[1]) for n in next_functions
  File "/root/miniconda3/envs/rtdetr/lib/python3.10/site-packages/torchexplorer/structure/structure.py", line 200, in _inner_recurse
    upstreams = _flatten([
  File "/root/miniconda3/envs/rtdetr/lib/python3.10/site-packages/torchexplorer/structure/structure.py", line 201, in <listcomp>
    self._inner_recurse(current_struct, n[0], n[1]) for n in next_functions
  File "/root/miniconda3/envs/rtdetr/lib/python3.10/site-packages/torchexplorer/structure/structure.py", line 200, in _inner_recurse
    upstreams = _flatten([
  File "/root/miniconda3/envs/rtdetr/lib/python3.10/site-packages/torchexplorer/structure/structure.py", line 201, in <listcomp>
    self._inner_recurse(current_struct, n[0], n[1]) for n in next_functions
  File "/root/miniconda3/envs/rtdetr/lib/python3.10/site-packages/torchexplorer/structure/structure.py", line 166, in _inner_recurse
    assert metadata['module'] == current_module
AssertionError
^CException ignored in: <module 'threading' from '/root/miniconda3/envs/rtdetr/lib/python3.10/threading.py'>
Traceback (most recent call last):
  File "/root/miniconda3/envs/rtdetr/lib/python3.10/threading.py", line 1567, in _shutdown
    lock.acquire()
KeyboardInterrupt: 

ps: This error disappeared after I remove the transformer encoder

@spfrommer
Copy link
Owner

Modules accepting a list of arguments isn't currently supported -- right now, it's expected that the input & output tensor shapes / count are consistent.

The second example you gave should work though. Would you mind making a MWE?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants