Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add transpose convolution layer to manim_ml [Do Not Merge Yet] #36

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

Lawreros
Copy link
Contributor

Hi Alec,
My first error fix to this project a few weeks ago got me into using manim to illustrate ML concepts, and I started working on adding transpose convolution as a fun side project. This PR is not ready to merge, as I still have a coloring bug to fix, but I wanted to make you aware of my work so I could potentially get some feedback before a serious PR. I tried to keep the TransposeConvolution2DLayer in line with how the rest of your filters are formatted, and it can be used in a similar way to the Convolutional2DLayer. There is documentation in the __init__ function, but a simple example for it's use would be:

NeuralNetwork([
                Convolutional2DLayer(1, 7, filter_spacing=0.32),
                Convolutional2DLayer(2, 5, 3, filter_spacing=0.32),
                TransposeConvolution2DLayer(2, 3, 1, filter_spacing=0.32),
                FeedForwardLayer(3),
            ],
            layer_spacing=0.25,
        )

For a quick version of the code above (to keep file size small) , watch this:
https://user-images.githubusercontent.com/52179159/231038729-c8f12df2-1355-4329-b9c7-5761828e1f24.mp4

I'm going to keep working on this in my free time, but it should be done within the week.

Thanks,
Ross

@helblazer811
Copy link
Owner

helblazer811 commented Apr 11, 2023

That's awesome! Thanks for contributing. I'll look over the code tomorrow hopefully if I have the chance. I'll try to look out for any potential failure cases or idiosyncracies with the library. It is also highly possible that something I have done in the underlying library could be at fault if you have an issue.

@helblazer811
Copy link
Owner

I looked at your video a bit. It looks like you are showing some sort of upsampling followed by a convolution. Wouldn't a basic transposed convolution operation just be an individual cell in an input feature map is mapped to a larger output kernel. Are you visualizing a different kind of transposed convolution than I am thinking?

@Lawreros
Copy link
Contributor Author

Lawreros commented Apr 11, 2023

The transposed convolution I was trying to represent contains the following steps:

  1. The original n-by-n feature map has each 1-by-1 pixel padded by p zeros. So a 5-by-5 feature map with p=1 would result in an 11-by-11 feature map.
  2. A k-by-k kernel is used to convolve the padded image, resulting in an m-by-m feature map if m > n
    This is based off of https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
    because the kernel values are trainable, it's worth displaying that a convolution is happening.

(Though I've been wrong about this kind of thing in the past, lol)

@helblazer811
Copy link
Owner

Reading more about the transposed convolution operation your visualization makes total sense! I think I had an incorrect model of that operation in my head.

@helblazer811
Copy link
Owner

I'm curious, does the design of the code make sense to you? I saw that you added a ConnectiveLayer for Conv2D to TransConv2d. Was that a straightforward process? Were there any parts of the code that didn't make a ton of sense to you?

@Lawreros
Copy link
Contributor Author

Being new to Manim in general I think added a lot of difficulty to understanding how this code worked. Once I walked through a sample scene with a debugger, it was fairly easy to understand. I think more documentation would be a big help in understanding how the connective layers worked. I am a little worried with how each new layer type requires a set of connective layers, but I don't have a great way to fix that. Maybe create a standardized dict that each class has that denotes it as a "category" of label (for example, Conv2d and TransConv2d have similar feature maps), and then some unified connective layer? 🤷

@helblazer811
Copy link
Owner

Yeah, maybe I could make some template connective layers. I think decoupling each of the different connective animations from the layer class is the right move in general though.

@helblazer811
Copy link
Owner

I've started putting together a documentation website https://alechelbling.com/ManimML/.
I haven't added much to it yet though, but eventually I will.

@helblazer811
Copy link
Owner

Any updates on this?

@Lawreros
Copy link
Contributor Author

Lawreros commented Jun 7, 2023

Sorry, yes I haven't forgotten about this, just been distracted by work and papers! Haven't really had time to work on it 😢 , but I should have free time again soon

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants