Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch.

Overview

SE3 Transformer - Pytorch

Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch. May be needed for replicating Alphafold2 results and other drug discovery applications.

Install

$ pip install se3-transformer-pytorch

Usage

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 512,
    heads = 8,
    depth = 6,
    dim_head = 64,
    num_degrees = 4,
    valid_radius = 10
)

feats = torch.randn(1, 1024, 512)
coors = torch.randn(1, 1024, 3)
mask  = torch.ones(1, 1024).bool()

out = model(feats, coors, mask) # (1, 1024, 512)

Potential example usage in Alphafold2, as outlined here

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 64,
    depth = 2,
    input_degrees = 1,
    num_degrees = 2,
    output_degrees = 2,
    reduce_dim_out = True
)

atom_feats = torch.randn(2, 32, 64)
coors = torch.randn(2, 32, 3)
mask  = torch.ones(2, 32).bool()

refinement = model(atom_feats, coors, mask, return_type = 1) # (2, 32, 3)

You can also let the base transformer class take care of embedding the type 0 features being passed in. Assuming they are atoms

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    num_tokens = 28,       # 28 unique atoms
    dim = 64,
    depth = 2,
    input_degrees = 1,
    num_degrees = 2,
    output_degrees = 2,
    reduce_dim_out = True
)

atoms = torch.randint(0, 28, (2, 32))
coors = torch.randn(2, 32, 3)
mask  = torch.ones(2, 32).bool()

refinement = model(atoms, coors, mask, return_type = 1) # (2, 32, 3)

If you think the net could further benefit from positional encoding, you can featurize your positions in space and pass it in as follows.

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 64,
    depth = 2,
    input_degrees = 2,
    num_degrees = 2,
    output_degrees = 2,
    reduce_dim_out = True  # reduce out the final dimension
)

atom_feats  = torch.randn(2, 32, 64, 1) # b x n x d x type0
coors_feats = torch.randn(2, 32, 64, 3) # b x n x d x type1

# atom features are type 0, predicted coordinates are type 1
features = {'0': atom_feats, '1': coors_feats}
coors = torch.randn(2, 32, 3)
mask  = torch.ones(2, 32).bool()

refinement = model(features, coors, mask, return_type = 1) # (2, 32, 3) - equivariant to input type 1 features and coordinates

Edges

To offer edge information to SE3 Transformers (say bond types between atoms), you just have to pass in two more keyword arguments on initialization.

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    num_tokens = 28,
    dim = 64,
    num_edge_tokens = 4,       # number of edge type, say 4 bond types
    edge_dim = 16,             # dimension of edge embedding
    depth = 2,
    input_degrees = 1,
    num_degrees = 3,
    output_degrees = 1,
    reduce_dim_out = True
)

atoms = torch.randint(0, 28, (2, 32))
bonds = torch.randint(0, 4, (2, 32, 32))
coors = torch.randn(2, 32, 3)
mask  = torch.ones(2, 32).bool()

pred = model(atoms, coors, mask, edges = bonds, return_type = 0) # (2, 32, 1)

Caching

By default, the basis vectors are cached. However, if there is ever the need to clear the cache, you simply have to set the environmental flag CLEAR_CACHE to some value on initiating the script

$ CLEAR_CACHE=1 python train.py

Or you can try deleting the cache directory, which should exist at

$ rm -rf ~/.cache.equivariant_attention

Testing

$ python setup.py pytest

Credit

This library is largely a port of Fabian's official repository, but without the DGL library.

Citations

@misc{fuchs2020se3transformers,
    title   = {SE(3)-Transformers: 3D Roto-Translation Equivariant Attention Networks}, 
    author  = {Fabian B. Fuchs and Daniel E. Worrall and Volker Fischer and Max Welling},
    year    = {2020},
    eprint  = {2006.10503},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
Comments
  • Breaking equivariance

    Breaking equivariance

    Hi, Thanks a lot for your work!!

    I was running some equivariance tests on your implementation of the SE3-Transformer and found that it is not always conserved. It does not break every time and unfortunately I do not know where the bug is.

    I have appended an image with an example of equivariance not being conserved.

    image

    To generate this image I used the following code:

        import numpy as np
        import matplotlib.pyplot as plt
        import torch
        
        from se3_transformer_pytorch import SE3Transformer
    

    Make some data

        zline = np.arange(0, 2, 0.05)
        xline = np.sin(zline * 2 * np.pi) 
        yline = np.cos(zline * 2 * np.pi)
        points = np.array([xline, yline, zline])
        geom = torch.tensor(points.transpose())[None,:].float()
        feat = torch.randint(0, 20, (1, geom.shape[1],1)).float()
        
        def rot_matrix(x):
            # Rotation matrix
            a ,b ,c = 2*np.pi*x
            return np.array([
                [np.cos(a)*np.cos(b), np.cos(a)*np.sin(b)*np.sin(c)- np.sin(a)*np.cos(c), np.cos(a)*np.sin(b)*np.cos(c)+ np.sin(a)*np.sin(c)],
                [np.sin(a)*np.cos(b), np.sin(a)*np.sin(b)*np.sin(c)+ np.cos(a)*np.cos(c), np.sin(a)*np.sin(b)*np.cos(c)- np.cos(a)*np.sin(c)],
                [-np.sin(b)         , np.cos(b)*np.sin(c)                               , np.cos(b)*np.cos(c)                               ]
            ])
    

    Initialize model

        mdl = SE3Transformer(
            dim = 1,
            depth = 3,
            input_degrees = 1,
            num_degrees = 2,
            output_degrees = 2,
            reduce_dim_out = True,
        )
        
        def model(geom,feat):
            return geom + mdl(feat,geom, return_type = 1)
    

    Check Rotation Invariance:

        with torch.no_grad():
            
            Q = torch.tensor(rot_matrix(np.random.random(3))).float()
            prerotated = model(geom @ Q, feat).squeeze().detach().numpy().transpose()
            posrotated = (model(geom, feat) @ Q).squeeze().detach().numpy().transpose()
        
            fig = plt.figure(dpi = 200)
            ax = plt.axes(projection="3d")
        
            ax.plot3D(prerotated[0], prerotated[1], prerotated[2], "r", linewidth=1.1)
            ax.plot3D(posrotated[0], posrotated[1], posrotated[2], "b", linewidth=0.5)
        
            plt.legend(["Pre-Rotated", "Post-Rotated"])
            plt.show()
    

    Check Translation Invariance:

        with torch.no_grad():
            
            x0 = 1*torch.rand(3)
            prerotated = model(geom + x0, feat).squeeze().detach().numpy().transpose()
            posrotated = (model(geom, feat) + x0).squeeze().detach().numpy().transpose()
        
            fig = plt.figure(dpi = 200)
            ax = plt.axes(projection="3d")
        
            ax.plot3D(prerotated[0], prerotated[1], prerotated[2], "r", linewidth=1.1)
            ax.plot3D(posrotated[0], posrotated[1], posrotated[2], "b", linewidth=0.5)
        
            plt.legend(["Pre-Translated", "Post-Translated"])
            plt.show()
    

    Hope this helps!!

    opened by brennanaba 18
  • How to populate input variable length data

    How to populate input variable length data

    Hello, thank you for your work. I am using the implementation of your SE3-Transformer as part of my project, but I have encountered some problems when processing variable length data input into your model. I do not know how to fill in features, coordinates and masks to meet the needs of the model.

    The processing of input data in my DataSet is as follows: image I fill the coordinates and features with zeros to ensure that their input dimensions are (246,3) and (246,20) respectively, and fill the mask with true and false bool types. In this case,there are 49 valid values,.So the shapes with valid features and coordinates are (49,20) and (49,3) and the rest are padded to zero. The first 49 of the mask is true and the rest is false

    Then I check through the torch.utils.data.DataLoader get input dimension is no problem: image My network looks like this, A typical binary classifier,I output the result of se3:

    image

    However, the output of se3 is NaN as shown in the figure below. image The first 48 have results, the rest are NaN, not clear what the problem is

    The next batch may all be NaN due to model weight changes image The output of SE3 causes all my losses to be NaN

    I think these NaN are probably caused by my incorrect filling method. I have also tried to fill in "1", but it was also ineffective. Maybe the mask I typed in can't have false, and I'm confused about that. Could you please tell me how to correctly fill coordinates, features and masks or give me some advice on using your SE3 model to handle variable length data. Hope this helps! ! thank you

    opened by zyk19981118 8
  • CPU/CUDA masking error

    CPU/CUDA masking error

    Hi - nice work - I was just testing out your code and ran into the following error, but only in the backward pass:

    RuntimeError: Expected object of device type cuda but got device type cpu for argument #2 'mask' in call to th_masked_scatter_bool

    When using the nightly Pytorch, the error message is:

    RuntimeError: Tensor for argument #2 'mask' is on CPU, but expected it to be on GPU (while checking arguments for masked_scatter_)

    I'm pretty sure I don't have any tensors in CPU memory, but not sure if this is a bug in your SE3 code or a Pytorch issue. My gut feeling is this is a Pytorch/autograd issue, but I just don't know these particular Pytorch ops well enough to be sure. Tried both 1.7.1 release Pytorch and the latest nightly. Seems like there has been recent work on masked_scatter according to Pytorch issues.

    opened by denjots 8
  • small bug

    small bug

    https://github.com/lucidrains/se3-transformer-pytorch/blob/7c79998e4d84ec6bd6b6d4b916c6bf30b870b75b/se3_transformer_pytorch/se3_transformer_pytorch.py#L301

    should be if isinstance(m,nn.Linear)

    nbd, but thought you might wanna know.

    opened by MattMcPartlon 5
  • INTERNAL ASSERT FAILED

    INTERNAL ASSERT FAILED

    Hi - I'm not sure if this is actually a bug or if I'm just expecting too much, but the following code snippet bombs with a PyTorch internal assert failure:

    RuntimeError: sub_iter.strides(0)[0] == 0INTERNAL ASSERT FAILED at "/opt/conda/conda-bld/pytorch_1618643394934/work/aten/src/ATen/native/cuda/Reduce.cuh":929, please report a bug to PyTorch.

    That's using the latest PyTorch nightly. Tried with both CUDA 10.2 and 11.1.

    Looking at the PyTorch bug reports, this seems to suggest that there is massive internal memory allocation being triggered when summing over a large tensor. It seems to have been sitting open for a year and I'm not sure if they even see it as an actual bug there. Certainly it doesn't seem to be a priority.

    Problem with this is that this isn't a particularly large data input, or a large model (to say the least!), and my GPU has 40 Gb of RAM, so, scalability of this particular transformer model seems very minimal as it currently stands. Changing the dim from 128 to 64 does at least allow it to run.

    import torch
    from se3_transformer_pytorch import SE3Transformer
    
    model = SE3Transformer(dim = 128, heads = 1, depth = 1, dim_head = 1, num_degrees = 1, input_degrees=1, output_degrees=1).cuda()
    
    feats = torch.randn(1, 200, 128).cuda()
    coords = torch.randn(1, 200, 3).cuda()
    
    out = model(feats, coords)
    
    
    opened by denjots 4
  • question about non scalar output

    question about non scalar output

    Hello,

    I am interested in using your implementation on my dataset. There are two things I want to check with you

    1. The property I want to predict is a 3x3 symmetric PSD matrix.
    2. There are some edge features (one categorical feature, one continuous feature) besides the coordinate difference

    I was wondering does the current se3-transformer can work with such a scenario? Thanks!

    opened by Chen-Cai-OSU 3
  • Reversible flag odd results

    Reversible flag odd results

    Sorry if it's expected behaviour again, but with a slight tweak of your new example code I get unexpected results when your new reversible option is used to return type 1 data. Here's the code I'm running...

    import torch
    from se3_transformer_pytorch import SE3Transformer
    
    model = SE3Transformer(
        num_tokens = 20,
        dim = 32,
        dim_head = 32,
        heads = 4,
        depth = 12,             # 12 layers
        input_degrees = 1,
        num_degrees = 2,
        output_degrees = 2,
        reduce_dim_out = True,
        reversible = True       # set reversible to True
    ).cuda()
    
    atoms = torch.randint(0, 4, (1, 50)).cuda()
    coors = torch.randn(1, 50, 3).cuda()
    mask  = torch.ones(1, 50).bool().cuda()
    
    pred = model(atoms, coors, mask = mask, return_type = 1)
    
    loss = pred.mean()
    print(loss)
    loss.backward()
    
    

    Without the reversible flag, the loss is close to zero as might be expected as the input coords are centered on the origin. However, when reversible is set, a very high value is produced which is going to blow up training if this was for real. Is this an expected side effect of reversible nets? Is some kind of normalization essential in that case? Or am I just doing something wrong here?

    opened by denjots 3
  • faster loop

    faster loop

    My small grain on sand for this project ;) : at least don't deal with python appends which are 2x slower than list comprehension.

    If this is of any help, here are some considerations (there might be misunderstandings on my side due to the decorators and so on):

    • i have my reservations on the utility of the line 62 in utils.py
    • would't make more sense to start the for i modified (line 122 of utils.py) in reverse order, then use the cached calculations for the lpmv() ?
    • same case (loop in reverse order) for the for in line 148 of basis.py?
    • if using the scipy.special.poch (which can deal with np arrays) instead of the custom pochhammer implementation, all operations inside get_spherical_harmonics_element are vectorizeable but the lpmv function call.
      • My sense is that the lpmv, get_spherical_harmonics_element and get_spherical_harmonics could be all wrapped in a single function (lower reusability / extension... so maybe doing the inverse loop order and caching is enough).
    opened by hypnopump 1
  • Question about continuous edge features

    Question about continuous edge features

    Thanks for all of the work!

    I am working on a simple proof of concept with your model. Ideally, I would like to perform multidimensional scaling. i.e. given a distance matrix recover corresponding coordinates.

    I was wondering if distance information could be passed as an edge feature (continuous rather than categorical information). Is it possible to do this with the current implementation?

    Thanks again and I appreciate the help!

    opened by MattMcPartlon 1
  • denoise.py bugfix

    denoise.py bugfix

    Fixes issue related to the constructor

    Traceback (most recent call last):
      File "/home/jcastellanos/projects/se3-transformer-pytorch/denoise.py", line 22, in <module>
        transformer = SE3Transformer(
      File "/home/jcastellanos/projects/se3-transformer-pytorch/se3_transformer_pytorch/se3_transformer_pytorch.py", line 1072, in __init__
        self.num_degrees = num_degrees if exists(num_degrees) else (max(hidden_fiber_dict.keys()) + 1)
    AttributeError: 'NoneType' object has no attribute 'keys'
    
    opened by javierbq 0
  • CUDA out of memory

    CUDA out of memory

    Thanks for your great job!

    The se3-transformer is powerful, but seems to be memory exhaustive.

    I built a model with the following parameters, and got "CUDA out of memory error" when I run it on the GPU(Nvidia V100 / 32G).

    model = SE3Transformer( dim = 20, heads = 4, depth = 2, dim_head = 5, num_degrees = 2, valid_radius = 5 )

    num_points = 512
    feats = torch.randn(1, num_points, 20)
    coors = torch.randn(1, num_points, 3)
    mask = torch.ones(1, num_points).bool()
    

    Does this error relate to the version of pytorch? and how can I fix it?

    opened by PengCheng-NUDT 0
  • SE3Transformer constructor hangs

    SE3Transformer constructor hangs

    I am trying to run an example from the README. The code is:

    import torch
    from se3_transformer_pytorch import SE3Transformer
    
    print('Initialising model...')
    model = SE3Transformer(
        dim = 512,
        heads = 8,
        depth = 6,
        dim_head = 64,
        num_degrees = 4,
        valid_radius = 10
    )
    
    print('Running model...')
    feats = torch.randn(1, 1024, 512)
    coors = torch.randn(1, 1024, 3)
    mask  = torch.ones(1, 1024).bool()
    
    out = model(feats, coors, mask) # (1, 1024, 512)
    

    The output hangs on 'Initialising model...' and eventually the kernel dies.

    Any ideas why this would be happening?

    Here is my pip freeze:

    anyio==3.2.1
    argon2-cffi @ file:///tmp/build/80754af9/argon2-cffi_1613036642480/work
    astunparse==1.6.3
    async-generator==1.10
    attrs @ file:///tmp/build/80754af9/attrs_1620827162558/work
    axial-positional-embedding==0.2.1
    Babel==2.9.1
    backcall @ file:///home/ktietz/src/ci/backcall_1611930011877/work
    biopython==1.79
    bleach @ file:///tmp/build/80754af9/bleach_1612211392645/work
    cached-property @ file:///tmp/build/80754af9/cached-property_1600785575025/work
    certifi==2021.5.30
    cffi @ file:///tmp/build/80754af9/cffi_1613246939562/work
    chardet==4.0.0
    click==8.0.1
    configparser==5.0.2
    decorator==4.4.2
    defusedxml @ file:///tmp/build/80754af9/defusedxml_1615228127516/work
    dgl-cu101==0.4.3.post2
    dgl-cu110==0.6.1
    docker-pycreds==0.4.0
    egnn-pytorch==0.2.6
    einops==0.3.0
    En-transformer==0.3.8
    entrypoints==0.3
    equivariant-attention @ file:///workspace/projects/se3-transformer-public
    filelock==3.0.12
    gitdb==4.0.7
    GitPython==3.1.18
    graph-transformer-pytorch==0.0.1
    h5py @ file:///tmp/build/80754af9/h5py_1622088444809/work
    huggingface-hub==0.0.12
    idna==2.10
    importlib-metadata @ file:///tmp/build/80754af9/importlib-metadata_1617877314848/work
    ipykernel @ file:///tmp/build/80754af9/ipykernel_1596206598566/work/dist/ipykernel-5.3.4-py3-none-any.whl
    ipython @ file:///tmp/build/80754af9/ipython_1617118429768/work
    ipython-genutils @ file:///tmp/build/80754af9/ipython_genutils_1606773439826/work
    ipywidgets @ file:///tmp/build/80754af9/ipywidgets_1610481889018/work
    jedi==0.17.0
    Jinja2 @ file:///tmp/build/80754af9/jinja2_1621238361758/work
    joblib==1.0.1
    json5==0.9.6
    jsonschema @ file:///tmp/build/80754af9/jsonschema_1602607155483/work
    jupyter==1.0.0
    jupyter-client @ file:///tmp/build/80754af9/jupyter_client_1616770841739/work
    jupyter-console @ file:///tmp/build/80754af9/jupyter_console_1616615302928/work
    jupyter-core @ file:///tmp/build/80754af9/jupyter_core_1612213308260/work
    jupyter-server==1.8.0
    jupyter-tensorboard==0.2.0
    jupyterlab==3.0.16
    jupyterlab-pygments @ file:///tmp/build/80754af9/jupyterlab_pygments_1601490720602/work
    jupyterlab-server==2.6.0
    jupyterlab-widgets @ file:///tmp/build/80754af9/jupyterlab_widgets_1609884341231/work
    jupytext==1.11.3
    lie-learn @ git+https://github.com/AMLab-Amsterdam/[email protected]
    llvmlite==0.36.0
    local-attention==1.4.1
    markdown-it-py==1.1.0
    MarkupSafe @ file:///tmp/build/80754af9/markupsafe_1621528142364/work
    mdit-py-plugins==0.2.8
    mdtraj==1.9.6
    mistune @ file:///tmp/build/80754af9/mistune_1594373098390/work
    mkl-fft==1.3.0
    mkl-random @ file:///tmp/build/80754af9/mkl_random_1618853974840/work
    mkl-service==2.3.0
    mp-nerf==0.1.11
    nbclassic==0.3.1
    nbclient @ file:///tmp/build/80754af9/nbclient_1614364831625/work
    nbconvert @ file:///tmp/build/80754af9/nbconvert_1601914821128/work
    nbformat @ file:///tmp/build/80754af9/nbformat_1617383369282/work
    nest-asyncio @ file:///tmp/build/80754af9/nest-asyncio_1613680548246/work
    networkx==2.5.1
    notebook @ file:///tmp/build/80754af9/notebook_1621523661196/work
    numba==0.53.1
    numpy @ file:///tmp/build/80754af9/numpy_and_numpy_base_1620831194891/work
    packaging @ file:///tmp/build/80754af9/packaging_1611952188834/work
    pandas==1.2.4
    pandocfilters @ file:///tmp/build/80754af9/pandocfilters_1605120451932/work
    parso @ file:///tmp/build/80754af9/parso_1617223946239/work
    pathtools==0.1.2
    performer-pytorch==1.0.11
    pexpect @ file:///tmp/build/80754af9/pexpect_1605563209008/work
    pickleshare @ file:///tmp/build/80754af9/pickleshare_1606932040724/work
    ProDy==2.0
    prometheus-client @ file:///tmp/build/80754af9/prometheus_client_1623189609245/work
    promise==2.3
    prompt-toolkit @ file:///tmp/build/80754af9/prompt-toolkit_1616415428029/work
    protobuf==3.17.3
    psutil==5.8.0
    ptyprocess @ file:///tmp/build/80754af9/ptyprocess_1609355006118/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
    py3Dmol==0.9.1
    pycparser @ file:///tmp/build/80754af9/pycparser_1594388511720/work
    Pygments @ file:///tmp/build/80754af9/pygments_1621606182707/work
    pyparsing @ file:///home/linux1/recipes/ci/pyparsing_1610983426697/work
    pyrsistent @ file:///tmp/build/80754af9/pyrsistent_1600141707582/work
    python-dateutil @ file:///home/ktietz/src/ci/python-dateutil_1611928101742/work
    pytz @ file:///tmp/build/80754af9/pytz_1612215392582/work
    PyYAML==5.4.1
    pyzmq==20.0.0
    qtconsole @ file:///tmp/build/80754af9/qtconsole_1623278325812/work
    QtPy==1.9.0
    regex==2021.4.4
    requests==2.25.1
    sacremoses==0.0.45
    scipy @ file:///tmp/build/80754af9/scipy_1618852618548/work
    se3-transformer-pytorch==0.8.10
    Send2Trash @ file:///tmp/build/80754af9/send2trash_1607525499227/work
    sentry-sdk==1.1.0
    shortuuid==1.0.1
    sidechainnet==0.6.0
    six @ file:///tmp/build/80754af9/six_1623709665295/work
    smmap==4.0.0
    sniffio==1.2.0
    subprocess32==3.5.4
    terminado==0.9.4
    testpath @ file:///home/ktietz/src/ci/testpath_1611930608132/work
    tokenizers==0.10.3
    toml==0.10.2
    torch==1.9.0
    tornado @ file:///tmp/build/80754af9/tornado_1606942283357/work
    tqdm==4.61.1
    traitlets @ file:///home/ktietz/src/ci/traitlets_1611929699868/work
    transformers==4.8.0
    typing-extensions @ file:///home/ktietz/src/ci_mi/typing_extensions_1612808209620/work
    urllib3==1.26.5
    wandb==0.10.32
    wcwidth @ file:///tmp/build/80754af9/wcwidth_1593447189090/work
    webencodings==0.5.1
    websocket-client==1.1.0
    widgetsnbextension==3.5.1
    zipp @ file:///tmp/build/80754af9/zipp_1615904174917/work
    

    Here is a summary of my system info (lshw -short):

    H/W path    Device  Class      Description
    ==========================================
                        system     Computer
    /0                  bus        Motherboard
    /0/0                memory     59GiB System memory
    /0/1                processor  Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz
    /0/100              bridge     440FX - 82441FX PMC [Natoma]
    /0/100/1            bridge     82371SB PIIX3 ISA [Natoma/Triton II]
    /0/100/1.1          storage    82371SB PIIX3 IDE [Natoma/Triton II]
    /0/100/1.3          bridge     82371AB/EB/MB PIIX4 ACPI
    /0/100/2            display    GD 5446
    /0/100/3            network    Elastic Network Adapter (ENA)
    /0/100/1e           display    GK210GL [Tesla K80]
    /0/100/1f           generic    Xen Platform Device
    /1          eth0    network    Ethernet interface
    
    opened by mpdprot 1
  • Whether SE3 needs pre-training

    Whether SE3 needs pre-training

    Thank you for your work. I used your reproduced SE3 as a part of my model, but the current test effect is not very good. I guess it may be because I do not have a good understanding of your model. Here are my questions:

    1. Does your model need pre-training?
    2. Can I train SE3 Transformer with the full connection layer that comes after it? Good advice is also welcome
    opened by zyk19981118 2
  • multiple molecules cases

    multiple molecules cases

    Hi,

    I use normally dataloader from PyG to handle my molecules dataset.

    Can you provide an example to make a real multistep epoch model please ?

    I run your sample code it's working but I need better understand input format as well as output format which for me would be x,y,z ?

    image

    thanks

    opened by thegodone 2
Releases(0.9.0)
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
Latent Execution for Neural Program Synthesis

Latent Execution for Neural Program Synthesis This repo provides the code to replicate the experiments in the paper Xinyun Chen, Dawn Song, Yuandong T

Xinyun Chen 16 Oct 02, 2022
Minimal fastai code needed for working with pytorch

fastai_minima A mimal version of fastai with the barebones needed to work with Pytorch #all_slow Install pip install fastai_minima How to use This lib

Zachary Mueller 14 Oct 21, 2022
Código de um painel de auto atendimento feito em Python.

Painel de Auto-Atendimento O intuito desse projeto era fazer em Python um programa que simulasse um painel de auto atendimento, no maior estilo Mac Do

Calebe Alves Evangelista 2 Nov 09, 2022
A developer interface for creating Chat AIs for the Chai app.

ChaiPy A developer interface for creating Chat AIs for the Chai app. Usage Local development A quick start guide is available here, with a minimal exa

Chai 28 Dec 28, 2022
Generalized Matrix Means for Semi-Supervised Learning with Multilayer Graphs

Generalized Matrix Means for Semi-Supervised Learning with Multilayer Graphs MATLAB implementation of the paper: P. Mercado, F. Tudisco, and M. Hein,

Pedro Mercado 6 May 26, 2022
Convert onnx models to pytorch.

onnx2torch onnx2torch is an ONNX to PyTorch converter. Our converter: Is easy to use – Convert the ONNX model with the function call convert; Is easy

ENOT 264 Dec 30, 2022
Breast Cancer Detection 🔬 ITI "AI_Pro" Graduation Project

BreastCancerDetection - This program is designed to predict two severity of abnormalities associated with breast cancer cells: benign and malignant. Mammograms from MIAS is preprocessed and features

6 Nov 29, 2022
Implementation of Deformable Attention in Pytorch from the paper "Vision Transformer with Deformable Attention"

Deformable Attention Implementation of Deformable Attention from this paper in Pytorch, which appears to be an improvement to what was proposed in DET

Phil Wang 128 Dec 24, 2022
Instance-wise Occlusion and Depth Orders in Natural Scenes (CVPR 2022)

Instance-wise Occlusion and Depth Orders in Natural Scenes Official source code. Appears at CVPR 2022 This repository provides a new dataset, named In

27 Dec 27, 2022
dyld_shared_cache processing / Single-Image loading for BinaryNinja

Dyld Shared Cache Parser Author: cynder (kat) Dyld Shared Cache Support for BinaryNinja Without any of the fuss of requiring manually loading several

cynder 76 Dec 28, 2022
The source codes for TME-BNA: Temporal Motif-Preserving Network Embedding with Bicomponent Neighbor Aggregation.

TME The source codes for TME-BNA: Temporal Motif-Preserving Network Embedding with Bicomponent Neighbor Aggregation. Our implementation is based on TG

2 Feb 10, 2022
This is the code of "Multi-view Contrastive Graph Clustering" in NeurlPS 2021.

MCGC Description This is the code of "Multi-view Contrastive Graph Clustering" in NeurlPS 2021. Datasets Results ACM DBLP IMDB Amazon photos Amazon co

31 Nov 14, 2022
An implementation of the WHATWG URL Standard in JavaScript

whatwg-url whatwg-url is a full implementation of the WHATWG URL Standard. It can be used standalone, but it also exposes a lot of the internal algori

314 Dec 28, 2022
Pytorch reimplement of the paper "A Novel Cascade Binary Tagging Framework for Relational Triple Extraction" ACL2020. The original code is written in keras.

CasRel-pytorch-reimplement Pytorch reimplement of the paper "A Novel Cascade Binary Tagging Framework for Relational Triple Extraction" ACL2020. The o

longlongman 170 Dec 01, 2022
The second project in Python course on FCC

Assignment Write a function named add_time that takes in two required parameters and one optional parameter: a start time in the 12-hour clock format

Denise T 1 Dec 13, 2021
SegNet-Basic with Keras

SegNet-Basic: What is Segnet? Deep Convolutional Encoder-Decoder Architecture for Semantic Pixel-wise Image Segmentation Segnet = (Encoder + Decoder)

Yad Konrad 81 Jun 30, 2022
Applying CLIP to Point Cloud Recognition.

PointCLIP: Point Cloud Understanding by CLIP This repository is an official implementation of the paper 'PointCLIP: Point Cloud Understanding by CLIP'

Renrui Zhang 175 Dec 24, 2022
MediaPipe Kullanarak İleri Seviye Bilgisayarla Görü

MediaPipe Kullanarak İleri Seviye Bilgisayarla Görü

Burak Bagatarhan 12 Mar 29, 2022
Crowd-Kit is a powerful Python library that implements commonly-used aggregation methods for crowdsourced annotation and offers the relevant metrics and datasets

Crowd-Kit: Computational Quality Control for Crowdsourcing Documentation Crowd-Kit is a powerful Python library that implements commonly-used aggregat

Toloka 125 Dec 30, 2022
Self-Supervised Learning for Domain Adaptation on Point-Clouds

Self-Supervised Learning for Domain Adaptation on Point-Clouds Introduction Self-supervised learning (SSL) allows to learn useful representations from

Idan Achituve 66 Dec 20, 2022