From ec7d50bf34bf622b29ecbe041cfe9d730c29f563 Mon Sep 17 00:00:00 2001 From: Leo Auri Date: Tue, 16 Sep 2025 11:46:29 +0200 Subject: [PATCH 1/4] Make pool and block layers importable --- example.py | 4 +- pyproject.toml | 35 ++++++ models/s4/lssl.md => s4/__init__.py | 0 {checkpoints => s4/checkpoints}/README.md | 0 .../checkpoints}/convert_pl_to_pt.py | 2 +- .../checkpoints}/convert_v3_to_v4.py | 4 +- {checkpoints => s4/checkpoints}/evaluate.py | 4 +- {configs => s4/configs}/README.md | 0 {configs => s4/configs}/callbacks/base.yaml | 0 .../configs}/callbacks/checkpoint.yaml | 0 .../callbacks/progressive_resizing.yaml | 0 {configs => s4/configs}/callbacks/rich.yaml | 0 {configs => s4/configs}/callbacks/swa.yaml | 0 {configs => s4/configs}/callbacks/wandb.yaml | 0 {configs => s4/configs}/config.yaml | 0 {configs => s4/configs}/dataset/aan.yaml | 0 {configs => s4/configs}/dataset/adding.yaml | 0 .../configs}/dataset/beethoven.yaml | 0 {configs => s4/configs}/dataset/bidmc.yaml | 0 .../configs}/dataset/celeba-all.yaml | 0 {configs => s4/configs}/dataset/cifar.yaml | 0 {configs => s4/configs}/dataset/copying.yaml | 0 {configs => s4/configs}/dataset/delay.yaml | 0 {configs => s4/configs}/dataset/ecl.yaml | 0 {configs => s4/configs}/dataset/etth.yaml | 0 {configs => s4/configs}/dataset/ettm.yaml | 0 {configs => s4/configs}/dataset/hmdb51.yaml | 0 {configs => s4/configs}/dataset/imagenet.yaml | 0 {configs => s4/configs}/dataset/imdb.yaml | 0 {configs => s4/configs}/dataset/listops.yaml | 0 {configs => s4/configs}/dataset/ljspeech.yaml | 0 {configs => s4/configs}/dataset/mnist.yaml | 0 {configs => s4/configs}/dataset/music.yaml | 0 .../configs}/dataset/pathfinder.yaml | 0 .../configs}/dataset/qautomusic.yaml | 0 .../configs}/dataset/reconstruct.yaml | 0 {configs => s4/configs}/dataset/sc.yaml | 0 {configs => s4/configs}/dataset/sc09.yaml | 0 {configs => s4/configs}/dataset/sc10.yaml | 0 {configs => s4/configs}/dataset/weather.yaml | 0 {configs => s4/configs}/dataset/wt103.yaml | 0 .../configs}/dataset/youtubemix.yaml | 0 {configs => s4/configs}/experiment/README.md | 0 .../experiment/audio/samplernn-beethoven.yaml | 0 .../audio/samplernn-qautomusic.yaml | 0 .../experiment/audio/samplernn-sc09.yaml | 0 .../experiment/audio/samplernn-scg.yaml | 0 .../audio/samplernn-youtubemix.yaml | 0 .../experiment/audio/sashimi-beethoven.yaml | 0 .../experiment/audio/sashimi-sc09-unet.yaml | 0 .../experiment/audio/sashimi-sc09.yaml | 0 .../experiment/audio/sashimi-standalone.yaml | 0 .../experiment/audio/sashimi-youtubemix.yaml | 0 .../experiment/audio/wavenet-beethoven.yaml | 0 .../experiment/audio/wavenet-qautomusic.yaml | 0 .../experiment/audio/wavenet-sc09.yaml | 0 .../experiment/audio/wavenet-youtubemix.yaml | 0 {configs => s4/configs}/experiment/base.yaml | 0 .../experiment/bidmc/ckconv-bidmc.yaml | 0 .../experiment/bidmc/resnet-bidmc.yaml | 0 .../experiment/bidmc/s4-bidmc-ablation.yaml | 0 .../configs}/experiment/bidmc/s4-bidmc.yaml | 0 .../experiment/cifar/cnn-cifar-2d.yaml | 0 .../experiment/cifar/resnet-cifar.yaml | 0 .../experiment/cifar/s4-cifar-ablation.yaml | 0 .../configs}/experiment/cifar/s4-cifar.yaml | 0 .../experiment/cifar/s4d-minimal-cifar.yaml | 0 .../forecasting/s4-informer-ecl.yaml | 0 .../forecasting/s4-informer-etth.yaml | 0 .../forecasting/s4-informer-ettm.yaml | 0 .../forecasting/s4-informer-weather.yaml | 0 .../configs}/experiment/lm/s4-wt103.yaml | 0 .../experiment/lm/transformer-wt103.yaml | 0 .../configs}/experiment/lra/lra-cifar.yaml | 0 .../configs}/experiment/lra/lra-listops.yaml | 0 .../experiment/lra/old/s4-lra-aan.yaml | 0 .../experiment/lra/old/s4-lra-cifar.yaml | 0 .../experiment/lra/old/s4-lra-imdb.yaml | 0 .../experiment/lra/old/s4-lra-listops.yaml | 0 .../experiment/lra/old/s4-lra-pathfinder.yaml | 0 .../experiment/lra/old/s4-lra-pathx.yaml | 0 .../experiment/lra/old/v3-s4-aan.yaml | 0 .../experiment/lra/old/v3-s4-cifar.yaml | 0 .../experiment/lra/old/v3-s4-imdb-small.yaml | 0 .../experiment/lra/old/v3-s4-imdb.yaml | 0 .../lra/old/v3-s4-listops-small.yaml | 0 .../experiment/lra/old/v3-s4-listops.yaml | 0 .../lra/old/v3-s4-pathfinder-small.yaml | 0 .../experiment/lra/old/v3-s4-pathfinder.yaml | 0 .../experiment/lra/old/v3-s4-pathx-small.yaml | 0 .../experiment/lra/old/v3-s4-pathx.yaml | 0 .../configs}/experiment/lra/resnet-pathx.yaml | 0 .../configs}/experiment/lra/s4-aan.yaml | 0 .../configs}/experiment/lra/s4-cifar.yaml | 0 .../configs}/experiment/lra/s4-imdb.yaml | 0 .../configs}/experiment/lra/s4-listops.yaml | 0 .../experiment/lra/s4-pathfinder.yaml | 0 .../configs}/experiment/lra/s4-pathx.yaml | 0 .../experiment/mega/lra-image/README.md | 0 .../mega/lra-image/large-ema-with-s4.yaml | 0 .../experiment/mega/lra-image/large-ema.yaml | 0 .../lra-image/large-mega-ema-with-s4.yaml | 0 .../mega/lra-image/large-mega-ema.yaml | 0 .../mega/lra-image/large-mega-s4d-real.yaml | 0 .../mega/lra-image/large-mega-s4d.yaml | 0 .../mega/lra-image/large-s4d-real.yaml | 0 .../experiment/mega/lra-image/large-s4d.yaml | 0 .../mega_ablations_10000_warmup_all.pdf | Bin .../mega_ablations_1000_warmup_all.pdf | Bin .../lra-image/mega_ablations_mega_repo.pdf | Bin .../mega/lra-image/small-ema-with-s4.yaml | 0 .../experiment/mega/lra-image/small-ema.yaml | 0 .../lra-image/small-mega-ema-with-s4.yaml | 0 .../mega/lra-image/small-mega-ema.yaml | 0 .../mega/lra-image/small-mega-s4d-real.yaml | 0 .../mega/lra-image/small-mega-s4d.yaml | 0 .../mega/lra-image/small-s4d-real.yaml | 0 .../experiment/mega/lra-image/small-s4d.yaml | 0 {configs => s4/configs}/experiment/rnn.yaml | 0 .../configs}/experiment/s4nd/README.md | 0 .../s4nd/celeba/convnext-celeba-all.yaml | 0 .../s4nd/celeba/convnext-s4nd-celeba-all.yaml | 0 .../experiment/s4nd/cifar/cnn-cifar-2d.yaml | 0 .../s4nd/cifar/s4-cifar-2d-16x16.yaml | 0 .../experiment/s4nd/cifar/s4-cifar-2d.yaml | 0 .../convnext/convnext_timm_tiny_imagenet.yaml | 0 .../convnext_timm_tiny_inflate3d_hmdb.yaml | 0 ...onvnext_timm_tiny_inflate3d_s4nd_hmdb.yaml | 0 .../convnext_timm_tiny_s4nd_imagenet.yaml | 0 .../experiment/s4nd/progres/cnn-cifar-2d.yaml | 0 .../experiment/s4nd/progres/s4-cifar-2d.yaml | 0 .../s4nd/vit/vit_b_16_imagenet.yaml | 0 .../s4nd/vit/vit_b_16_s4_imagenet_v2.yaml | 0 .../configs}/experiment/sc/convnet-sc.yaml | 0 .../configs}/experiment/sc/resnet-sc.yaml | 0 .../experiment/sc/s4-sc-ablation.yaml | 0 .../configs}/experiment/sc/s4-sc.yaml | 0 .../experiment/sc/transformer-sc.yaml | 0 .../experiment/synthetic/s4-copying.yaml | 0 .../experiment/synthetic/s4-delay.yaml | 0 .../experiment/synthetic/s4-reconstruct.yaml | 0 {configs => s4/configs}/generate.yaml | 0 {configs => s4/configs}/loader/default.yaml | 0 .../configs}/loader/imresolution.yaml | 0 {configs => s4/configs}/loader/lm.yaml | 0 .../configs}/loader/resolution.yaml | 0 {configs => s4/configs}/loader/tbptt.yaml | 0 {configs => s4/configs}/model/README.md | 0 {configs => s4/configs}/model/base.yaml | 0 .../configs}/model/baseline/ckconv.yaml | 0 .../configs}/model/baseline/lipschitzrnn.yaml | 0 .../configs}/model/baseline/lstm.yaml | 0 .../configs}/model/baseline/odelstm.yaml | 0 .../configs}/model/baseline/resnet2d.yaml | 0 .../configs}/model/baseline/samplernn.yaml | 0 .../model/baseline/stackedrnn_baseline.yaml | 0 .../configs}/model/baseline/unicornn.yaml | 0 .../configs}/model/baseline/wavenet.yaml | 0 {configs => s4/configs}/model/convnet1d.yaml | 0 {configs => s4/configs}/model/convnet2d.yaml | 0 .../configs}/model/layer/cell/exprnn.yaml | 0 .../configs}/model/layer/cell/goru.yaml | 0 .../configs}/model/layer/cell/gru.yaml | 0 .../model/layer/cell/hippo-glagt.yaml | 0 .../configs}/model/layer/cell/hippo-lagt.yaml | 0 .../configs}/model/layer/cell/hippo-legs.yaml | 0 .../configs}/model/layer/cell/hippo-legt.yaml | 0 .../model/layer/cell/hippo-timestamp.yaml | 0 .../configs}/model/layer/cell/lmu.yaml | 0 .../configs}/model/layer/cell/rnn.yaml | 0 .../configs}/model/layer/cell/sru.yaml | 0 .../configs}/model/layer/conv1d.yaml | 0 .../configs}/model/layer/conv2d.yaml | 0 {configs => s4/configs}/model/layer/ff.yaml | 0 {configs => s4/configs}/model/layer/id.yaml | 0 {configs => s4/configs}/model/layer/lssl.yaml | 0 {configs => s4/configs}/model/layer/lstm.yaml | 0 {configs => s4/configs}/model/layer/mega.yaml | 0 {configs => s4/configs}/model/layer/mha.yaml | 0 .../configs}/model/layer/performer.yaml | 0 {configs => s4/configs}/model/layer/rnn.yaml | 0 {configs => s4/configs}/model/layer/s4.yaml | 0 {configs => s4/configs}/model/layer/s4d.yaml | 0 .../configs}/model/layer/s4d_example.yaml | 0 {configs => s4/configs}/model/layer/s4ff.yaml | 0 {configs => s4/configs}/model/layer/s4nd.yaml | 0 .../configs}/model/layer/s4s4ff.yaml | 0 {configs => s4/configs}/model/layer/sru.yaml | 0 .../configs}/model/layer/standalone.yaml | 0 .../configs}/model/layer/transformer.yaml | 0 {configs => s4/configs}/model/layer/vit.yaml | 0 {configs => s4/configs}/model/mega.yaml | 0 .../configs}/model/nonaka/inception.yaml | 0 .../configs}/model/nonaka/resnet.yaml | 0 .../configs}/model/nonaka/xresnet.yaml | 0 {configs => s4/configs}/model/s4.yaml | 0 .../configs}/model/sashimi-standalone.yaml | 0 .../configs}/model/sashimi-transformer.yaml | 0 {configs => s4/configs}/model/sashimi.yaml | 0 .../configs}/model/transformer.yaml | 0 {configs => s4/configs}/model/unet.yaml | 0 {configs => s4/configs}/model/vit/vit.yaml | 0 .../configs}/model/vit/vit_b_16.yaml | 0 .../configs}/model/vit/vit_s_16.yaml | 0 {configs => s4/configs}/optimizer/adam.yaml | 0 {configs => s4/configs}/optimizer/adamw.yaml | 0 {configs => s4/configs}/optimizer/lamb.yaml | 0 {configs => s4/configs}/optimizer/sgd.yaml | 0 {configs => s4/configs}/pipeline/aan.yaml | 0 {configs => s4/configs}/pipeline/adding.yaml | 0 .../configs}/pipeline/celeba-all-2d.yaml | 0 .../configs}/pipeline/cifar-2d.yaml | 0 {configs => s4/configs}/pipeline/cifar.yaml | 0 {configs => s4/configs}/pipeline/copying.yaml | 0 {configs => s4/configs}/pipeline/delay.yaml | 0 {configs => s4/configs}/pipeline/ema.yaml | 0 .../configs}/pipeline/hmdb51_convnext.yaml | 0 .../configs}/pipeline/imagenet.yaml | 0 {configs => s4/configs}/pipeline/imdb.yaml | 0 .../configs}/pipeline/informer.yaml | 0 {configs => s4/configs}/pipeline/listops.yaml | 0 {configs => s4/configs}/pipeline/mnist.yaml | 0 .../configs}/pipeline/pathfinder.yaml | 0 {configs => s4/configs}/pipeline/pathx.yaml | 0 .../configs}/pipeline/reconstruct.yaml | 0 {configs => s4/configs}/pipeline/sc.yaml | 0 {configs => s4/configs}/pipeline/wt103.yaml | 0 .../configs}/scheduler/constant.yaml | 0 .../configs}/scheduler/constant_warmup.yaml | 0 {configs => s4/configs}/scheduler/cosine.yaml | 0 .../configs}/scheduler/cosine_warmup.yaml | 0 .../configs}/scheduler/linear_warmup.yaml | 0 .../configs}/scheduler/multistep.yaml | 0 .../configs}/scheduler/plateau.yaml | 0 {configs => s4/configs}/scheduler/step.yaml | 0 .../configs}/scheduler/timm_cosine.yaml | 0 {configs => s4/configs}/task/forecasting.yaml | 0 {configs => s4/configs}/task/lm.yaml | 0 .../task/multiclass_classification.yaml | 0 .../task/multilabel_classification.yaml | 0 {configs => s4/configs}/task/regression.yaml | 0 {configs => s4/configs}/task/video.yaml | 0 {configs => s4/configs}/trainer/debug.yaml | 0 {configs => s4/configs}/trainer/default.yaml | 0 {configs => s4/configs}/trainer/lm.yaml | 0 s4/extensions/__init__.py | 0 .../extensions}/kernels/README.md | 0 s4/extensions/kernels/__init__.py | 0 .../extensions}/kernels/benchmark_cauchy.py | 0 .../kernels/benchmark_cauchy_tune.py | 0 .../extensions}/kernels/cauchy.cpp | 0 .../extensions}/kernels/cauchy.py | 0 .../extensions}/kernels/cauchy_cuda.cu | 0 {extensions => s4/extensions}/kernels/map.h | 0 .../extensions}/kernels/setup.py | 0 .../extensions}/kernels/test_cauchy.py | 0 .../extensions}/kernels/test_vandermonde.py | 2 +- .../extensions}/kernels/tune_cauchy.py | 0 .../extensions}/kernels/tune_cauchy.sh | 0 .../extensions}/kernels/tuner.py | 0 .../extensions}/kernels/tuning_setup.py | 0 .../extensions}/kernels/vandermonde.py | 0 generate.py => s4/generate.py | 8 +- {models => s4/models}/README.md | 0 s4/models/__init__.py | 0 {models => s4/models}/dss/README.md | 0 s4/models/dss/__init__.py | 0 {models => s4/models}/hippo/README.md | 0 s4/models/hippo/__init__.py | 0 {models => s4/models}/related/README.md | 0 s4/models/related/__init__.py | 0 {models => s4/models}/s4/README.md | 0 s4/models/s4/__init__.py | 0 {models => s4/models}/s4/experiments.md | 0 s4/models/s4/lssl.md | 0 {models => s4/models}/s4/s4.py | 4 +- {models => s4/models}/s4/s4d.py | 2 +- {models => s4/models}/s4nd/README.md | 0 s4/models/s4nd/__init__.py | 0 {models => s4/models}/sashimi/README.md | 0 s4/models/sashimi/__init__.py | 0 {models => s4/models}/sashimi/metrics.py | 0 s4/models/sashimi/mturk/__init__.py | 0 .../sashimi/mturk/mos/MTurk SC09 MOS.ipynb | 0 .../mturk/mos/MTurk YouTubeMix MOS.ipynb | 0 s4/models/sashimi/mturk/mos/__init__.py | 0 .../models}/sashimi/mturk/prepare_sc09.py | 0 .../models}/sashimi/mturk/template_music.py | 0 .../models}/sashimi/mturk/template_speech.py | 0 .../sashimi/mturk/turk_create_batch.py | 0 {models => s4/models}/sashimi/sashimi.py | 2 +- .../models}/sashimi/sc09_classifier/README.md | 0 s4/models/sashimi/sc09_classifier/__init__.py | 0 .../sc09_classifier/datasets/__init__.py | 0 .../datasets/speech_commands/__init__.py | 0 .../datasets/speech_commands/split_dataset.py | 0 .../download_speech_commands_dataset.sh | 0 .../sc09_classifier/models/__init__.py | 0 .../sashimi/sc09_classifier/models/resnext.py | 0 .../sashimi/sc09_classifier/requirements.txt | 0 .../speech_commands_dataset.py | 0 .../sc09_classifier/test_speech_commands.py | 0 .../sc09_classifier/train_speech_commands.py | 2 +- .../sc09_classifier/transforms/__init__.py | 0 .../transforms/transforms_stft.py | 0 .../transforms/transforms_wav.py | 0 s4/src/__init__.py | 0 {src => s4/src}/callbacks/norms.py | 0 {src => s4/src}/callbacks/params.py | 0 .../src}/callbacks/progressive_resizing.py | 4 +- {src => s4/src}/callbacks/timer.py | 0 {src => s4/src}/callbacks/wandb.py | 0 {src => s4/src}/dataloaders/README.md | 0 {src => s4/src}/dataloaders/__init__.py | 0 {src => s4/src}/dataloaders/audio.py | 16 +-- {src => s4/src}/dataloaders/base.py | 2 +- {src => s4/src}/dataloaders/basic.py | 6 +- .../src}/dataloaders/datasets/adding.py | 0 .../src}/dataloaders/datasets/celeba.py | 0 .../src}/dataloaders/datasets/copying.py | 2 +- {src => s4/src}/dataloaders/datasets/delay.py | 2 +- {src => s4/src}/dataloaders/datasets/music.py | 0 .../src}/dataloaders/datasets/reconstruct.py | 2 +- {src => s4/src}/dataloaders/datasets/sc.py | 0 {src => s4/src}/dataloaders/et.py | 2 +- {src => s4/src}/dataloaders/lm.py | 10 +- {src => s4/src}/dataloaders/lra.py | 2 +- .../src}/dataloaders/prepare/bidmc/README.md | 0 .../src}/dataloaders/prepare/bidmc/data.ipynb | 0 .../dataloaders/prepare/bidmc/data_loader.py | 0 .../dataloaders/prepare/bidmc/process_data.py | 0 {src => s4/src}/dataloaders/synthetic.py | 4 +- {src => s4/src}/dataloaders/ts.py | 4 +- .../dataloaders/utils/cifar_augmentations.py | 0 {src => s4/src}/dataloaders/utils/signal.py | 0 .../src}/dataloaders/utils/timm_mixup.py | 0 .../src}/dataloaders/utils/video_loader.py | 0 .../src}/dataloaders/utils/vocabulary.py | 2 +- {src => s4/src}/dataloaders/vision.py | 6 +- {src => s4/src}/models/README.md | 0 s4/src/models/__init__.py | 0 {src => s4/src}/models/baselines/ckconv.py | 0 .../src}/models/baselines/convnext_timm.py | 6 +- {src => s4/src}/models/baselines/gru.py | 4 +- .../src}/models/baselines/lipschitzrnn.py | 2 +- {src => s4/src}/models/baselines/lstm.py | 4 +- .../src}/models/baselines/nonaka/LICENSE | 0 .../src}/models/baselines/nonaka/README.md | 0 .../models/baselines/nonaka/basic_conv1d.py | 0 .../src}/models/baselines/nonaka/inception.py | 2 +- .../src}/models/baselines/nonaka/resnet.py | 2 +- .../src}/models/baselines/nonaka/xresnet.py | 2 +- {src => s4/src}/models/baselines/nrde.py | 0 {src => s4/src}/models/baselines/odelstm.py | 0 {src => s4/src}/models/baselines/resnet.py | 0 .../src}/models/baselines/resnet_timm.py | 0 {src => s4/src}/models/baselines/samplernn.py | 10 +- .../src}/models/baselines/transformer.py | 0 {src => s4/src}/models/baselines/unicornn.py | 2 +- {src => s4/src}/models/baselines/vit.py | 2 +- {src => s4/src}/models/baselines/vit_all.py | 8 +- {src => s4/src}/models/baselines/wavenet.py | 2 +- {src => s4/src}/models/functional/cauchy.py | 0 {src => s4/src}/models/functional/krylov.py | 2 +- {src => s4/src}/models/functional/toeplitz.py | 0 {src => s4/src}/models/functional/unroll.py | 4 +- .../src}/models/functional/vandermonde.py | 0 {src => s4/src}/models/hippo/hippo.py | 0 {src => s4/src}/models/hippo/transition.py | 8 +- .../src}/models/hippo/visualizations.py | 2 +- {src => s4/src}/models/nn/__init__.py | 0 {src => s4/src}/models/nn/activation.py | 0 {src => s4/src}/models/nn/adaptive_softmax.py | 2 +- {src => s4/src}/models/nn/dropout.py | 0 {src => s4/src}/models/nn/dxt.py | 0 {src => s4/src}/models/nn/exprnn/README.md | 0 {src => s4/src}/models/nn/exprnn/expm32.py | 0 .../src}/models/nn/exprnn/initialization.py | 0 .../src}/models/nn/exprnn/orthogonal.py | 2 +- .../src}/models/nn/exprnn/parametrization.py | 0 .../src}/models/nn/exprnn/trivializations.py | 0 {src => s4/src}/models/nn/gate.py | 0 {src => s4/src}/models/nn/initialization.py | 0 {src => s4/src}/models/nn/linear.py | 2 +- {src => s4/src}/models/nn/normalization.py | 0 {src => s4/src}/models/nn/orthogonal.py | 0 {src => s4/src}/models/nn/residual.py | 0 {src => s4/src}/models/nn/utils.py | 0 {src => s4/src}/models/s4/README.md | 0 {src => s4/src}/models/sequence/README.md | 0 {src => s4/src}/models/sequence/__init__.py | 0 .../src}/models/sequence/attention/linear.py | 4 +- .../src}/models/sequence/attention/mha.py | 4 +- .../models/sequence/attention/performer.py | 0 s4/src/models/sequence/backbones/__init__.py | 0 .../src}/models/sequence/backbones/block.py | 12 +- .../src}/models/sequence/backbones/model.py | 8 +- .../src}/models/sequence/backbones/sashimi.py | 6 +- .../src}/models/sequence/backbones/unet.py | 8 +- {src => s4/src}/models/sequence/base.py | 0 .../src}/models/sequence/convs/conv1d.py | 6 +- .../src}/models/sequence/convs/conv2d.py | 4 +- .../src}/models/sequence/kernels/__init__.py | 0 .../src}/models/sequence/kernels/dplr.py | 4 +- .../src}/models/sequence/kernels/fftconv.py | 6 +- .../src}/models/sequence/kernels/kernel.py | 2 +- .../src}/models/sequence/kernels/ssm.py | 24 ++-- s4/src/models/sequence/modules/__init__.py | 0 .../src}/models/sequence/modules/ffn.py | 4 +- .../src}/models/sequence/modules/lssl.py | 12 +- .../src}/models/sequence/modules/megablock.py | 6 +- .../src}/models/sequence/modules/pool.py | 4 +- .../src}/models/sequence/modules/s4block.py | 12 +- .../src}/models/sequence/modules/s4nd.py | 12 +- .../src}/models/sequence/rnns/__init__.py | 0 .../models/sequence/rnns/cells/__init__.py | 0 .../src}/models/sequence/rnns/cells/basic.py | 8 +- .../src}/models/sequence/rnns/cells/hippo.py | 4 +- .../src}/models/sequence/rnns/cells/memory.py | 6 +- .../models/sequence/rnns/cells/minimalrnn.py | 6 +- .../models/sequence/rnns/cells/timestamp.py | 4 +- {src => s4/src}/models/sequence/rnns/qrnn.py | 10 +- {src => s4/src}/models/sequence/rnns/rnn.py | 6 +- {src => s4/src}/models/sequence/rnns/sru.py | 8 +- {src => s4/src}/tasks/decoders.py | 4 +- {src => s4/src}/tasks/encoders.py | 10 +- {src => s4/src}/tasks/metrics.py | 0 {src => s4/src}/tasks/tasks.py | 10 +- {src => s4/src}/utils/__init__.py | 0 {src => s4/src}/utils/config.py | 0 {src => s4/src}/utils/distributed.py | 0 {src => s4/src}/utils/optim/ema.py | 0 {src => s4/src}/utils/optim/lamb.py | 0 {src => s4/src}/utils/optim/schedulers.py | 0 {src => s4/src}/utils/optim_groups.py | 0 {src => s4/src}/utils/permutations.py | 0 s4/src/utils/registry.py | 106 ++++++++++++++++++ {src => s4/src}/utils/train.py | 2 +- train.py => s4/train.py | 16 +-- src/utils/registry.py | 106 ------------------ 440 files changed, 340 insertions(+), 305 deletions(-) create mode 100644 pyproject.toml rename models/s4/lssl.md => s4/__init__.py (100%) rename {checkpoints => s4/checkpoints}/README.md (100%) rename {checkpoints => s4/checkpoints}/convert_pl_to_pt.py (87%) rename {checkpoints => s4/checkpoints}/convert_v3_to_v4.py (98%) rename {checkpoints => s4/checkpoints}/evaluate.py (97%) rename {configs => s4/configs}/README.md (100%) rename {configs => s4/configs}/callbacks/base.yaml (100%) rename {configs => s4/configs}/callbacks/checkpoint.yaml (100%) rename {configs => s4/configs}/callbacks/progressive_resizing.yaml (100%) rename {configs => s4/configs}/callbacks/rich.yaml (100%) rename {configs => s4/configs}/callbacks/swa.yaml (100%) rename {configs => s4/configs}/callbacks/wandb.yaml (100%) rename {configs => s4/configs}/config.yaml (100%) rename {configs => s4/configs}/dataset/aan.yaml (100%) rename {configs => s4/configs}/dataset/adding.yaml (100%) rename {configs => s4/configs}/dataset/beethoven.yaml (100%) rename {configs => s4/configs}/dataset/bidmc.yaml (100%) rename {configs => s4/configs}/dataset/celeba-all.yaml (100%) rename {configs => s4/configs}/dataset/cifar.yaml (100%) rename {configs => s4/configs}/dataset/copying.yaml (100%) rename {configs => s4/configs}/dataset/delay.yaml (100%) rename {configs => s4/configs}/dataset/ecl.yaml (100%) rename {configs => s4/configs}/dataset/etth.yaml (100%) rename {configs => s4/configs}/dataset/ettm.yaml (100%) rename {configs => s4/configs}/dataset/hmdb51.yaml (100%) rename {configs => s4/configs}/dataset/imagenet.yaml (100%) rename {configs => s4/configs}/dataset/imdb.yaml (100%) rename {configs => s4/configs}/dataset/listops.yaml (100%) rename {configs => s4/configs}/dataset/ljspeech.yaml (100%) rename {configs => s4/configs}/dataset/mnist.yaml (100%) rename {configs => s4/configs}/dataset/music.yaml (100%) rename {configs => s4/configs}/dataset/pathfinder.yaml (100%) rename {configs => s4/configs}/dataset/qautomusic.yaml (100%) rename {configs => s4/configs}/dataset/reconstruct.yaml (100%) rename {configs => s4/configs}/dataset/sc.yaml (100%) rename {configs => s4/configs}/dataset/sc09.yaml (100%) rename {configs => s4/configs}/dataset/sc10.yaml (100%) rename {configs => s4/configs}/dataset/weather.yaml (100%) rename {configs => s4/configs}/dataset/wt103.yaml (100%) rename {configs => s4/configs}/dataset/youtubemix.yaml (100%) rename {configs => s4/configs}/experiment/README.md (100%) rename {configs => s4/configs}/experiment/audio/samplernn-beethoven.yaml (100%) rename {configs => s4/configs}/experiment/audio/samplernn-qautomusic.yaml (100%) rename {configs => s4/configs}/experiment/audio/samplernn-sc09.yaml (100%) rename {configs => s4/configs}/experiment/audio/samplernn-scg.yaml (100%) rename {configs => s4/configs}/experiment/audio/samplernn-youtubemix.yaml (100%) rename {configs => s4/configs}/experiment/audio/sashimi-beethoven.yaml (100%) rename {configs => s4/configs}/experiment/audio/sashimi-sc09-unet.yaml (100%) rename {configs => s4/configs}/experiment/audio/sashimi-sc09.yaml (100%) rename {configs => s4/configs}/experiment/audio/sashimi-standalone.yaml (100%) rename {configs => s4/configs}/experiment/audio/sashimi-youtubemix.yaml (100%) rename {configs => s4/configs}/experiment/audio/wavenet-beethoven.yaml (100%) rename {configs => s4/configs}/experiment/audio/wavenet-qautomusic.yaml (100%) rename {configs => s4/configs}/experiment/audio/wavenet-sc09.yaml (100%) rename {configs => s4/configs}/experiment/audio/wavenet-youtubemix.yaml (100%) rename {configs => s4/configs}/experiment/base.yaml (100%) rename {configs => s4/configs}/experiment/bidmc/ckconv-bidmc.yaml (100%) rename {configs => s4/configs}/experiment/bidmc/resnet-bidmc.yaml (100%) rename {configs => s4/configs}/experiment/bidmc/s4-bidmc-ablation.yaml (100%) rename {configs => s4/configs}/experiment/bidmc/s4-bidmc.yaml (100%) rename {configs => s4/configs}/experiment/cifar/cnn-cifar-2d.yaml (100%) rename {configs => s4/configs}/experiment/cifar/resnet-cifar.yaml (100%) rename {configs => s4/configs}/experiment/cifar/s4-cifar-ablation.yaml (100%) rename {configs => s4/configs}/experiment/cifar/s4-cifar.yaml (100%) rename {configs => s4/configs}/experiment/cifar/s4d-minimal-cifar.yaml (100%) rename {configs => s4/configs}/experiment/forecasting/s4-informer-ecl.yaml (100%) rename {configs => s4/configs}/experiment/forecasting/s4-informer-etth.yaml (100%) rename {configs => s4/configs}/experiment/forecasting/s4-informer-ettm.yaml (100%) rename {configs => s4/configs}/experiment/forecasting/s4-informer-weather.yaml (100%) rename {configs => s4/configs}/experiment/lm/s4-wt103.yaml (100%) rename {configs => s4/configs}/experiment/lm/transformer-wt103.yaml (100%) rename {configs => s4/configs}/experiment/lra/lra-cifar.yaml (100%) rename {configs => s4/configs}/experiment/lra/lra-listops.yaml (100%) rename {configs => s4/configs}/experiment/lra/old/s4-lra-aan.yaml (100%) rename {configs => s4/configs}/experiment/lra/old/s4-lra-cifar.yaml (100%) rename {configs => s4/configs}/experiment/lra/old/s4-lra-imdb.yaml (100%) rename {configs => s4/configs}/experiment/lra/old/s4-lra-listops.yaml (100%) rename {configs => s4/configs}/experiment/lra/old/s4-lra-pathfinder.yaml (100%) rename {configs => s4/configs}/experiment/lra/old/s4-lra-pathx.yaml (100%) rename {configs => s4/configs}/experiment/lra/old/v3-s4-aan.yaml (100%) rename {configs => s4/configs}/experiment/lra/old/v3-s4-cifar.yaml (100%) rename {configs => s4/configs}/experiment/lra/old/v3-s4-imdb-small.yaml (100%) rename {configs => s4/configs}/experiment/lra/old/v3-s4-imdb.yaml (100%) rename {configs => s4/configs}/experiment/lra/old/v3-s4-listops-small.yaml (100%) rename {configs => s4/configs}/experiment/lra/old/v3-s4-listops.yaml (100%) rename {configs => s4/configs}/experiment/lra/old/v3-s4-pathfinder-small.yaml (100%) rename {configs => s4/configs}/experiment/lra/old/v3-s4-pathfinder.yaml (100%) rename {configs => s4/configs}/experiment/lra/old/v3-s4-pathx-small.yaml (100%) rename {configs => s4/configs}/experiment/lra/old/v3-s4-pathx.yaml (100%) rename {configs => s4/configs}/experiment/lra/resnet-pathx.yaml (100%) rename {configs => s4/configs}/experiment/lra/s4-aan.yaml (100%) rename {configs => s4/configs}/experiment/lra/s4-cifar.yaml (100%) rename {configs => s4/configs}/experiment/lra/s4-imdb.yaml (100%) rename {configs => s4/configs}/experiment/lra/s4-listops.yaml (100%) rename {configs => s4/configs}/experiment/lra/s4-pathfinder.yaml (100%) rename {configs => s4/configs}/experiment/lra/s4-pathx.yaml (100%) rename {configs => s4/configs}/experiment/mega/lra-image/README.md (100%) rename {configs => s4/configs}/experiment/mega/lra-image/large-ema-with-s4.yaml (100%) rename {configs => s4/configs}/experiment/mega/lra-image/large-ema.yaml (100%) rename {configs => s4/configs}/experiment/mega/lra-image/large-mega-ema-with-s4.yaml (100%) rename {configs => s4/configs}/experiment/mega/lra-image/large-mega-ema.yaml (100%) rename {configs => s4/configs}/experiment/mega/lra-image/large-mega-s4d-real.yaml (100%) rename {configs => s4/configs}/experiment/mega/lra-image/large-mega-s4d.yaml (100%) rename {configs => s4/configs}/experiment/mega/lra-image/large-s4d-real.yaml (100%) rename {configs => s4/configs}/experiment/mega/lra-image/large-s4d.yaml (100%) rename {configs => s4/configs}/experiment/mega/lra-image/mega_ablations_10000_warmup_all.pdf (100%) rename {configs => s4/configs}/experiment/mega/lra-image/mega_ablations_1000_warmup_all.pdf (100%) rename {configs => s4/configs}/experiment/mega/lra-image/mega_ablations_mega_repo.pdf (100%) rename {configs => s4/configs}/experiment/mega/lra-image/small-ema-with-s4.yaml (100%) rename {configs => s4/configs}/experiment/mega/lra-image/small-ema.yaml (100%) rename {configs => s4/configs}/experiment/mega/lra-image/small-mega-ema-with-s4.yaml (100%) rename {configs => s4/configs}/experiment/mega/lra-image/small-mega-ema.yaml (100%) rename {configs => s4/configs}/experiment/mega/lra-image/small-mega-s4d-real.yaml (100%) rename {configs => s4/configs}/experiment/mega/lra-image/small-mega-s4d.yaml (100%) rename {configs => s4/configs}/experiment/mega/lra-image/small-s4d-real.yaml (100%) rename {configs => s4/configs}/experiment/mega/lra-image/small-s4d.yaml (100%) rename {configs => s4/configs}/experiment/rnn.yaml (100%) rename {configs => s4/configs}/experiment/s4nd/README.md (100%) rename {configs => s4/configs}/experiment/s4nd/celeba/convnext-celeba-all.yaml (100%) rename {configs => s4/configs}/experiment/s4nd/celeba/convnext-s4nd-celeba-all.yaml (100%) rename {configs => s4/configs}/experiment/s4nd/cifar/cnn-cifar-2d.yaml (100%) rename {configs => s4/configs}/experiment/s4nd/cifar/s4-cifar-2d-16x16.yaml (100%) rename {configs => s4/configs}/experiment/s4nd/cifar/s4-cifar-2d.yaml (100%) rename {configs => s4/configs}/experiment/s4nd/convnext/convnext_timm_tiny_imagenet.yaml (100%) rename {configs => s4/configs}/experiment/s4nd/convnext/convnext_timm_tiny_inflate3d_hmdb.yaml (100%) rename {configs => s4/configs}/experiment/s4nd/convnext/convnext_timm_tiny_inflate3d_s4nd_hmdb.yaml (100%) rename {configs => s4/configs}/experiment/s4nd/convnext/convnext_timm_tiny_s4nd_imagenet.yaml (100%) rename {configs => s4/configs}/experiment/s4nd/progres/cnn-cifar-2d.yaml (100%) rename {configs => s4/configs}/experiment/s4nd/progres/s4-cifar-2d.yaml (100%) rename {configs => s4/configs}/experiment/s4nd/vit/vit_b_16_imagenet.yaml (100%) rename {configs => s4/configs}/experiment/s4nd/vit/vit_b_16_s4_imagenet_v2.yaml (100%) rename {configs => s4/configs}/experiment/sc/convnet-sc.yaml (100%) rename {configs => s4/configs}/experiment/sc/resnet-sc.yaml (100%) rename {configs => s4/configs}/experiment/sc/s4-sc-ablation.yaml (100%) rename {configs => s4/configs}/experiment/sc/s4-sc.yaml (100%) rename {configs => s4/configs}/experiment/sc/transformer-sc.yaml (100%) rename {configs => s4/configs}/experiment/synthetic/s4-copying.yaml (100%) rename {configs => s4/configs}/experiment/synthetic/s4-delay.yaml (100%) rename {configs => s4/configs}/experiment/synthetic/s4-reconstruct.yaml (100%) rename {configs => s4/configs}/generate.yaml (100%) rename {configs => s4/configs}/loader/default.yaml (100%) rename {configs => s4/configs}/loader/imresolution.yaml (100%) rename {configs => s4/configs}/loader/lm.yaml (100%) rename {configs => s4/configs}/loader/resolution.yaml (100%) rename {configs => s4/configs}/loader/tbptt.yaml (100%) rename {configs => s4/configs}/model/README.md (100%) rename {configs => s4/configs}/model/base.yaml (100%) rename {configs => s4/configs}/model/baseline/ckconv.yaml (100%) rename {configs => s4/configs}/model/baseline/lipschitzrnn.yaml (100%) rename {configs => s4/configs}/model/baseline/lstm.yaml (100%) rename {configs => s4/configs}/model/baseline/odelstm.yaml (100%) rename {configs => s4/configs}/model/baseline/resnet2d.yaml (100%) rename {configs => s4/configs}/model/baseline/samplernn.yaml (100%) rename {configs => s4/configs}/model/baseline/stackedrnn_baseline.yaml (100%) rename {configs => s4/configs}/model/baseline/unicornn.yaml (100%) rename {configs => s4/configs}/model/baseline/wavenet.yaml (100%) rename {configs => s4/configs}/model/convnet1d.yaml (100%) rename {configs => s4/configs}/model/convnet2d.yaml (100%) rename {configs => s4/configs}/model/layer/cell/exprnn.yaml (100%) rename {configs => s4/configs}/model/layer/cell/goru.yaml (100%) rename {configs => s4/configs}/model/layer/cell/gru.yaml (100%) rename {configs => s4/configs}/model/layer/cell/hippo-glagt.yaml (100%) rename {configs => s4/configs}/model/layer/cell/hippo-lagt.yaml (100%) rename {configs => s4/configs}/model/layer/cell/hippo-legs.yaml (100%) rename {configs => s4/configs}/model/layer/cell/hippo-legt.yaml (100%) rename {configs => s4/configs}/model/layer/cell/hippo-timestamp.yaml (100%) rename {configs => s4/configs}/model/layer/cell/lmu.yaml (100%) rename {configs => s4/configs}/model/layer/cell/rnn.yaml (100%) rename {configs => s4/configs}/model/layer/cell/sru.yaml (100%) rename {configs => s4/configs}/model/layer/conv1d.yaml (100%) rename {configs => s4/configs}/model/layer/conv2d.yaml (100%) rename {configs => s4/configs}/model/layer/ff.yaml (100%) rename {configs => s4/configs}/model/layer/id.yaml (100%) rename {configs => s4/configs}/model/layer/lssl.yaml (100%) rename {configs => s4/configs}/model/layer/lstm.yaml (100%) rename {configs => s4/configs}/model/layer/mega.yaml (100%) rename {configs => s4/configs}/model/layer/mha.yaml (100%) rename {configs => s4/configs}/model/layer/performer.yaml (100%) rename {configs => s4/configs}/model/layer/rnn.yaml (100%) rename {configs => s4/configs}/model/layer/s4.yaml (100%) rename {configs => s4/configs}/model/layer/s4d.yaml (100%) rename {configs => s4/configs}/model/layer/s4d_example.yaml (100%) rename {configs => s4/configs}/model/layer/s4ff.yaml (100%) rename {configs => s4/configs}/model/layer/s4nd.yaml (100%) rename {configs => s4/configs}/model/layer/s4s4ff.yaml (100%) rename {configs => s4/configs}/model/layer/sru.yaml (100%) rename {configs => s4/configs}/model/layer/standalone.yaml (100%) rename {configs => s4/configs}/model/layer/transformer.yaml (100%) rename {configs => s4/configs}/model/layer/vit.yaml (100%) rename {configs => s4/configs}/model/mega.yaml (100%) rename {configs => s4/configs}/model/nonaka/inception.yaml (100%) rename {configs => s4/configs}/model/nonaka/resnet.yaml (100%) rename {configs => s4/configs}/model/nonaka/xresnet.yaml (100%) rename {configs => s4/configs}/model/s4.yaml (100%) rename {configs => s4/configs}/model/sashimi-standalone.yaml (100%) rename {configs => s4/configs}/model/sashimi-transformer.yaml (100%) rename {configs => s4/configs}/model/sashimi.yaml (100%) rename {configs => s4/configs}/model/transformer.yaml (100%) rename {configs => s4/configs}/model/unet.yaml (100%) rename {configs => s4/configs}/model/vit/vit.yaml (100%) rename {configs => s4/configs}/model/vit/vit_b_16.yaml (100%) rename {configs => s4/configs}/model/vit/vit_s_16.yaml (100%) rename {configs => s4/configs}/optimizer/adam.yaml (100%) rename {configs => s4/configs}/optimizer/adamw.yaml (100%) rename {configs => s4/configs}/optimizer/lamb.yaml (100%) rename {configs => s4/configs}/optimizer/sgd.yaml (100%) rename {configs => s4/configs}/pipeline/aan.yaml (100%) rename {configs => s4/configs}/pipeline/adding.yaml (100%) rename {configs => s4/configs}/pipeline/celeba-all-2d.yaml (100%) rename {configs => s4/configs}/pipeline/cifar-2d.yaml (100%) rename {configs => s4/configs}/pipeline/cifar.yaml (100%) rename {configs => s4/configs}/pipeline/copying.yaml (100%) rename {configs => s4/configs}/pipeline/delay.yaml (100%) rename {configs => s4/configs}/pipeline/ema.yaml (100%) rename {configs => s4/configs}/pipeline/hmdb51_convnext.yaml (100%) rename {configs => s4/configs}/pipeline/imagenet.yaml (100%) rename {configs => s4/configs}/pipeline/imdb.yaml (100%) rename {configs => s4/configs}/pipeline/informer.yaml (100%) rename {configs => s4/configs}/pipeline/listops.yaml (100%) rename {configs => s4/configs}/pipeline/mnist.yaml (100%) rename {configs => s4/configs}/pipeline/pathfinder.yaml (100%) rename {configs => s4/configs}/pipeline/pathx.yaml (100%) rename {configs => s4/configs}/pipeline/reconstruct.yaml (100%) rename {configs => s4/configs}/pipeline/sc.yaml (100%) rename {configs => s4/configs}/pipeline/wt103.yaml (100%) rename {configs => s4/configs}/scheduler/constant.yaml (100%) rename {configs => s4/configs}/scheduler/constant_warmup.yaml (100%) rename {configs => s4/configs}/scheduler/cosine.yaml (100%) rename {configs => s4/configs}/scheduler/cosine_warmup.yaml (100%) rename {configs => s4/configs}/scheduler/linear_warmup.yaml (100%) rename {configs => s4/configs}/scheduler/multistep.yaml (100%) rename {configs => s4/configs}/scheduler/plateau.yaml (100%) rename {configs => s4/configs}/scheduler/step.yaml (100%) rename {configs => s4/configs}/scheduler/timm_cosine.yaml (100%) rename {configs => s4/configs}/task/forecasting.yaml (100%) rename {configs => s4/configs}/task/lm.yaml (100%) rename {configs => s4/configs}/task/multiclass_classification.yaml (100%) rename {configs => s4/configs}/task/multilabel_classification.yaml (100%) rename {configs => s4/configs}/task/regression.yaml (100%) rename {configs => s4/configs}/task/video.yaml (100%) rename {configs => s4/configs}/trainer/debug.yaml (100%) rename {configs => s4/configs}/trainer/default.yaml (100%) rename {configs => s4/configs}/trainer/lm.yaml (100%) create mode 100644 s4/extensions/__init__.py rename {extensions => s4/extensions}/kernels/README.md (100%) create mode 100644 s4/extensions/kernels/__init__.py rename {extensions => s4/extensions}/kernels/benchmark_cauchy.py (100%) rename {extensions => s4/extensions}/kernels/benchmark_cauchy_tune.py (100%) rename {extensions => s4/extensions}/kernels/cauchy.cpp (100%) rename {extensions => s4/extensions}/kernels/cauchy.py (100%) rename {extensions => s4/extensions}/kernels/cauchy_cuda.cu (100%) rename {extensions => s4/extensions}/kernels/map.h (100%) rename {extensions => s4/extensions}/kernels/setup.py (100%) rename {extensions => s4/extensions}/kernels/test_cauchy.py (100%) rename {extensions => s4/extensions}/kernels/test_vandermonde.py (96%) rename {extensions => s4/extensions}/kernels/tune_cauchy.py (100%) rename {extensions => s4/extensions}/kernels/tune_cauchy.sh (100%) rename {extensions => s4/extensions}/kernels/tuner.py (100%) rename {extensions => s4/extensions}/kernels/tuning_setup.py (100%) rename {extensions => s4/extensions}/kernels/vandermonde.py (100%) rename generate.py => s4/generate.py (97%) rename {models => s4/models}/README.md (100%) create mode 100644 s4/models/__init__.py rename {models => s4/models}/dss/README.md (100%) create mode 100644 s4/models/dss/__init__.py rename {models => s4/models}/hippo/README.md (100%) create mode 100644 s4/models/hippo/__init__.py rename {models => s4/models}/related/README.md (100%) create mode 100644 s4/models/related/__init__.py rename {models => s4/models}/s4/README.md (100%) create mode 100644 s4/models/s4/__init__.py rename {models => s4/models}/s4/experiments.md (100%) create mode 100644 s4/models/s4/lssl.md rename {models => s4/models}/s4/s4.py (99%) rename {models => s4/models}/s4/s4d.py (98%) rename {models => s4/models}/s4nd/README.md (100%) create mode 100644 s4/models/s4nd/__init__.py rename {models => s4/models}/sashimi/README.md (100%) create mode 100644 s4/models/sashimi/__init__.py rename {models => s4/models}/sashimi/metrics.py (100%) create mode 100644 s4/models/sashimi/mturk/__init__.py rename {models => s4/models}/sashimi/mturk/mos/MTurk SC09 MOS.ipynb (100%) rename {models => s4/models}/sashimi/mturk/mos/MTurk YouTubeMix MOS.ipynb (100%) create mode 100644 s4/models/sashimi/mturk/mos/__init__.py rename {models => s4/models}/sashimi/mturk/prepare_sc09.py (100%) rename {models => s4/models}/sashimi/mturk/template_music.py (100%) rename {models => s4/models}/sashimi/mturk/template_speech.py (100%) rename {models => s4/models}/sashimi/mturk/turk_create_batch.py (100%) rename {models => s4/models}/sashimi/sashimi.py (99%) rename {models => s4/models}/sashimi/sc09_classifier/README.md (100%) create mode 100644 s4/models/sashimi/sc09_classifier/__init__.py create mode 100644 s4/models/sashimi/sc09_classifier/datasets/__init__.py create mode 100644 s4/models/sashimi/sc09_classifier/datasets/speech_commands/__init__.py rename {models => s4/models}/sashimi/sc09_classifier/datasets/speech_commands/split_dataset.py (100%) rename {models => s4/models}/sashimi/sc09_classifier/download_speech_commands_dataset.sh (100%) create mode 100644 s4/models/sashimi/sc09_classifier/models/__init__.py rename {models => s4/models}/sashimi/sc09_classifier/models/resnext.py (100%) rename {models => s4/models}/sashimi/sc09_classifier/requirements.txt (100%) rename {models => s4/models}/sashimi/sc09_classifier/speech_commands_dataset.py (100%) rename {models => s4/models}/sashimi/sc09_classifier/test_speech_commands.py (100%) rename {models => s4/models}/sashimi/sc09_classifier/train_speech_commands.py (99%) rename {models => s4/models}/sashimi/sc09_classifier/transforms/__init__.py (100%) rename {models => s4/models}/sashimi/sc09_classifier/transforms/transforms_stft.py (100%) rename {models => s4/models}/sashimi/sc09_classifier/transforms/transforms_wav.py (100%) create mode 100644 s4/src/__init__.py rename {src => s4/src}/callbacks/norms.py (100%) rename {src => s4/src}/callbacks/params.py (100%) rename {src => s4/src}/callbacks/progressive_resizing.py (98%) rename {src => s4/src}/callbacks/timer.py (100%) rename {src => s4/src}/callbacks/wandb.py (100%) rename {src => s4/src}/dataloaders/README.md (100%) rename {src => s4/src}/dataloaders/__init__.py (100%) rename {src => s4/src}/dataloaders/audio.py (98%) rename {src => s4/src}/dataloaders/base.py (99%) rename {src => s4/src}/dataloaders/basic.py (97%) rename {src => s4/src}/dataloaders/datasets/adding.py (100%) rename {src => s4/src}/dataloaders/datasets/celeba.py (100%) rename {src => s4/src}/dataloaders/datasets/copying.py (99%) rename {src => s4/src}/dataloaders/datasets/delay.py (96%) rename {src => s4/src}/dataloaders/datasets/music.py (100%) rename {src => s4/src}/dataloaders/datasets/reconstruct.py (95%) rename {src => s4/src}/dataloaders/datasets/sc.py (100%) rename {src => s4/src}/dataloaders/et.py (99%) rename {src => s4/src}/dataloaders/lm.py (98%) rename {src => s4/src}/dataloaders/lra.py (99%) rename {src => s4/src}/dataloaders/prepare/bidmc/README.md (100%) rename {src => s4/src}/dataloaders/prepare/bidmc/data.ipynb (100%) rename {src => s4/src}/dataloaders/prepare/bidmc/data_loader.py (100%) rename {src => s4/src}/dataloaders/prepare/bidmc/process_data.py (100%) rename {src => s4/src}/dataloaders/synthetic.py (98%) rename {src => s4/src}/dataloaders/ts.py (99%) rename {src => s4/src}/dataloaders/utils/cifar_augmentations.py (100%) rename {src => s4/src}/dataloaders/utils/signal.py (100%) rename {src => s4/src}/dataloaders/utils/timm_mixup.py (100%) rename {src => s4/src}/dataloaders/utils/video_loader.py (100%) rename {src => s4/src}/dataloaders/utils/vocabulary.py (99%) rename {src => s4/src}/dataloaders/vision.py (99%) rename {src => s4/src}/models/README.md (100%) create mode 100644 s4/src/models/__init__.py rename {src => s4/src}/models/baselines/ckconv.py (100%) rename {src => s4/src}/models/baselines/convnext_timm.py (99%) rename {src => s4/src}/models/baselines/gru.py (94%) rename {src => s4/src}/models/baselines/lipschitzrnn.py (99%) rename {src => s4/src}/models/baselines/lstm.py (96%) rename {src => s4/src}/models/baselines/nonaka/LICENSE (100%) rename {src => s4/src}/models/baselines/nonaka/README.md (100%) rename {src => s4/src}/models/baselines/nonaka/basic_conv1d.py (100%) rename {src => s4/src}/models/baselines/nonaka/inception.py (97%) rename {src => s4/src}/models/baselines/nonaka/resnet.py (99%) rename {src => s4/src}/models/baselines/nonaka/xresnet.py (99%) rename {src => s4/src}/models/baselines/nrde.py (100%) rename {src => s4/src}/models/baselines/odelstm.py (100%) rename {src => s4/src}/models/baselines/resnet.py (100%) rename {src => s4/src}/models/baselines/resnet_timm.py (100%) rename {src => s4/src}/models/baselines/samplernn.py (98%) rename {src => s4/src}/models/baselines/transformer.py (100%) rename {src => s4/src}/models/baselines/unicornn.py (99%) rename {src => s4/src}/models/baselines/vit.py (98%) rename {src => s4/src}/models/baselines/vit_all.py (98%) rename {src => s4/src}/models/baselines/wavenet.py (99%) rename {src => s4/src}/models/functional/cauchy.py (100%) rename {src => s4/src}/models/functional/krylov.py (98%) rename {src => s4/src}/models/functional/toeplitz.py (100%) rename {src => s4/src}/models/functional/unroll.py (98%) rename {src => s4/src}/models/functional/vandermonde.py (100%) rename {src => s4/src}/models/hippo/hippo.py (100%) rename {src => s4/src}/models/hippo/transition.py (98%) rename {src => s4/src}/models/hippo/visualizations.py (99%) rename {src => s4/src}/models/nn/__init__.py (100%) rename {src => s4/src}/models/nn/activation.py (100%) rename {src => s4/src}/models/nn/adaptive_softmax.py (99%) rename {src => s4/src}/models/nn/dropout.py (100%) rename {src => s4/src}/models/nn/dxt.py (100%) rename {src => s4/src}/models/nn/exprnn/README.md (100%) rename {src => s4/src}/models/nn/exprnn/expm32.py (100%) rename {src => s4/src}/models/nn/exprnn/initialization.py (100%) rename {src => s4/src}/models/nn/exprnn/orthogonal.py (98%) rename {src => s4/src}/models/nn/exprnn/parametrization.py (100%) rename {src => s4/src}/models/nn/exprnn/trivializations.py (100%) rename {src => s4/src}/models/nn/gate.py (100%) rename {src => s4/src}/models/nn/initialization.py (100%) rename {src => s4/src}/models/nn/linear.py (98%) rename {src => s4/src}/models/nn/normalization.py (100%) rename {src => s4/src}/models/nn/orthogonal.py (100%) rename {src => s4/src}/models/nn/residual.py (100%) rename {src => s4/src}/models/nn/utils.py (100%) rename {src => s4/src}/models/s4/README.md (100%) rename {src => s4/src}/models/sequence/README.md (100%) rename {src => s4/src}/models/sequence/__init__.py (100%) rename {src => s4/src}/models/sequence/attention/linear.py (98%) rename {src => s4/src}/models/sequence/attention/mha.py (97%) rename {src => s4/src}/models/sequence/attention/performer.py (100%) create mode 100644 s4/src/models/sequence/backbones/__init__.py rename {src => s4/src}/models/sequence/backbones/block.py (93%) rename {src => s4/src}/models/sequence/backbones/model.py (96%) rename {src => s4/src}/models/sequence/backbones/sashimi.py (97%) rename {src => s4/src}/models/sequence/backbones/unet.py (96%) rename {src => s4/src}/models/sequence/base.py (100%) rename {src => s4/src}/models/sequence/convs/conv1d.py (90%) rename {src => s4/src}/models/sequence/convs/conv2d.py (94%) rename {src => s4/src}/models/sequence/kernels/__init__.py (100%) rename {src => s4/src}/models/sequence/kernels/dplr.py (98%) rename {src => s4/src}/models/sequence/kernels/fftconv.py (97%) rename {src => s4/src}/models/sequence/kernels/kernel.py (99%) rename {src => s4/src}/models/sequence/kernels/ssm.py (98%) create mode 100644 s4/src/models/sequence/modules/__init__.py rename {src => s4/src}/models/sequence/modules/ffn.py (93%) rename {src => s4/src}/models/sequence/modules/lssl.py (97%) rename {src => s4/src}/models/sequence/modules/megablock.py (99%) rename {src => s4/src}/models/sequence/modules/pool.py (99%) rename {src => s4/src}/models/sequence/modules/s4block.py (96%) rename {src => s4/src}/models/sequence/modules/s4nd.py (97%) rename {src => s4/src}/models/sequence/rnns/__init__.py (100%) rename {src => s4/src}/models/sequence/rnns/cells/__init__.py (100%) rename {src => s4/src}/models/sequence/rnns/cells/basic.py (96%) rename {src => s4/src}/models/sequence/rnns/cells/hippo.py (96%) rename {src => s4/src}/models/sequence/rnns/cells/memory.py (98%) rename {src => s4/src}/models/sequence/rnns/cells/minimalrnn.py (92%) rename {src => s4/src}/models/sequence/rnns/cells/timestamp.py (96%) rename {src => s4/src}/models/sequence/rnns/qrnn.py (95%) rename {src => s4/src}/models/sequence/rnns/rnn.py (97%) rename {src => s4/src}/models/sequence/rnns/sru.py (95%) rename {src => s4/src}/tasks/decoders.py (99%) rename {src => s4/src}/tasks/encoders.py (98%) rename {src => s4/src}/tasks/metrics.py (100%) rename {src => s4/src}/tasks/tasks.py (98%) rename {src => s4/src}/utils/__init__.py (100%) rename {src => s4/src}/utils/config.py (100%) rename {src => s4/src}/utils/distributed.py (100%) rename {src => s4/src}/utils/optim/ema.py (100%) rename {src => s4/src}/utils/optim/lamb.py (100%) rename {src => s4/src}/utils/optim/schedulers.py (100%) rename {src => s4/src}/utils/optim_groups.py (100%) rename {src => s4/src}/utils/permutations.py (100%) create mode 100644 s4/src/utils/registry.py rename {src => s4/src}/utils/train.py (99%) rename train.py => s4/train.py (98%) delete mode 100644 src/utils/registry.py diff --git a/example.py b/example.py index b4d5e2d6..57964048 100644 --- a/example.py +++ b/example.py @@ -32,8 +32,8 @@ import os import argparse -from models.s4.s4 import S4Block as S4 # Can use full version instead of minimal S4D standalone below -from models.s4.s4d import S4D +from s4.models.s4.s4 import S4Block as S4 # Can use full version instead of minimal S4D standalone below +from s4.models.s4.s4d import S4D from tqdm.auto import tqdm # Dropout broke in PyTorch 1.11 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..2c96dec7 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,35 @@ +[build-system] +requires = ["setuptools>=45", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "s4" +version = "0.1.0" +description = "Structured State Space Models" +dependencies = [ + "numpy", + "scipy", + "pandas", + "scikit-learn", + "matplotlib", + "tqdm", + "rich", + "torchtext", + "lit", + "pytorch-lightning==2.0.4", + "hydra-core", + "omegaconf", + "wandb", + "einops", + "cmake", + "transformers", + "datasets", + "sktime", + "numba", + "gluonts", + "timm==0.5.4", +] + +[tool.setuptools.packages.find] +where = ["."] +include = ["s4*"] \ No newline at end of file diff --git a/models/s4/lssl.md b/s4/__init__.py similarity index 100% rename from models/s4/lssl.md rename to s4/__init__.py diff --git a/checkpoints/README.md b/s4/checkpoints/README.md similarity index 100% rename from checkpoints/README.md rename to s4/checkpoints/README.md diff --git a/checkpoints/convert_pl_to_pt.py b/s4/checkpoints/convert_pl_to_pt.py similarity index 87% rename from checkpoints/convert_pl_to_pt.py rename to s4/checkpoints/convert_pl_to_pt.py index b609260c..374d1dff 100644 --- a/checkpoints/convert_pl_to_pt.py +++ b/s4/checkpoints/convert_pl_to_pt.py @@ -2,7 +2,7 @@ import torch from pathlib import Path -from train import SequenceLightningModule +from s4.train import SequenceLightningModule parser = argparse.ArgumentParser() diff --git a/checkpoints/convert_v3_to_v4.py b/s4/checkpoints/convert_v3_to_v4.py similarity index 98% rename from checkpoints/convert_v3_to_v4.py rename to s4/checkpoints/convert_v3_to_v4.py index 15369bc9..6da2c27d 100644 --- a/checkpoints/convert_v3_to_v4.py +++ b/s4/checkpoints/convert_v3_to_v4.py @@ -10,10 +10,10 @@ from torch.nn.modules import module import torch.nn.functional as F from torch.distributions import Categorical -from src import utils +from s4.src import utils from einops import rearrange, repeat, reduce -from train import SequenceLightningModule +from s4.train import SequenceLightningModule from omegaconf import OmegaConf diff --git a/checkpoints/evaluate.py b/s4/checkpoints/evaluate.py similarity index 97% rename from checkpoints/evaluate.py rename to s4/checkpoints/evaluate.py index 74181544..4a4c883f 100644 --- a/checkpoints/evaluate.py +++ b/s4/checkpoints/evaluate.py @@ -8,10 +8,10 @@ from torch.nn.modules import module import torch.nn.functional as F from torch.distributions import Categorical -from src import utils +from s4.src import utils from einops import rearrange, repeat, reduce -from train import SequenceLightningModule +from s4.train import SequenceLightningModule from omegaconf import OmegaConf @hydra.main(config_path="../configs", config_name="generate.yaml") diff --git a/configs/README.md b/s4/configs/README.md similarity index 100% rename from configs/README.md rename to s4/configs/README.md diff --git a/configs/callbacks/base.yaml b/s4/configs/callbacks/base.yaml similarity index 100% rename from configs/callbacks/base.yaml rename to s4/configs/callbacks/base.yaml diff --git a/configs/callbacks/checkpoint.yaml b/s4/configs/callbacks/checkpoint.yaml similarity index 100% rename from configs/callbacks/checkpoint.yaml rename to s4/configs/callbacks/checkpoint.yaml diff --git a/configs/callbacks/progressive_resizing.yaml b/s4/configs/callbacks/progressive_resizing.yaml similarity index 100% rename from configs/callbacks/progressive_resizing.yaml rename to s4/configs/callbacks/progressive_resizing.yaml diff --git a/configs/callbacks/rich.yaml b/s4/configs/callbacks/rich.yaml similarity index 100% rename from configs/callbacks/rich.yaml rename to s4/configs/callbacks/rich.yaml diff --git a/configs/callbacks/swa.yaml b/s4/configs/callbacks/swa.yaml similarity index 100% rename from configs/callbacks/swa.yaml rename to s4/configs/callbacks/swa.yaml diff --git a/configs/callbacks/wandb.yaml b/s4/configs/callbacks/wandb.yaml similarity index 100% rename from configs/callbacks/wandb.yaml rename to s4/configs/callbacks/wandb.yaml diff --git a/configs/config.yaml b/s4/configs/config.yaml similarity index 100% rename from configs/config.yaml rename to s4/configs/config.yaml diff --git a/configs/dataset/aan.yaml b/s4/configs/dataset/aan.yaml similarity index 100% rename from configs/dataset/aan.yaml rename to s4/configs/dataset/aan.yaml diff --git a/configs/dataset/adding.yaml b/s4/configs/dataset/adding.yaml similarity index 100% rename from configs/dataset/adding.yaml rename to s4/configs/dataset/adding.yaml diff --git a/configs/dataset/beethoven.yaml b/s4/configs/dataset/beethoven.yaml similarity index 100% rename from configs/dataset/beethoven.yaml rename to s4/configs/dataset/beethoven.yaml diff --git a/configs/dataset/bidmc.yaml b/s4/configs/dataset/bidmc.yaml similarity index 100% rename from configs/dataset/bidmc.yaml rename to s4/configs/dataset/bidmc.yaml diff --git a/configs/dataset/celeba-all.yaml b/s4/configs/dataset/celeba-all.yaml similarity index 100% rename from configs/dataset/celeba-all.yaml rename to s4/configs/dataset/celeba-all.yaml diff --git a/configs/dataset/cifar.yaml b/s4/configs/dataset/cifar.yaml similarity index 100% rename from configs/dataset/cifar.yaml rename to s4/configs/dataset/cifar.yaml diff --git a/configs/dataset/copying.yaml b/s4/configs/dataset/copying.yaml similarity index 100% rename from configs/dataset/copying.yaml rename to s4/configs/dataset/copying.yaml diff --git a/configs/dataset/delay.yaml b/s4/configs/dataset/delay.yaml similarity index 100% rename from configs/dataset/delay.yaml rename to s4/configs/dataset/delay.yaml diff --git a/configs/dataset/ecl.yaml b/s4/configs/dataset/ecl.yaml similarity index 100% rename from configs/dataset/ecl.yaml rename to s4/configs/dataset/ecl.yaml diff --git a/configs/dataset/etth.yaml b/s4/configs/dataset/etth.yaml similarity index 100% rename from configs/dataset/etth.yaml rename to s4/configs/dataset/etth.yaml diff --git a/configs/dataset/ettm.yaml b/s4/configs/dataset/ettm.yaml similarity index 100% rename from configs/dataset/ettm.yaml rename to s4/configs/dataset/ettm.yaml diff --git a/configs/dataset/hmdb51.yaml b/s4/configs/dataset/hmdb51.yaml similarity index 100% rename from configs/dataset/hmdb51.yaml rename to s4/configs/dataset/hmdb51.yaml diff --git a/configs/dataset/imagenet.yaml b/s4/configs/dataset/imagenet.yaml similarity index 100% rename from configs/dataset/imagenet.yaml rename to s4/configs/dataset/imagenet.yaml diff --git a/configs/dataset/imdb.yaml b/s4/configs/dataset/imdb.yaml similarity index 100% rename from configs/dataset/imdb.yaml rename to s4/configs/dataset/imdb.yaml diff --git a/configs/dataset/listops.yaml b/s4/configs/dataset/listops.yaml similarity index 100% rename from configs/dataset/listops.yaml rename to s4/configs/dataset/listops.yaml diff --git a/configs/dataset/ljspeech.yaml b/s4/configs/dataset/ljspeech.yaml similarity index 100% rename from configs/dataset/ljspeech.yaml rename to s4/configs/dataset/ljspeech.yaml diff --git a/configs/dataset/mnist.yaml b/s4/configs/dataset/mnist.yaml similarity index 100% rename from configs/dataset/mnist.yaml rename to s4/configs/dataset/mnist.yaml diff --git a/configs/dataset/music.yaml b/s4/configs/dataset/music.yaml similarity index 100% rename from configs/dataset/music.yaml rename to s4/configs/dataset/music.yaml diff --git a/configs/dataset/pathfinder.yaml b/s4/configs/dataset/pathfinder.yaml similarity index 100% rename from configs/dataset/pathfinder.yaml rename to s4/configs/dataset/pathfinder.yaml diff --git a/configs/dataset/qautomusic.yaml b/s4/configs/dataset/qautomusic.yaml similarity index 100% rename from configs/dataset/qautomusic.yaml rename to s4/configs/dataset/qautomusic.yaml diff --git a/configs/dataset/reconstruct.yaml b/s4/configs/dataset/reconstruct.yaml similarity index 100% rename from configs/dataset/reconstruct.yaml rename to s4/configs/dataset/reconstruct.yaml diff --git a/configs/dataset/sc.yaml b/s4/configs/dataset/sc.yaml similarity index 100% rename from configs/dataset/sc.yaml rename to s4/configs/dataset/sc.yaml diff --git a/configs/dataset/sc09.yaml b/s4/configs/dataset/sc09.yaml similarity index 100% rename from configs/dataset/sc09.yaml rename to s4/configs/dataset/sc09.yaml diff --git a/configs/dataset/sc10.yaml b/s4/configs/dataset/sc10.yaml similarity index 100% rename from configs/dataset/sc10.yaml rename to s4/configs/dataset/sc10.yaml diff --git a/configs/dataset/weather.yaml b/s4/configs/dataset/weather.yaml similarity index 100% rename from configs/dataset/weather.yaml rename to s4/configs/dataset/weather.yaml diff --git a/configs/dataset/wt103.yaml b/s4/configs/dataset/wt103.yaml similarity index 100% rename from configs/dataset/wt103.yaml rename to s4/configs/dataset/wt103.yaml diff --git a/configs/dataset/youtubemix.yaml b/s4/configs/dataset/youtubemix.yaml similarity index 100% rename from configs/dataset/youtubemix.yaml rename to s4/configs/dataset/youtubemix.yaml diff --git a/configs/experiment/README.md b/s4/configs/experiment/README.md similarity index 100% rename from configs/experiment/README.md rename to s4/configs/experiment/README.md diff --git a/configs/experiment/audio/samplernn-beethoven.yaml b/s4/configs/experiment/audio/samplernn-beethoven.yaml similarity index 100% rename from configs/experiment/audio/samplernn-beethoven.yaml rename to s4/configs/experiment/audio/samplernn-beethoven.yaml diff --git a/configs/experiment/audio/samplernn-qautomusic.yaml b/s4/configs/experiment/audio/samplernn-qautomusic.yaml similarity index 100% rename from configs/experiment/audio/samplernn-qautomusic.yaml rename to s4/configs/experiment/audio/samplernn-qautomusic.yaml diff --git a/configs/experiment/audio/samplernn-sc09.yaml b/s4/configs/experiment/audio/samplernn-sc09.yaml similarity index 100% rename from configs/experiment/audio/samplernn-sc09.yaml rename to s4/configs/experiment/audio/samplernn-sc09.yaml diff --git a/configs/experiment/audio/samplernn-scg.yaml b/s4/configs/experiment/audio/samplernn-scg.yaml similarity index 100% rename from configs/experiment/audio/samplernn-scg.yaml rename to s4/configs/experiment/audio/samplernn-scg.yaml diff --git a/configs/experiment/audio/samplernn-youtubemix.yaml b/s4/configs/experiment/audio/samplernn-youtubemix.yaml similarity index 100% rename from configs/experiment/audio/samplernn-youtubemix.yaml rename to s4/configs/experiment/audio/samplernn-youtubemix.yaml diff --git a/configs/experiment/audio/sashimi-beethoven.yaml b/s4/configs/experiment/audio/sashimi-beethoven.yaml similarity index 100% rename from configs/experiment/audio/sashimi-beethoven.yaml rename to s4/configs/experiment/audio/sashimi-beethoven.yaml diff --git a/configs/experiment/audio/sashimi-sc09-unet.yaml b/s4/configs/experiment/audio/sashimi-sc09-unet.yaml similarity index 100% rename from configs/experiment/audio/sashimi-sc09-unet.yaml rename to s4/configs/experiment/audio/sashimi-sc09-unet.yaml diff --git a/configs/experiment/audio/sashimi-sc09.yaml b/s4/configs/experiment/audio/sashimi-sc09.yaml similarity index 100% rename from configs/experiment/audio/sashimi-sc09.yaml rename to s4/configs/experiment/audio/sashimi-sc09.yaml diff --git a/configs/experiment/audio/sashimi-standalone.yaml b/s4/configs/experiment/audio/sashimi-standalone.yaml similarity index 100% rename from configs/experiment/audio/sashimi-standalone.yaml rename to s4/configs/experiment/audio/sashimi-standalone.yaml diff --git a/configs/experiment/audio/sashimi-youtubemix.yaml b/s4/configs/experiment/audio/sashimi-youtubemix.yaml similarity index 100% rename from configs/experiment/audio/sashimi-youtubemix.yaml rename to s4/configs/experiment/audio/sashimi-youtubemix.yaml diff --git a/configs/experiment/audio/wavenet-beethoven.yaml b/s4/configs/experiment/audio/wavenet-beethoven.yaml similarity index 100% rename from configs/experiment/audio/wavenet-beethoven.yaml rename to s4/configs/experiment/audio/wavenet-beethoven.yaml diff --git a/configs/experiment/audio/wavenet-qautomusic.yaml b/s4/configs/experiment/audio/wavenet-qautomusic.yaml similarity index 100% rename from configs/experiment/audio/wavenet-qautomusic.yaml rename to s4/configs/experiment/audio/wavenet-qautomusic.yaml diff --git a/configs/experiment/audio/wavenet-sc09.yaml b/s4/configs/experiment/audio/wavenet-sc09.yaml similarity index 100% rename from configs/experiment/audio/wavenet-sc09.yaml rename to s4/configs/experiment/audio/wavenet-sc09.yaml diff --git a/configs/experiment/audio/wavenet-youtubemix.yaml b/s4/configs/experiment/audio/wavenet-youtubemix.yaml similarity index 100% rename from configs/experiment/audio/wavenet-youtubemix.yaml rename to s4/configs/experiment/audio/wavenet-youtubemix.yaml diff --git a/configs/experiment/base.yaml b/s4/configs/experiment/base.yaml similarity index 100% rename from configs/experiment/base.yaml rename to s4/configs/experiment/base.yaml diff --git a/configs/experiment/bidmc/ckconv-bidmc.yaml b/s4/configs/experiment/bidmc/ckconv-bidmc.yaml similarity index 100% rename from configs/experiment/bidmc/ckconv-bidmc.yaml rename to s4/configs/experiment/bidmc/ckconv-bidmc.yaml diff --git a/configs/experiment/bidmc/resnet-bidmc.yaml b/s4/configs/experiment/bidmc/resnet-bidmc.yaml similarity index 100% rename from configs/experiment/bidmc/resnet-bidmc.yaml rename to s4/configs/experiment/bidmc/resnet-bidmc.yaml diff --git a/configs/experiment/bidmc/s4-bidmc-ablation.yaml b/s4/configs/experiment/bidmc/s4-bidmc-ablation.yaml similarity index 100% rename from configs/experiment/bidmc/s4-bidmc-ablation.yaml rename to s4/configs/experiment/bidmc/s4-bidmc-ablation.yaml diff --git a/configs/experiment/bidmc/s4-bidmc.yaml b/s4/configs/experiment/bidmc/s4-bidmc.yaml similarity index 100% rename from configs/experiment/bidmc/s4-bidmc.yaml rename to s4/configs/experiment/bidmc/s4-bidmc.yaml diff --git a/configs/experiment/cifar/cnn-cifar-2d.yaml b/s4/configs/experiment/cifar/cnn-cifar-2d.yaml similarity index 100% rename from configs/experiment/cifar/cnn-cifar-2d.yaml rename to s4/configs/experiment/cifar/cnn-cifar-2d.yaml diff --git a/configs/experiment/cifar/resnet-cifar.yaml b/s4/configs/experiment/cifar/resnet-cifar.yaml similarity index 100% rename from configs/experiment/cifar/resnet-cifar.yaml rename to s4/configs/experiment/cifar/resnet-cifar.yaml diff --git a/configs/experiment/cifar/s4-cifar-ablation.yaml b/s4/configs/experiment/cifar/s4-cifar-ablation.yaml similarity index 100% rename from configs/experiment/cifar/s4-cifar-ablation.yaml rename to s4/configs/experiment/cifar/s4-cifar-ablation.yaml diff --git a/configs/experiment/cifar/s4-cifar.yaml b/s4/configs/experiment/cifar/s4-cifar.yaml similarity index 100% rename from configs/experiment/cifar/s4-cifar.yaml rename to s4/configs/experiment/cifar/s4-cifar.yaml diff --git a/configs/experiment/cifar/s4d-minimal-cifar.yaml b/s4/configs/experiment/cifar/s4d-minimal-cifar.yaml similarity index 100% rename from configs/experiment/cifar/s4d-minimal-cifar.yaml rename to s4/configs/experiment/cifar/s4d-minimal-cifar.yaml diff --git a/configs/experiment/forecasting/s4-informer-ecl.yaml b/s4/configs/experiment/forecasting/s4-informer-ecl.yaml similarity index 100% rename from configs/experiment/forecasting/s4-informer-ecl.yaml rename to s4/configs/experiment/forecasting/s4-informer-ecl.yaml diff --git a/configs/experiment/forecasting/s4-informer-etth.yaml b/s4/configs/experiment/forecasting/s4-informer-etth.yaml similarity index 100% rename from configs/experiment/forecasting/s4-informer-etth.yaml rename to s4/configs/experiment/forecasting/s4-informer-etth.yaml diff --git a/configs/experiment/forecasting/s4-informer-ettm.yaml b/s4/configs/experiment/forecasting/s4-informer-ettm.yaml similarity index 100% rename from configs/experiment/forecasting/s4-informer-ettm.yaml rename to s4/configs/experiment/forecasting/s4-informer-ettm.yaml diff --git a/configs/experiment/forecasting/s4-informer-weather.yaml b/s4/configs/experiment/forecasting/s4-informer-weather.yaml similarity index 100% rename from configs/experiment/forecasting/s4-informer-weather.yaml rename to s4/configs/experiment/forecasting/s4-informer-weather.yaml diff --git a/configs/experiment/lm/s4-wt103.yaml b/s4/configs/experiment/lm/s4-wt103.yaml similarity index 100% rename from configs/experiment/lm/s4-wt103.yaml rename to s4/configs/experiment/lm/s4-wt103.yaml diff --git a/configs/experiment/lm/transformer-wt103.yaml b/s4/configs/experiment/lm/transformer-wt103.yaml similarity index 100% rename from configs/experiment/lm/transformer-wt103.yaml rename to s4/configs/experiment/lm/transformer-wt103.yaml diff --git a/configs/experiment/lra/lra-cifar.yaml b/s4/configs/experiment/lra/lra-cifar.yaml similarity index 100% rename from configs/experiment/lra/lra-cifar.yaml rename to s4/configs/experiment/lra/lra-cifar.yaml diff --git a/configs/experiment/lra/lra-listops.yaml b/s4/configs/experiment/lra/lra-listops.yaml similarity index 100% rename from configs/experiment/lra/lra-listops.yaml rename to s4/configs/experiment/lra/lra-listops.yaml diff --git a/configs/experiment/lra/old/s4-lra-aan.yaml b/s4/configs/experiment/lra/old/s4-lra-aan.yaml similarity index 100% rename from configs/experiment/lra/old/s4-lra-aan.yaml rename to s4/configs/experiment/lra/old/s4-lra-aan.yaml diff --git a/configs/experiment/lra/old/s4-lra-cifar.yaml b/s4/configs/experiment/lra/old/s4-lra-cifar.yaml similarity index 100% rename from configs/experiment/lra/old/s4-lra-cifar.yaml rename to s4/configs/experiment/lra/old/s4-lra-cifar.yaml diff --git a/configs/experiment/lra/old/s4-lra-imdb.yaml b/s4/configs/experiment/lra/old/s4-lra-imdb.yaml similarity index 100% rename from configs/experiment/lra/old/s4-lra-imdb.yaml rename to s4/configs/experiment/lra/old/s4-lra-imdb.yaml diff --git a/configs/experiment/lra/old/s4-lra-listops.yaml b/s4/configs/experiment/lra/old/s4-lra-listops.yaml similarity index 100% rename from configs/experiment/lra/old/s4-lra-listops.yaml rename to s4/configs/experiment/lra/old/s4-lra-listops.yaml diff --git a/configs/experiment/lra/old/s4-lra-pathfinder.yaml b/s4/configs/experiment/lra/old/s4-lra-pathfinder.yaml similarity index 100% rename from configs/experiment/lra/old/s4-lra-pathfinder.yaml rename to s4/configs/experiment/lra/old/s4-lra-pathfinder.yaml diff --git a/configs/experiment/lra/old/s4-lra-pathx.yaml b/s4/configs/experiment/lra/old/s4-lra-pathx.yaml similarity index 100% rename from configs/experiment/lra/old/s4-lra-pathx.yaml rename to s4/configs/experiment/lra/old/s4-lra-pathx.yaml diff --git a/configs/experiment/lra/old/v3-s4-aan.yaml b/s4/configs/experiment/lra/old/v3-s4-aan.yaml similarity index 100% rename from configs/experiment/lra/old/v3-s4-aan.yaml rename to s4/configs/experiment/lra/old/v3-s4-aan.yaml diff --git a/configs/experiment/lra/old/v3-s4-cifar.yaml b/s4/configs/experiment/lra/old/v3-s4-cifar.yaml similarity index 100% rename from configs/experiment/lra/old/v3-s4-cifar.yaml rename to s4/configs/experiment/lra/old/v3-s4-cifar.yaml diff --git a/configs/experiment/lra/old/v3-s4-imdb-small.yaml b/s4/configs/experiment/lra/old/v3-s4-imdb-small.yaml similarity index 100% rename from configs/experiment/lra/old/v3-s4-imdb-small.yaml rename to s4/configs/experiment/lra/old/v3-s4-imdb-small.yaml diff --git a/configs/experiment/lra/old/v3-s4-imdb.yaml b/s4/configs/experiment/lra/old/v3-s4-imdb.yaml similarity index 100% rename from configs/experiment/lra/old/v3-s4-imdb.yaml rename to s4/configs/experiment/lra/old/v3-s4-imdb.yaml diff --git a/configs/experiment/lra/old/v3-s4-listops-small.yaml b/s4/configs/experiment/lra/old/v3-s4-listops-small.yaml similarity index 100% rename from configs/experiment/lra/old/v3-s4-listops-small.yaml rename to s4/configs/experiment/lra/old/v3-s4-listops-small.yaml diff --git a/configs/experiment/lra/old/v3-s4-listops.yaml b/s4/configs/experiment/lra/old/v3-s4-listops.yaml similarity index 100% rename from configs/experiment/lra/old/v3-s4-listops.yaml rename to s4/configs/experiment/lra/old/v3-s4-listops.yaml diff --git a/configs/experiment/lra/old/v3-s4-pathfinder-small.yaml b/s4/configs/experiment/lra/old/v3-s4-pathfinder-small.yaml similarity index 100% rename from configs/experiment/lra/old/v3-s4-pathfinder-small.yaml rename to s4/configs/experiment/lra/old/v3-s4-pathfinder-small.yaml diff --git a/configs/experiment/lra/old/v3-s4-pathfinder.yaml b/s4/configs/experiment/lra/old/v3-s4-pathfinder.yaml similarity index 100% rename from configs/experiment/lra/old/v3-s4-pathfinder.yaml rename to s4/configs/experiment/lra/old/v3-s4-pathfinder.yaml diff --git a/configs/experiment/lra/old/v3-s4-pathx-small.yaml b/s4/configs/experiment/lra/old/v3-s4-pathx-small.yaml similarity index 100% rename from configs/experiment/lra/old/v3-s4-pathx-small.yaml rename to s4/configs/experiment/lra/old/v3-s4-pathx-small.yaml diff --git a/configs/experiment/lra/old/v3-s4-pathx.yaml b/s4/configs/experiment/lra/old/v3-s4-pathx.yaml similarity index 100% rename from configs/experiment/lra/old/v3-s4-pathx.yaml rename to s4/configs/experiment/lra/old/v3-s4-pathx.yaml diff --git a/configs/experiment/lra/resnet-pathx.yaml b/s4/configs/experiment/lra/resnet-pathx.yaml similarity index 100% rename from configs/experiment/lra/resnet-pathx.yaml rename to s4/configs/experiment/lra/resnet-pathx.yaml diff --git a/configs/experiment/lra/s4-aan.yaml b/s4/configs/experiment/lra/s4-aan.yaml similarity index 100% rename from configs/experiment/lra/s4-aan.yaml rename to s4/configs/experiment/lra/s4-aan.yaml diff --git a/configs/experiment/lra/s4-cifar.yaml b/s4/configs/experiment/lra/s4-cifar.yaml similarity index 100% rename from configs/experiment/lra/s4-cifar.yaml rename to s4/configs/experiment/lra/s4-cifar.yaml diff --git a/configs/experiment/lra/s4-imdb.yaml b/s4/configs/experiment/lra/s4-imdb.yaml similarity index 100% rename from configs/experiment/lra/s4-imdb.yaml rename to s4/configs/experiment/lra/s4-imdb.yaml diff --git a/configs/experiment/lra/s4-listops.yaml b/s4/configs/experiment/lra/s4-listops.yaml similarity index 100% rename from configs/experiment/lra/s4-listops.yaml rename to s4/configs/experiment/lra/s4-listops.yaml diff --git a/configs/experiment/lra/s4-pathfinder.yaml b/s4/configs/experiment/lra/s4-pathfinder.yaml similarity index 100% rename from configs/experiment/lra/s4-pathfinder.yaml rename to s4/configs/experiment/lra/s4-pathfinder.yaml diff --git a/configs/experiment/lra/s4-pathx.yaml b/s4/configs/experiment/lra/s4-pathx.yaml similarity index 100% rename from configs/experiment/lra/s4-pathx.yaml rename to s4/configs/experiment/lra/s4-pathx.yaml diff --git a/configs/experiment/mega/lra-image/README.md b/s4/configs/experiment/mega/lra-image/README.md similarity index 100% rename from configs/experiment/mega/lra-image/README.md rename to s4/configs/experiment/mega/lra-image/README.md diff --git a/configs/experiment/mega/lra-image/large-ema-with-s4.yaml b/s4/configs/experiment/mega/lra-image/large-ema-with-s4.yaml similarity index 100% rename from configs/experiment/mega/lra-image/large-ema-with-s4.yaml rename to s4/configs/experiment/mega/lra-image/large-ema-with-s4.yaml diff --git a/configs/experiment/mega/lra-image/large-ema.yaml b/s4/configs/experiment/mega/lra-image/large-ema.yaml similarity index 100% rename from configs/experiment/mega/lra-image/large-ema.yaml rename to s4/configs/experiment/mega/lra-image/large-ema.yaml diff --git a/configs/experiment/mega/lra-image/large-mega-ema-with-s4.yaml b/s4/configs/experiment/mega/lra-image/large-mega-ema-with-s4.yaml similarity index 100% rename from configs/experiment/mega/lra-image/large-mega-ema-with-s4.yaml rename to s4/configs/experiment/mega/lra-image/large-mega-ema-with-s4.yaml diff --git a/configs/experiment/mega/lra-image/large-mega-ema.yaml b/s4/configs/experiment/mega/lra-image/large-mega-ema.yaml similarity index 100% rename from configs/experiment/mega/lra-image/large-mega-ema.yaml rename to s4/configs/experiment/mega/lra-image/large-mega-ema.yaml diff --git a/configs/experiment/mega/lra-image/large-mega-s4d-real.yaml b/s4/configs/experiment/mega/lra-image/large-mega-s4d-real.yaml similarity index 100% rename from configs/experiment/mega/lra-image/large-mega-s4d-real.yaml rename to s4/configs/experiment/mega/lra-image/large-mega-s4d-real.yaml diff --git a/configs/experiment/mega/lra-image/large-mega-s4d.yaml b/s4/configs/experiment/mega/lra-image/large-mega-s4d.yaml similarity index 100% rename from configs/experiment/mega/lra-image/large-mega-s4d.yaml rename to s4/configs/experiment/mega/lra-image/large-mega-s4d.yaml diff --git a/configs/experiment/mega/lra-image/large-s4d-real.yaml b/s4/configs/experiment/mega/lra-image/large-s4d-real.yaml similarity index 100% rename from configs/experiment/mega/lra-image/large-s4d-real.yaml rename to s4/configs/experiment/mega/lra-image/large-s4d-real.yaml diff --git a/configs/experiment/mega/lra-image/large-s4d.yaml b/s4/configs/experiment/mega/lra-image/large-s4d.yaml similarity index 100% rename from configs/experiment/mega/lra-image/large-s4d.yaml rename to s4/configs/experiment/mega/lra-image/large-s4d.yaml diff --git a/configs/experiment/mega/lra-image/mega_ablations_10000_warmup_all.pdf b/s4/configs/experiment/mega/lra-image/mega_ablations_10000_warmup_all.pdf similarity index 100% rename from configs/experiment/mega/lra-image/mega_ablations_10000_warmup_all.pdf rename to s4/configs/experiment/mega/lra-image/mega_ablations_10000_warmup_all.pdf diff --git a/configs/experiment/mega/lra-image/mega_ablations_1000_warmup_all.pdf b/s4/configs/experiment/mega/lra-image/mega_ablations_1000_warmup_all.pdf similarity index 100% rename from configs/experiment/mega/lra-image/mega_ablations_1000_warmup_all.pdf rename to s4/configs/experiment/mega/lra-image/mega_ablations_1000_warmup_all.pdf diff --git a/configs/experiment/mega/lra-image/mega_ablations_mega_repo.pdf b/s4/configs/experiment/mega/lra-image/mega_ablations_mega_repo.pdf similarity index 100% rename from configs/experiment/mega/lra-image/mega_ablations_mega_repo.pdf rename to s4/configs/experiment/mega/lra-image/mega_ablations_mega_repo.pdf diff --git a/configs/experiment/mega/lra-image/small-ema-with-s4.yaml b/s4/configs/experiment/mega/lra-image/small-ema-with-s4.yaml similarity index 100% rename from configs/experiment/mega/lra-image/small-ema-with-s4.yaml rename to s4/configs/experiment/mega/lra-image/small-ema-with-s4.yaml diff --git a/configs/experiment/mega/lra-image/small-ema.yaml b/s4/configs/experiment/mega/lra-image/small-ema.yaml similarity index 100% rename from configs/experiment/mega/lra-image/small-ema.yaml rename to s4/configs/experiment/mega/lra-image/small-ema.yaml diff --git a/configs/experiment/mega/lra-image/small-mega-ema-with-s4.yaml b/s4/configs/experiment/mega/lra-image/small-mega-ema-with-s4.yaml similarity index 100% rename from configs/experiment/mega/lra-image/small-mega-ema-with-s4.yaml rename to s4/configs/experiment/mega/lra-image/small-mega-ema-with-s4.yaml diff --git a/configs/experiment/mega/lra-image/small-mega-ema.yaml b/s4/configs/experiment/mega/lra-image/small-mega-ema.yaml similarity index 100% rename from configs/experiment/mega/lra-image/small-mega-ema.yaml rename to s4/configs/experiment/mega/lra-image/small-mega-ema.yaml diff --git a/configs/experiment/mega/lra-image/small-mega-s4d-real.yaml b/s4/configs/experiment/mega/lra-image/small-mega-s4d-real.yaml similarity index 100% rename from configs/experiment/mega/lra-image/small-mega-s4d-real.yaml rename to s4/configs/experiment/mega/lra-image/small-mega-s4d-real.yaml diff --git a/configs/experiment/mega/lra-image/small-mega-s4d.yaml b/s4/configs/experiment/mega/lra-image/small-mega-s4d.yaml similarity index 100% rename from configs/experiment/mega/lra-image/small-mega-s4d.yaml rename to s4/configs/experiment/mega/lra-image/small-mega-s4d.yaml diff --git a/configs/experiment/mega/lra-image/small-s4d-real.yaml b/s4/configs/experiment/mega/lra-image/small-s4d-real.yaml similarity index 100% rename from configs/experiment/mega/lra-image/small-s4d-real.yaml rename to s4/configs/experiment/mega/lra-image/small-s4d-real.yaml diff --git a/configs/experiment/mega/lra-image/small-s4d.yaml b/s4/configs/experiment/mega/lra-image/small-s4d.yaml similarity index 100% rename from configs/experiment/mega/lra-image/small-s4d.yaml rename to s4/configs/experiment/mega/lra-image/small-s4d.yaml diff --git a/configs/experiment/rnn.yaml b/s4/configs/experiment/rnn.yaml similarity index 100% rename from configs/experiment/rnn.yaml rename to s4/configs/experiment/rnn.yaml diff --git a/configs/experiment/s4nd/README.md b/s4/configs/experiment/s4nd/README.md similarity index 100% rename from configs/experiment/s4nd/README.md rename to s4/configs/experiment/s4nd/README.md diff --git a/configs/experiment/s4nd/celeba/convnext-celeba-all.yaml b/s4/configs/experiment/s4nd/celeba/convnext-celeba-all.yaml similarity index 100% rename from configs/experiment/s4nd/celeba/convnext-celeba-all.yaml rename to s4/configs/experiment/s4nd/celeba/convnext-celeba-all.yaml diff --git a/configs/experiment/s4nd/celeba/convnext-s4nd-celeba-all.yaml b/s4/configs/experiment/s4nd/celeba/convnext-s4nd-celeba-all.yaml similarity index 100% rename from configs/experiment/s4nd/celeba/convnext-s4nd-celeba-all.yaml rename to s4/configs/experiment/s4nd/celeba/convnext-s4nd-celeba-all.yaml diff --git a/configs/experiment/s4nd/cifar/cnn-cifar-2d.yaml b/s4/configs/experiment/s4nd/cifar/cnn-cifar-2d.yaml similarity index 100% rename from configs/experiment/s4nd/cifar/cnn-cifar-2d.yaml rename to s4/configs/experiment/s4nd/cifar/cnn-cifar-2d.yaml diff --git a/configs/experiment/s4nd/cifar/s4-cifar-2d-16x16.yaml b/s4/configs/experiment/s4nd/cifar/s4-cifar-2d-16x16.yaml similarity index 100% rename from configs/experiment/s4nd/cifar/s4-cifar-2d-16x16.yaml rename to s4/configs/experiment/s4nd/cifar/s4-cifar-2d-16x16.yaml diff --git a/configs/experiment/s4nd/cifar/s4-cifar-2d.yaml b/s4/configs/experiment/s4nd/cifar/s4-cifar-2d.yaml similarity index 100% rename from configs/experiment/s4nd/cifar/s4-cifar-2d.yaml rename to s4/configs/experiment/s4nd/cifar/s4-cifar-2d.yaml diff --git a/configs/experiment/s4nd/convnext/convnext_timm_tiny_imagenet.yaml b/s4/configs/experiment/s4nd/convnext/convnext_timm_tiny_imagenet.yaml similarity index 100% rename from configs/experiment/s4nd/convnext/convnext_timm_tiny_imagenet.yaml rename to s4/configs/experiment/s4nd/convnext/convnext_timm_tiny_imagenet.yaml diff --git a/configs/experiment/s4nd/convnext/convnext_timm_tiny_inflate3d_hmdb.yaml b/s4/configs/experiment/s4nd/convnext/convnext_timm_tiny_inflate3d_hmdb.yaml similarity index 100% rename from configs/experiment/s4nd/convnext/convnext_timm_tiny_inflate3d_hmdb.yaml rename to s4/configs/experiment/s4nd/convnext/convnext_timm_tiny_inflate3d_hmdb.yaml diff --git a/configs/experiment/s4nd/convnext/convnext_timm_tiny_inflate3d_s4nd_hmdb.yaml b/s4/configs/experiment/s4nd/convnext/convnext_timm_tiny_inflate3d_s4nd_hmdb.yaml similarity index 100% rename from configs/experiment/s4nd/convnext/convnext_timm_tiny_inflate3d_s4nd_hmdb.yaml rename to s4/configs/experiment/s4nd/convnext/convnext_timm_tiny_inflate3d_s4nd_hmdb.yaml diff --git a/configs/experiment/s4nd/convnext/convnext_timm_tiny_s4nd_imagenet.yaml b/s4/configs/experiment/s4nd/convnext/convnext_timm_tiny_s4nd_imagenet.yaml similarity index 100% rename from configs/experiment/s4nd/convnext/convnext_timm_tiny_s4nd_imagenet.yaml rename to s4/configs/experiment/s4nd/convnext/convnext_timm_tiny_s4nd_imagenet.yaml diff --git a/configs/experiment/s4nd/progres/cnn-cifar-2d.yaml b/s4/configs/experiment/s4nd/progres/cnn-cifar-2d.yaml similarity index 100% rename from configs/experiment/s4nd/progres/cnn-cifar-2d.yaml rename to s4/configs/experiment/s4nd/progres/cnn-cifar-2d.yaml diff --git a/configs/experiment/s4nd/progres/s4-cifar-2d.yaml b/s4/configs/experiment/s4nd/progres/s4-cifar-2d.yaml similarity index 100% rename from configs/experiment/s4nd/progres/s4-cifar-2d.yaml rename to s4/configs/experiment/s4nd/progres/s4-cifar-2d.yaml diff --git a/configs/experiment/s4nd/vit/vit_b_16_imagenet.yaml b/s4/configs/experiment/s4nd/vit/vit_b_16_imagenet.yaml similarity index 100% rename from configs/experiment/s4nd/vit/vit_b_16_imagenet.yaml rename to s4/configs/experiment/s4nd/vit/vit_b_16_imagenet.yaml diff --git a/configs/experiment/s4nd/vit/vit_b_16_s4_imagenet_v2.yaml b/s4/configs/experiment/s4nd/vit/vit_b_16_s4_imagenet_v2.yaml similarity index 100% rename from configs/experiment/s4nd/vit/vit_b_16_s4_imagenet_v2.yaml rename to s4/configs/experiment/s4nd/vit/vit_b_16_s4_imagenet_v2.yaml diff --git a/configs/experiment/sc/convnet-sc.yaml b/s4/configs/experiment/sc/convnet-sc.yaml similarity index 100% rename from configs/experiment/sc/convnet-sc.yaml rename to s4/configs/experiment/sc/convnet-sc.yaml diff --git a/configs/experiment/sc/resnet-sc.yaml b/s4/configs/experiment/sc/resnet-sc.yaml similarity index 100% rename from configs/experiment/sc/resnet-sc.yaml rename to s4/configs/experiment/sc/resnet-sc.yaml diff --git a/configs/experiment/sc/s4-sc-ablation.yaml b/s4/configs/experiment/sc/s4-sc-ablation.yaml similarity index 100% rename from configs/experiment/sc/s4-sc-ablation.yaml rename to s4/configs/experiment/sc/s4-sc-ablation.yaml diff --git a/configs/experiment/sc/s4-sc.yaml b/s4/configs/experiment/sc/s4-sc.yaml similarity index 100% rename from configs/experiment/sc/s4-sc.yaml rename to s4/configs/experiment/sc/s4-sc.yaml diff --git a/configs/experiment/sc/transformer-sc.yaml b/s4/configs/experiment/sc/transformer-sc.yaml similarity index 100% rename from configs/experiment/sc/transformer-sc.yaml rename to s4/configs/experiment/sc/transformer-sc.yaml diff --git a/configs/experiment/synthetic/s4-copying.yaml b/s4/configs/experiment/synthetic/s4-copying.yaml similarity index 100% rename from configs/experiment/synthetic/s4-copying.yaml rename to s4/configs/experiment/synthetic/s4-copying.yaml diff --git a/configs/experiment/synthetic/s4-delay.yaml b/s4/configs/experiment/synthetic/s4-delay.yaml similarity index 100% rename from configs/experiment/synthetic/s4-delay.yaml rename to s4/configs/experiment/synthetic/s4-delay.yaml diff --git a/configs/experiment/synthetic/s4-reconstruct.yaml b/s4/configs/experiment/synthetic/s4-reconstruct.yaml similarity index 100% rename from configs/experiment/synthetic/s4-reconstruct.yaml rename to s4/configs/experiment/synthetic/s4-reconstruct.yaml diff --git a/configs/generate.yaml b/s4/configs/generate.yaml similarity index 100% rename from configs/generate.yaml rename to s4/configs/generate.yaml diff --git a/configs/loader/default.yaml b/s4/configs/loader/default.yaml similarity index 100% rename from configs/loader/default.yaml rename to s4/configs/loader/default.yaml diff --git a/configs/loader/imresolution.yaml b/s4/configs/loader/imresolution.yaml similarity index 100% rename from configs/loader/imresolution.yaml rename to s4/configs/loader/imresolution.yaml diff --git a/configs/loader/lm.yaml b/s4/configs/loader/lm.yaml similarity index 100% rename from configs/loader/lm.yaml rename to s4/configs/loader/lm.yaml diff --git a/configs/loader/resolution.yaml b/s4/configs/loader/resolution.yaml similarity index 100% rename from configs/loader/resolution.yaml rename to s4/configs/loader/resolution.yaml diff --git a/configs/loader/tbptt.yaml b/s4/configs/loader/tbptt.yaml similarity index 100% rename from configs/loader/tbptt.yaml rename to s4/configs/loader/tbptt.yaml diff --git a/configs/model/README.md b/s4/configs/model/README.md similarity index 100% rename from configs/model/README.md rename to s4/configs/model/README.md diff --git a/configs/model/base.yaml b/s4/configs/model/base.yaml similarity index 100% rename from configs/model/base.yaml rename to s4/configs/model/base.yaml diff --git a/configs/model/baseline/ckconv.yaml b/s4/configs/model/baseline/ckconv.yaml similarity index 100% rename from configs/model/baseline/ckconv.yaml rename to s4/configs/model/baseline/ckconv.yaml diff --git a/configs/model/baseline/lipschitzrnn.yaml b/s4/configs/model/baseline/lipschitzrnn.yaml similarity index 100% rename from configs/model/baseline/lipschitzrnn.yaml rename to s4/configs/model/baseline/lipschitzrnn.yaml diff --git a/configs/model/baseline/lstm.yaml b/s4/configs/model/baseline/lstm.yaml similarity index 100% rename from configs/model/baseline/lstm.yaml rename to s4/configs/model/baseline/lstm.yaml diff --git a/configs/model/baseline/odelstm.yaml b/s4/configs/model/baseline/odelstm.yaml similarity index 100% rename from configs/model/baseline/odelstm.yaml rename to s4/configs/model/baseline/odelstm.yaml diff --git a/configs/model/baseline/resnet2d.yaml b/s4/configs/model/baseline/resnet2d.yaml similarity index 100% rename from configs/model/baseline/resnet2d.yaml rename to s4/configs/model/baseline/resnet2d.yaml diff --git a/configs/model/baseline/samplernn.yaml b/s4/configs/model/baseline/samplernn.yaml similarity index 100% rename from configs/model/baseline/samplernn.yaml rename to s4/configs/model/baseline/samplernn.yaml diff --git a/configs/model/baseline/stackedrnn_baseline.yaml b/s4/configs/model/baseline/stackedrnn_baseline.yaml similarity index 100% rename from configs/model/baseline/stackedrnn_baseline.yaml rename to s4/configs/model/baseline/stackedrnn_baseline.yaml diff --git a/configs/model/baseline/unicornn.yaml b/s4/configs/model/baseline/unicornn.yaml similarity index 100% rename from configs/model/baseline/unicornn.yaml rename to s4/configs/model/baseline/unicornn.yaml diff --git a/configs/model/baseline/wavenet.yaml b/s4/configs/model/baseline/wavenet.yaml similarity index 100% rename from configs/model/baseline/wavenet.yaml rename to s4/configs/model/baseline/wavenet.yaml diff --git a/configs/model/convnet1d.yaml b/s4/configs/model/convnet1d.yaml similarity index 100% rename from configs/model/convnet1d.yaml rename to s4/configs/model/convnet1d.yaml diff --git a/configs/model/convnet2d.yaml b/s4/configs/model/convnet2d.yaml similarity index 100% rename from configs/model/convnet2d.yaml rename to s4/configs/model/convnet2d.yaml diff --git a/configs/model/layer/cell/exprnn.yaml b/s4/configs/model/layer/cell/exprnn.yaml similarity index 100% rename from configs/model/layer/cell/exprnn.yaml rename to s4/configs/model/layer/cell/exprnn.yaml diff --git a/configs/model/layer/cell/goru.yaml b/s4/configs/model/layer/cell/goru.yaml similarity index 100% rename from configs/model/layer/cell/goru.yaml rename to s4/configs/model/layer/cell/goru.yaml diff --git a/configs/model/layer/cell/gru.yaml b/s4/configs/model/layer/cell/gru.yaml similarity index 100% rename from configs/model/layer/cell/gru.yaml rename to s4/configs/model/layer/cell/gru.yaml diff --git a/configs/model/layer/cell/hippo-glagt.yaml b/s4/configs/model/layer/cell/hippo-glagt.yaml similarity index 100% rename from configs/model/layer/cell/hippo-glagt.yaml rename to s4/configs/model/layer/cell/hippo-glagt.yaml diff --git a/configs/model/layer/cell/hippo-lagt.yaml b/s4/configs/model/layer/cell/hippo-lagt.yaml similarity index 100% rename from configs/model/layer/cell/hippo-lagt.yaml rename to s4/configs/model/layer/cell/hippo-lagt.yaml diff --git a/configs/model/layer/cell/hippo-legs.yaml b/s4/configs/model/layer/cell/hippo-legs.yaml similarity index 100% rename from configs/model/layer/cell/hippo-legs.yaml rename to s4/configs/model/layer/cell/hippo-legs.yaml diff --git a/configs/model/layer/cell/hippo-legt.yaml b/s4/configs/model/layer/cell/hippo-legt.yaml similarity index 100% rename from configs/model/layer/cell/hippo-legt.yaml rename to s4/configs/model/layer/cell/hippo-legt.yaml diff --git a/configs/model/layer/cell/hippo-timestamp.yaml b/s4/configs/model/layer/cell/hippo-timestamp.yaml similarity index 100% rename from configs/model/layer/cell/hippo-timestamp.yaml rename to s4/configs/model/layer/cell/hippo-timestamp.yaml diff --git a/configs/model/layer/cell/lmu.yaml b/s4/configs/model/layer/cell/lmu.yaml similarity index 100% rename from configs/model/layer/cell/lmu.yaml rename to s4/configs/model/layer/cell/lmu.yaml diff --git a/configs/model/layer/cell/rnn.yaml b/s4/configs/model/layer/cell/rnn.yaml similarity index 100% rename from configs/model/layer/cell/rnn.yaml rename to s4/configs/model/layer/cell/rnn.yaml diff --git a/configs/model/layer/cell/sru.yaml b/s4/configs/model/layer/cell/sru.yaml similarity index 100% rename from configs/model/layer/cell/sru.yaml rename to s4/configs/model/layer/cell/sru.yaml diff --git a/configs/model/layer/conv1d.yaml b/s4/configs/model/layer/conv1d.yaml similarity index 100% rename from configs/model/layer/conv1d.yaml rename to s4/configs/model/layer/conv1d.yaml diff --git a/configs/model/layer/conv2d.yaml b/s4/configs/model/layer/conv2d.yaml similarity index 100% rename from configs/model/layer/conv2d.yaml rename to s4/configs/model/layer/conv2d.yaml diff --git a/configs/model/layer/ff.yaml b/s4/configs/model/layer/ff.yaml similarity index 100% rename from configs/model/layer/ff.yaml rename to s4/configs/model/layer/ff.yaml diff --git a/configs/model/layer/id.yaml b/s4/configs/model/layer/id.yaml similarity index 100% rename from configs/model/layer/id.yaml rename to s4/configs/model/layer/id.yaml diff --git a/configs/model/layer/lssl.yaml b/s4/configs/model/layer/lssl.yaml similarity index 100% rename from configs/model/layer/lssl.yaml rename to s4/configs/model/layer/lssl.yaml diff --git a/configs/model/layer/lstm.yaml b/s4/configs/model/layer/lstm.yaml similarity index 100% rename from configs/model/layer/lstm.yaml rename to s4/configs/model/layer/lstm.yaml diff --git a/configs/model/layer/mega.yaml b/s4/configs/model/layer/mega.yaml similarity index 100% rename from configs/model/layer/mega.yaml rename to s4/configs/model/layer/mega.yaml diff --git a/configs/model/layer/mha.yaml b/s4/configs/model/layer/mha.yaml similarity index 100% rename from configs/model/layer/mha.yaml rename to s4/configs/model/layer/mha.yaml diff --git a/configs/model/layer/performer.yaml b/s4/configs/model/layer/performer.yaml similarity index 100% rename from configs/model/layer/performer.yaml rename to s4/configs/model/layer/performer.yaml diff --git a/configs/model/layer/rnn.yaml b/s4/configs/model/layer/rnn.yaml similarity index 100% rename from configs/model/layer/rnn.yaml rename to s4/configs/model/layer/rnn.yaml diff --git a/configs/model/layer/s4.yaml b/s4/configs/model/layer/s4.yaml similarity index 100% rename from configs/model/layer/s4.yaml rename to s4/configs/model/layer/s4.yaml diff --git a/configs/model/layer/s4d.yaml b/s4/configs/model/layer/s4d.yaml similarity index 100% rename from configs/model/layer/s4d.yaml rename to s4/configs/model/layer/s4d.yaml diff --git a/configs/model/layer/s4d_example.yaml b/s4/configs/model/layer/s4d_example.yaml similarity index 100% rename from configs/model/layer/s4d_example.yaml rename to s4/configs/model/layer/s4d_example.yaml diff --git a/configs/model/layer/s4ff.yaml b/s4/configs/model/layer/s4ff.yaml similarity index 100% rename from configs/model/layer/s4ff.yaml rename to s4/configs/model/layer/s4ff.yaml diff --git a/configs/model/layer/s4nd.yaml b/s4/configs/model/layer/s4nd.yaml similarity index 100% rename from configs/model/layer/s4nd.yaml rename to s4/configs/model/layer/s4nd.yaml diff --git a/configs/model/layer/s4s4ff.yaml b/s4/configs/model/layer/s4s4ff.yaml similarity index 100% rename from configs/model/layer/s4s4ff.yaml rename to s4/configs/model/layer/s4s4ff.yaml diff --git a/configs/model/layer/sru.yaml b/s4/configs/model/layer/sru.yaml similarity index 100% rename from configs/model/layer/sru.yaml rename to s4/configs/model/layer/sru.yaml diff --git a/configs/model/layer/standalone.yaml b/s4/configs/model/layer/standalone.yaml similarity index 100% rename from configs/model/layer/standalone.yaml rename to s4/configs/model/layer/standalone.yaml diff --git a/configs/model/layer/transformer.yaml b/s4/configs/model/layer/transformer.yaml similarity index 100% rename from configs/model/layer/transformer.yaml rename to s4/configs/model/layer/transformer.yaml diff --git a/configs/model/layer/vit.yaml b/s4/configs/model/layer/vit.yaml similarity index 100% rename from configs/model/layer/vit.yaml rename to s4/configs/model/layer/vit.yaml diff --git a/configs/model/mega.yaml b/s4/configs/model/mega.yaml similarity index 100% rename from configs/model/mega.yaml rename to s4/configs/model/mega.yaml diff --git a/configs/model/nonaka/inception.yaml b/s4/configs/model/nonaka/inception.yaml similarity index 100% rename from configs/model/nonaka/inception.yaml rename to s4/configs/model/nonaka/inception.yaml diff --git a/configs/model/nonaka/resnet.yaml b/s4/configs/model/nonaka/resnet.yaml similarity index 100% rename from configs/model/nonaka/resnet.yaml rename to s4/configs/model/nonaka/resnet.yaml diff --git a/configs/model/nonaka/xresnet.yaml b/s4/configs/model/nonaka/xresnet.yaml similarity index 100% rename from configs/model/nonaka/xresnet.yaml rename to s4/configs/model/nonaka/xresnet.yaml diff --git a/configs/model/s4.yaml b/s4/configs/model/s4.yaml similarity index 100% rename from configs/model/s4.yaml rename to s4/configs/model/s4.yaml diff --git a/configs/model/sashimi-standalone.yaml b/s4/configs/model/sashimi-standalone.yaml similarity index 100% rename from configs/model/sashimi-standalone.yaml rename to s4/configs/model/sashimi-standalone.yaml diff --git a/configs/model/sashimi-transformer.yaml b/s4/configs/model/sashimi-transformer.yaml similarity index 100% rename from configs/model/sashimi-transformer.yaml rename to s4/configs/model/sashimi-transformer.yaml diff --git a/configs/model/sashimi.yaml b/s4/configs/model/sashimi.yaml similarity index 100% rename from configs/model/sashimi.yaml rename to s4/configs/model/sashimi.yaml diff --git a/configs/model/transformer.yaml b/s4/configs/model/transformer.yaml similarity index 100% rename from configs/model/transformer.yaml rename to s4/configs/model/transformer.yaml diff --git a/configs/model/unet.yaml b/s4/configs/model/unet.yaml similarity index 100% rename from configs/model/unet.yaml rename to s4/configs/model/unet.yaml diff --git a/configs/model/vit/vit.yaml b/s4/configs/model/vit/vit.yaml similarity index 100% rename from configs/model/vit/vit.yaml rename to s4/configs/model/vit/vit.yaml diff --git a/configs/model/vit/vit_b_16.yaml b/s4/configs/model/vit/vit_b_16.yaml similarity index 100% rename from configs/model/vit/vit_b_16.yaml rename to s4/configs/model/vit/vit_b_16.yaml diff --git a/configs/model/vit/vit_s_16.yaml b/s4/configs/model/vit/vit_s_16.yaml similarity index 100% rename from configs/model/vit/vit_s_16.yaml rename to s4/configs/model/vit/vit_s_16.yaml diff --git a/configs/optimizer/adam.yaml b/s4/configs/optimizer/adam.yaml similarity index 100% rename from configs/optimizer/adam.yaml rename to s4/configs/optimizer/adam.yaml diff --git a/configs/optimizer/adamw.yaml b/s4/configs/optimizer/adamw.yaml similarity index 100% rename from configs/optimizer/adamw.yaml rename to s4/configs/optimizer/adamw.yaml diff --git a/configs/optimizer/lamb.yaml b/s4/configs/optimizer/lamb.yaml similarity index 100% rename from configs/optimizer/lamb.yaml rename to s4/configs/optimizer/lamb.yaml diff --git a/configs/optimizer/sgd.yaml b/s4/configs/optimizer/sgd.yaml similarity index 100% rename from configs/optimizer/sgd.yaml rename to s4/configs/optimizer/sgd.yaml diff --git a/configs/pipeline/aan.yaml b/s4/configs/pipeline/aan.yaml similarity index 100% rename from configs/pipeline/aan.yaml rename to s4/configs/pipeline/aan.yaml diff --git a/configs/pipeline/adding.yaml b/s4/configs/pipeline/adding.yaml similarity index 100% rename from configs/pipeline/adding.yaml rename to s4/configs/pipeline/adding.yaml diff --git a/configs/pipeline/celeba-all-2d.yaml b/s4/configs/pipeline/celeba-all-2d.yaml similarity index 100% rename from configs/pipeline/celeba-all-2d.yaml rename to s4/configs/pipeline/celeba-all-2d.yaml diff --git a/configs/pipeline/cifar-2d.yaml b/s4/configs/pipeline/cifar-2d.yaml similarity index 100% rename from configs/pipeline/cifar-2d.yaml rename to s4/configs/pipeline/cifar-2d.yaml diff --git a/configs/pipeline/cifar.yaml b/s4/configs/pipeline/cifar.yaml similarity index 100% rename from configs/pipeline/cifar.yaml rename to s4/configs/pipeline/cifar.yaml diff --git a/configs/pipeline/copying.yaml b/s4/configs/pipeline/copying.yaml similarity index 100% rename from configs/pipeline/copying.yaml rename to s4/configs/pipeline/copying.yaml diff --git a/configs/pipeline/delay.yaml b/s4/configs/pipeline/delay.yaml similarity index 100% rename from configs/pipeline/delay.yaml rename to s4/configs/pipeline/delay.yaml diff --git a/configs/pipeline/ema.yaml b/s4/configs/pipeline/ema.yaml similarity index 100% rename from configs/pipeline/ema.yaml rename to s4/configs/pipeline/ema.yaml diff --git a/configs/pipeline/hmdb51_convnext.yaml b/s4/configs/pipeline/hmdb51_convnext.yaml similarity index 100% rename from configs/pipeline/hmdb51_convnext.yaml rename to s4/configs/pipeline/hmdb51_convnext.yaml diff --git a/configs/pipeline/imagenet.yaml b/s4/configs/pipeline/imagenet.yaml similarity index 100% rename from configs/pipeline/imagenet.yaml rename to s4/configs/pipeline/imagenet.yaml diff --git a/configs/pipeline/imdb.yaml b/s4/configs/pipeline/imdb.yaml similarity index 100% rename from configs/pipeline/imdb.yaml rename to s4/configs/pipeline/imdb.yaml diff --git a/configs/pipeline/informer.yaml b/s4/configs/pipeline/informer.yaml similarity index 100% rename from configs/pipeline/informer.yaml rename to s4/configs/pipeline/informer.yaml diff --git a/configs/pipeline/listops.yaml b/s4/configs/pipeline/listops.yaml similarity index 100% rename from configs/pipeline/listops.yaml rename to s4/configs/pipeline/listops.yaml diff --git a/configs/pipeline/mnist.yaml b/s4/configs/pipeline/mnist.yaml similarity index 100% rename from configs/pipeline/mnist.yaml rename to s4/configs/pipeline/mnist.yaml diff --git a/configs/pipeline/pathfinder.yaml b/s4/configs/pipeline/pathfinder.yaml similarity index 100% rename from configs/pipeline/pathfinder.yaml rename to s4/configs/pipeline/pathfinder.yaml diff --git a/configs/pipeline/pathx.yaml b/s4/configs/pipeline/pathx.yaml similarity index 100% rename from configs/pipeline/pathx.yaml rename to s4/configs/pipeline/pathx.yaml diff --git a/configs/pipeline/reconstruct.yaml b/s4/configs/pipeline/reconstruct.yaml similarity index 100% rename from configs/pipeline/reconstruct.yaml rename to s4/configs/pipeline/reconstruct.yaml diff --git a/configs/pipeline/sc.yaml b/s4/configs/pipeline/sc.yaml similarity index 100% rename from configs/pipeline/sc.yaml rename to s4/configs/pipeline/sc.yaml diff --git a/configs/pipeline/wt103.yaml b/s4/configs/pipeline/wt103.yaml similarity index 100% rename from configs/pipeline/wt103.yaml rename to s4/configs/pipeline/wt103.yaml diff --git a/configs/scheduler/constant.yaml b/s4/configs/scheduler/constant.yaml similarity index 100% rename from configs/scheduler/constant.yaml rename to s4/configs/scheduler/constant.yaml diff --git a/configs/scheduler/constant_warmup.yaml b/s4/configs/scheduler/constant_warmup.yaml similarity index 100% rename from configs/scheduler/constant_warmup.yaml rename to s4/configs/scheduler/constant_warmup.yaml diff --git a/configs/scheduler/cosine.yaml b/s4/configs/scheduler/cosine.yaml similarity index 100% rename from configs/scheduler/cosine.yaml rename to s4/configs/scheduler/cosine.yaml diff --git a/configs/scheduler/cosine_warmup.yaml b/s4/configs/scheduler/cosine_warmup.yaml similarity index 100% rename from configs/scheduler/cosine_warmup.yaml rename to s4/configs/scheduler/cosine_warmup.yaml diff --git a/configs/scheduler/linear_warmup.yaml b/s4/configs/scheduler/linear_warmup.yaml similarity index 100% rename from configs/scheduler/linear_warmup.yaml rename to s4/configs/scheduler/linear_warmup.yaml diff --git a/configs/scheduler/multistep.yaml b/s4/configs/scheduler/multistep.yaml similarity index 100% rename from configs/scheduler/multistep.yaml rename to s4/configs/scheduler/multistep.yaml diff --git a/configs/scheduler/plateau.yaml b/s4/configs/scheduler/plateau.yaml similarity index 100% rename from configs/scheduler/plateau.yaml rename to s4/configs/scheduler/plateau.yaml diff --git a/configs/scheduler/step.yaml b/s4/configs/scheduler/step.yaml similarity index 100% rename from configs/scheduler/step.yaml rename to s4/configs/scheduler/step.yaml diff --git a/configs/scheduler/timm_cosine.yaml b/s4/configs/scheduler/timm_cosine.yaml similarity index 100% rename from configs/scheduler/timm_cosine.yaml rename to s4/configs/scheduler/timm_cosine.yaml diff --git a/configs/task/forecasting.yaml b/s4/configs/task/forecasting.yaml similarity index 100% rename from configs/task/forecasting.yaml rename to s4/configs/task/forecasting.yaml diff --git a/configs/task/lm.yaml b/s4/configs/task/lm.yaml similarity index 100% rename from configs/task/lm.yaml rename to s4/configs/task/lm.yaml diff --git a/configs/task/multiclass_classification.yaml b/s4/configs/task/multiclass_classification.yaml similarity index 100% rename from configs/task/multiclass_classification.yaml rename to s4/configs/task/multiclass_classification.yaml diff --git a/configs/task/multilabel_classification.yaml b/s4/configs/task/multilabel_classification.yaml similarity index 100% rename from configs/task/multilabel_classification.yaml rename to s4/configs/task/multilabel_classification.yaml diff --git a/configs/task/regression.yaml b/s4/configs/task/regression.yaml similarity index 100% rename from configs/task/regression.yaml rename to s4/configs/task/regression.yaml diff --git a/configs/task/video.yaml b/s4/configs/task/video.yaml similarity index 100% rename from configs/task/video.yaml rename to s4/configs/task/video.yaml diff --git a/configs/trainer/debug.yaml b/s4/configs/trainer/debug.yaml similarity index 100% rename from configs/trainer/debug.yaml rename to s4/configs/trainer/debug.yaml diff --git a/configs/trainer/default.yaml b/s4/configs/trainer/default.yaml similarity index 100% rename from configs/trainer/default.yaml rename to s4/configs/trainer/default.yaml diff --git a/configs/trainer/lm.yaml b/s4/configs/trainer/lm.yaml similarity index 100% rename from configs/trainer/lm.yaml rename to s4/configs/trainer/lm.yaml diff --git a/s4/extensions/__init__.py b/s4/extensions/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/extensions/kernels/README.md b/s4/extensions/kernels/README.md similarity index 100% rename from extensions/kernels/README.md rename to s4/extensions/kernels/README.md diff --git a/s4/extensions/kernels/__init__.py b/s4/extensions/kernels/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/extensions/kernels/benchmark_cauchy.py b/s4/extensions/kernels/benchmark_cauchy.py similarity index 100% rename from extensions/kernels/benchmark_cauchy.py rename to s4/extensions/kernels/benchmark_cauchy.py diff --git a/extensions/kernels/benchmark_cauchy_tune.py b/s4/extensions/kernels/benchmark_cauchy_tune.py similarity index 100% rename from extensions/kernels/benchmark_cauchy_tune.py rename to s4/extensions/kernels/benchmark_cauchy_tune.py diff --git a/extensions/kernels/cauchy.cpp b/s4/extensions/kernels/cauchy.cpp similarity index 100% rename from extensions/kernels/cauchy.cpp rename to s4/extensions/kernels/cauchy.cpp diff --git a/extensions/kernels/cauchy.py b/s4/extensions/kernels/cauchy.py similarity index 100% rename from extensions/kernels/cauchy.py rename to s4/extensions/kernels/cauchy.py diff --git a/extensions/kernels/cauchy_cuda.cu b/s4/extensions/kernels/cauchy_cuda.cu similarity index 100% rename from extensions/kernels/cauchy_cuda.cu rename to s4/extensions/kernels/cauchy_cuda.cu diff --git a/extensions/kernels/map.h b/s4/extensions/kernels/map.h similarity index 100% rename from extensions/kernels/map.h rename to s4/extensions/kernels/map.h diff --git a/extensions/kernels/setup.py b/s4/extensions/kernels/setup.py similarity index 100% rename from extensions/kernels/setup.py rename to s4/extensions/kernels/setup.py diff --git a/extensions/kernels/test_cauchy.py b/s4/extensions/kernels/test_cauchy.py similarity index 100% rename from extensions/kernels/test_cauchy.py rename to s4/extensions/kernels/test_cauchy.py diff --git a/extensions/kernels/test_vandermonde.py b/s4/extensions/kernels/test_vandermonde.py similarity index 96% rename from extensions/kernels/test_vandermonde.py rename to s4/extensions/kernels/test_vandermonde.py index 4ffff619..5bdc68f9 100644 --- a/extensions/kernels/test_vandermonde.py +++ b/s4/extensions/kernels/test_vandermonde.py @@ -6,7 +6,7 @@ from einops import rearrange -from src.ops.vandermonde import log_vandermonde, log_vandermonde_fast +from s4.src.ops.vandermonde import log_vandermonde, log_vandermonde_fast @pytest.mark.parametrize('L', [3, 17, 489, 2**10, 1047, 2**11, 2**12]) diff --git a/extensions/kernels/tune_cauchy.py b/s4/extensions/kernels/tune_cauchy.py similarity index 100% rename from extensions/kernels/tune_cauchy.py rename to s4/extensions/kernels/tune_cauchy.py diff --git a/extensions/kernels/tune_cauchy.sh b/s4/extensions/kernels/tune_cauchy.sh similarity index 100% rename from extensions/kernels/tune_cauchy.sh rename to s4/extensions/kernels/tune_cauchy.sh diff --git a/extensions/kernels/tuner.py b/s4/extensions/kernels/tuner.py similarity index 100% rename from extensions/kernels/tuner.py rename to s4/extensions/kernels/tuner.py diff --git a/extensions/kernels/tuning_setup.py b/s4/extensions/kernels/tuning_setup.py similarity index 100% rename from extensions/kernels/tuning_setup.py rename to s4/extensions/kernels/tuning_setup.py diff --git a/extensions/kernels/vandermonde.py b/s4/extensions/kernels/vandermonde.py similarity index 100% rename from extensions/kernels/vandermonde.py rename to s4/extensions/kernels/vandermonde.py diff --git a/generate.py b/s4/generate.py similarity index 97% rename from generate.py rename to s4/generate.py index c47f2d23..9bc659ed 100644 --- a/generate.py +++ b/s4/generate.py @@ -12,10 +12,10 @@ from torch.distributions import Categorical from tqdm.auto import tqdm -from src import utils -from src.dataloaders.audio import mu_law_decode -from src.models.baselines.wavenet import WaveNetModel -from train import SequenceLightningModule +from s4.src import utils +from s4.src.dataloaders.audio import mu_law_decode +from s4.src.models.baselines.wavenet import WaveNetModel +from s4.train import SequenceLightningModule def test_step(model): B, L = 2, 64 diff --git a/models/README.md b/s4/models/README.md similarity index 100% rename from models/README.md rename to s4/models/README.md diff --git a/s4/models/__init__.py b/s4/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/models/dss/README.md b/s4/models/dss/README.md similarity index 100% rename from models/dss/README.md rename to s4/models/dss/README.md diff --git a/s4/models/dss/__init__.py b/s4/models/dss/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/models/hippo/README.md b/s4/models/hippo/README.md similarity index 100% rename from models/hippo/README.md rename to s4/models/hippo/README.md diff --git a/s4/models/hippo/__init__.py b/s4/models/hippo/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/models/related/README.md b/s4/models/related/README.md similarity index 100% rename from models/related/README.md rename to s4/models/related/README.md diff --git a/s4/models/related/__init__.py b/s4/models/related/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/models/s4/README.md b/s4/models/s4/README.md similarity index 100% rename from models/s4/README.md rename to s4/models/s4/README.md diff --git a/s4/models/s4/__init__.py b/s4/models/s4/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/models/s4/experiments.md b/s4/models/s4/experiments.md similarity index 100% rename from models/s4/experiments.md rename to s4/models/s4/experiments.md diff --git a/s4/models/s4/lssl.md b/s4/models/s4/lssl.md new file mode 100644 index 00000000..e69de29b diff --git a/models/s4/s4.py b/s4/models/s4/s4.py similarity index 99% rename from models/s4/s4.py rename to s4/models/s4/s4.py index deb53419..60783c41 100644 --- a/models/s4/s4.py +++ b/s4/models/s4/s4.py @@ -43,8 +43,8 @@ def get_logger(name=__name__, level=logging.INFO) -> logging.Logger: # Try CUDA extension try: - from extensions.kernels.cauchy import cauchy_mult as cauchy_cuda - from extensions.kernels.vandermonde import log_vandermonde_cuda + from s4.extensions.kernels.cauchy import cauchy_mult as cauchy_cuda + from s4.extensions.kernels.vandermonde import log_vandermonde_cuda has_cuda_extension = True log.info("CUDA extension for structured kernels (Cauchy and Vandermonde multiplication) found.") except: diff --git a/models/s4/s4d.py b/s4/models/s4/s4d.py similarity index 98% rename from models/s4/s4d.py rename to s4/models/s4/s4d.py index 49ef2a9b..b77338e2 100644 --- a/models/s4/s4d.py +++ b/s4/models/s4/s4d.py @@ -6,7 +6,7 @@ import torch.nn.functional as F from einops import rearrange, repeat -from src.models.nn import DropoutNd +from s4.src.models.nn import DropoutNd class S4DKernel(nn.Module): """Generate convolution kernel from diagonal SSM parameters.""" diff --git a/models/s4nd/README.md b/s4/models/s4nd/README.md similarity index 100% rename from models/s4nd/README.md rename to s4/models/s4nd/README.md diff --git a/s4/models/s4nd/__init__.py b/s4/models/s4nd/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/models/sashimi/README.md b/s4/models/sashimi/README.md similarity index 100% rename from models/sashimi/README.md rename to s4/models/sashimi/README.md diff --git a/s4/models/sashimi/__init__.py b/s4/models/sashimi/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/models/sashimi/metrics.py b/s4/models/sashimi/metrics.py similarity index 100% rename from models/sashimi/metrics.py rename to s4/models/sashimi/metrics.py diff --git a/s4/models/sashimi/mturk/__init__.py b/s4/models/sashimi/mturk/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/models/sashimi/mturk/mos/MTurk SC09 MOS.ipynb b/s4/models/sashimi/mturk/mos/MTurk SC09 MOS.ipynb similarity index 100% rename from models/sashimi/mturk/mos/MTurk SC09 MOS.ipynb rename to s4/models/sashimi/mturk/mos/MTurk SC09 MOS.ipynb diff --git a/models/sashimi/mturk/mos/MTurk YouTubeMix MOS.ipynb b/s4/models/sashimi/mturk/mos/MTurk YouTubeMix MOS.ipynb similarity index 100% rename from models/sashimi/mturk/mos/MTurk YouTubeMix MOS.ipynb rename to s4/models/sashimi/mturk/mos/MTurk YouTubeMix MOS.ipynb diff --git a/s4/models/sashimi/mturk/mos/__init__.py b/s4/models/sashimi/mturk/mos/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/models/sashimi/mturk/prepare_sc09.py b/s4/models/sashimi/mturk/prepare_sc09.py similarity index 100% rename from models/sashimi/mturk/prepare_sc09.py rename to s4/models/sashimi/mturk/prepare_sc09.py diff --git a/models/sashimi/mturk/template_music.py b/s4/models/sashimi/mturk/template_music.py similarity index 100% rename from models/sashimi/mturk/template_music.py rename to s4/models/sashimi/mturk/template_music.py diff --git a/models/sashimi/mturk/template_speech.py b/s4/models/sashimi/mturk/template_speech.py similarity index 100% rename from models/sashimi/mturk/template_speech.py rename to s4/models/sashimi/mturk/template_speech.py diff --git a/models/sashimi/mturk/turk_create_batch.py b/s4/models/sashimi/mturk/turk_create_batch.py similarity index 100% rename from models/sashimi/mturk/turk_create_batch.py rename to s4/models/sashimi/mturk/turk_create_batch.py diff --git a/models/sashimi/sashimi.py b/s4/models/sashimi/sashimi.py similarity index 99% rename from models/sashimi/sashimi.py rename to s4/models/sashimi/sashimi.py index 9f337f1c..6ae20734 100644 --- a/models/sashimi/sashimi.py +++ b/s4/models/sashimi/sashimi.py @@ -16,7 +16,7 @@ from einops import rearrange -from models.s4.s4 import LinearActivation, S4Block as S4 +from s4.models.s4.s4 import LinearActivation, S4Block as S4 class DownPool(nn.Module): def __init__(self, d_input, expand, pool): diff --git a/models/sashimi/sc09_classifier/README.md b/s4/models/sashimi/sc09_classifier/README.md similarity index 100% rename from models/sashimi/sc09_classifier/README.md rename to s4/models/sashimi/sc09_classifier/README.md diff --git a/s4/models/sashimi/sc09_classifier/__init__.py b/s4/models/sashimi/sc09_classifier/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/s4/models/sashimi/sc09_classifier/datasets/__init__.py b/s4/models/sashimi/sc09_classifier/datasets/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/s4/models/sashimi/sc09_classifier/datasets/speech_commands/__init__.py b/s4/models/sashimi/sc09_classifier/datasets/speech_commands/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/models/sashimi/sc09_classifier/datasets/speech_commands/split_dataset.py b/s4/models/sashimi/sc09_classifier/datasets/speech_commands/split_dataset.py similarity index 100% rename from models/sashimi/sc09_classifier/datasets/speech_commands/split_dataset.py rename to s4/models/sashimi/sc09_classifier/datasets/speech_commands/split_dataset.py diff --git a/models/sashimi/sc09_classifier/download_speech_commands_dataset.sh b/s4/models/sashimi/sc09_classifier/download_speech_commands_dataset.sh similarity index 100% rename from models/sashimi/sc09_classifier/download_speech_commands_dataset.sh rename to s4/models/sashimi/sc09_classifier/download_speech_commands_dataset.sh diff --git a/s4/models/sashimi/sc09_classifier/models/__init__.py b/s4/models/sashimi/sc09_classifier/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/models/sashimi/sc09_classifier/models/resnext.py b/s4/models/sashimi/sc09_classifier/models/resnext.py similarity index 100% rename from models/sashimi/sc09_classifier/models/resnext.py rename to s4/models/sashimi/sc09_classifier/models/resnext.py diff --git a/models/sashimi/sc09_classifier/requirements.txt b/s4/models/sashimi/sc09_classifier/requirements.txt similarity index 100% rename from models/sashimi/sc09_classifier/requirements.txt rename to s4/models/sashimi/sc09_classifier/requirements.txt diff --git a/models/sashimi/sc09_classifier/speech_commands_dataset.py b/s4/models/sashimi/sc09_classifier/speech_commands_dataset.py similarity index 100% rename from models/sashimi/sc09_classifier/speech_commands_dataset.py rename to s4/models/sashimi/sc09_classifier/speech_commands_dataset.py diff --git a/models/sashimi/sc09_classifier/test_speech_commands.py b/s4/models/sashimi/sc09_classifier/test_speech_commands.py similarity index 100% rename from models/sashimi/sc09_classifier/test_speech_commands.py rename to s4/models/sashimi/sc09_classifier/test_speech_commands.py diff --git a/models/sashimi/sc09_classifier/train_speech_commands.py b/s4/models/sashimi/sc09_classifier/train_speech_commands.py similarity index 99% rename from models/sashimi/sc09_classifier/train_speech_commands.py rename to s4/models/sashimi/sc09_classifier/train_speech_commands.py index aac58adf..d8a8794e 100755 --- a/models/sashimi/sc09_classifier/train_speech_commands.py +++ b/s4/models/sashimi/sc09_classifier/train_speech_commands.py @@ -17,7 +17,7 @@ from torchvision.transforms import Compose from tqdm.auto import tqdm -from models.resnext import CifarResNeXt +from s4.models.resnext import CifarResNeXt from speech_commands_dataset import BackgroundNoiseDataset, CLASSES, SpeechCommandsDataset from transforms import ( AddBackgroundNoiseOnSTFT, diff --git a/models/sashimi/sc09_classifier/transforms/__init__.py b/s4/models/sashimi/sc09_classifier/transforms/__init__.py similarity index 100% rename from models/sashimi/sc09_classifier/transforms/__init__.py rename to s4/models/sashimi/sc09_classifier/transforms/__init__.py diff --git a/models/sashimi/sc09_classifier/transforms/transforms_stft.py b/s4/models/sashimi/sc09_classifier/transforms/transforms_stft.py similarity index 100% rename from models/sashimi/sc09_classifier/transforms/transforms_stft.py rename to s4/models/sashimi/sc09_classifier/transforms/transforms_stft.py diff --git a/models/sashimi/sc09_classifier/transforms/transforms_wav.py b/s4/models/sashimi/sc09_classifier/transforms/transforms_wav.py similarity index 100% rename from models/sashimi/sc09_classifier/transforms/transforms_wav.py rename to s4/models/sashimi/sc09_classifier/transforms/transforms_wav.py diff --git a/s4/src/__init__.py b/s4/src/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/callbacks/norms.py b/s4/src/callbacks/norms.py similarity index 100% rename from src/callbacks/norms.py rename to s4/src/callbacks/norms.py diff --git a/src/callbacks/params.py b/s4/src/callbacks/params.py similarity index 100% rename from src/callbacks/params.py rename to s4/src/callbacks/params.py diff --git a/src/callbacks/progressive_resizing.py b/s4/src/callbacks/progressive_resizing.py similarity index 98% rename from src/callbacks/progressive_resizing.py rename to s4/src/callbacks/progressive_resizing.py index a96b7130..910fdc28 100644 --- a/src/callbacks/progressive_resizing.py +++ b/s4/src/callbacks/progressive_resizing.py @@ -3,8 +3,8 @@ import numpy as np from pytorch_lightning.callbacks import Callback -import src.utils as utils -from src.utils import registry +import s4.src.utils as utils +from s4.src.utils import registry class ProgressiveResizing(Callback): diff --git a/src/callbacks/timer.py b/s4/src/callbacks/timer.py similarity index 100% rename from src/callbacks/timer.py rename to s4/src/callbacks/timer.py diff --git a/src/callbacks/wandb.py b/s4/src/callbacks/wandb.py similarity index 100% rename from src/callbacks/wandb.py rename to s4/src/callbacks/wandb.py diff --git a/src/dataloaders/README.md b/s4/src/dataloaders/README.md similarity index 100% rename from src/dataloaders/README.md rename to s4/src/dataloaders/README.md diff --git a/src/dataloaders/__init__.py b/s4/src/dataloaders/__init__.py similarity index 100% rename from src/dataloaders/__init__.py rename to s4/src/dataloaders/__init__.py diff --git a/src/dataloaders/audio.py b/s4/src/dataloaders/audio.py similarity index 98% rename from src/dataloaders/audio.py rename to s4/src/dataloaders/audio.py index 5eac760f..850ad2e1 100644 --- a/src/dataloaders/audio.py +++ b/s4/src/dataloaders/audio.py @@ -9,7 +9,7 @@ from torch import nn from torch.nn import functional as F -from src.dataloaders.base import default_data_path, SequenceDataset, deprecated +from s4.src.dataloaders.base import default_data_path, SequenceDataset, deprecated def minmax_scale(tensor, range_min=0, range_max=1): @@ -306,7 +306,7 @@ def init_defaults(self): } def setup(self): - from src.dataloaders.audio import QuantizedAudioDataset + from s4.src.dataloaders.audio import QuantizedAudioDataset assert self.path is not None or self.data_dir is not None, "Pass a path to a folder of audio: either `data_dir` for full directory or `path` for relative path." if self.data_dir is None: self.data_dir = default_data_path / self.path @@ -490,7 +490,7 @@ def init_defaults(self): } def setup(self): - from src.dataloaders.audio import SpeechCommands09 + from s4.src.dataloaders.audio import SpeechCommands09 self.data_dir = self.data_dir or default_data_path / self._name_ self.dataset_train = SpeechCommands09( @@ -617,7 +617,7 @@ def init_defaults(self): } def setup(self): - from src.dataloaders.audio import MaestroDataset + from s4.src.dataloaders.audio import MaestroDataset self.data_dir = self.data_dir or default_data_path / self._name_ / 'maestro-v3.0.0' self.dataset_train = MaestroDataset( @@ -765,7 +765,7 @@ def init_defaults(self): } def setup(self): - from src.dataloaders.audio import LJSpeech + from s4.src.dataloaders.audio import LJSpeech self.data_dir = self.data_dir or default_data_path / self._name_ / 'LJSpeech-1.1' / 'wavs' self.dataset_train = LJSpeech( @@ -893,7 +893,7 @@ def init_defaults(self): } def setup(self): - from src.dataloaders.audio import _SpeechCommands09Classification + from s4.src.dataloaders.audio import _SpeechCommands09Classification self.data_dir = self.data_dir or default_data_path / 'sc09' self.dataset_train = _SpeechCommands09Classification( @@ -965,7 +965,7 @@ def init(self): self.l_output = self.length def setup(self): - from src.dataloaders.datasets.sc import _SpeechCommandsGeneration + from s4.src.dataloaders.datasets.sc import _SpeechCommandsGeneration # TODO refactor with data_dir argument self.dataset_train = _SpeechCommandsGeneration( @@ -1039,7 +1039,7 @@ def init(self): return def setup(self): - from src.dataloaders.music import _Music + from s4.src.dataloaders.music import _Music self.music_class = _Music( path=default_data_path, diff --git a/src/dataloaders/base.py b/s4/src/dataloaders/base.py similarity index 99% rename from src/dataloaders/base.py rename to s4/src/dataloaders/base.py index fb1982ab..272b6e64 100644 --- a/src/dataloaders/base.py +++ b/s4/src/dataloaders/base.py @@ -11,7 +11,7 @@ import torchvision from einops import rearrange from einops.layers.torch import Rearrange -from src.utils import is_list, permutations +from s4.src.utils import is_list, permutations from torch.nn import functional as F def deprecated(cls_or_func): diff --git a/src/dataloaders/basic.py b/s4/src/dataloaders/basic.py similarity index 97% rename from src/dataloaders/basic.py rename to s4/src/dataloaders/basic.py index b07206d8..a3c2d4fb 100644 --- a/src/dataloaders/basic.py +++ b/s4/src/dataloaders/basic.py @@ -4,9 +4,9 @@ import torch import torchvision from einops.layers.torch import Rearrange -from src.utils import permutations +from s4.src.utils import permutations -from src.dataloaders.base import default_data_path, ImageResolutionSequenceDataset, ResolutionSequenceDataset, SequenceDataset +from s4.src.dataloaders.base import default_data_path, ImageResolutionSequenceDataset, ResolutionSequenceDataset, SequenceDataset class MNIST(SequenceDataset): @@ -238,7 +238,7 @@ def L(self): def setup(self): self.data_dir = self.data_dir or default_data_path # TODO make same logic as other classes - from src.dataloaders.datasets.sc import _SpeechCommands + from s4.src.dataloaders.datasets.sc import _SpeechCommands # TODO refactor with data_dir argument self.dataset_train = _SpeechCommands( diff --git a/src/dataloaders/datasets/adding.py b/s4/src/dataloaders/datasets/adding.py similarity index 100% rename from src/dataloaders/datasets/adding.py rename to s4/src/dataloaders/datasets/adding.py diff --git a/src/dataloaders/datasets/celeba.py b/s4/src/dataloaders/datasets/celeba.py similarity index 100% rename from src/dataloaders/datasets/celeba.py rename to s4/src/dataloaders/datasets/celeba.py diff --git a/src/dataloaders/datasets/copying.py b/s4/src/dataloaders/datasets/copying.py similarity index 99% rename from src/dataloaders/datasets/copying.py rename to s4/src/dataloaders/datasets/copying.py index 514cebf3..b718cb84 100644 --- a/src/dataloaders/datasets/copying.py +++ b/s4/src/dataloaders/datasets/copying.py @@ -9,7 +9,7 @@ import torch.nn.functional as F import numpy as np -from src.utils import distributed +from s4.src.utils import distributed def np_copying_data(L, M, A, batch_shape=()): diff --git a/src/dataloaders/datasets/delay.py b/s4/src/dataloaders/datasets/delay.py similarity index 96% rename from src/dataloaders/datasets/delay.py rename to s4/src/dataloaders/datasets/delay.py index 06d514a3..ea442421 100644 --- a/src/dataloaders/datasets/delay.py +++ b/s4/src/dataloaders/datasets/delay.py @@ -3,7 +3,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from src.dataloaders.utils.signal import whitesignal +from s4.src.dataloaders.utils.signal import whitesignal class DelayTrainDataset(torch.utils.data.Dataset): diff --git a/src/dataloaders/datasets/music.py b/s4/src/dataloaders/datasets/music.py similarity index 100% rename from src/dataloaders/datasets/music.py rename to s4/src/dataloaders/datasets/music.py diff --git a/src/dataloaders/datasets/reconstruct.py b/s4/src/dataloaders/datasets/reconstruct.py similarity index 95% rename from src/dataloaders/datasets/reconstruct.py rename to s4/src/dataloaders/datasets/reconstruct.py index 4b95c2a7..17a881a1 100644 --- a/src/dataloaders/datasets/reconstruct.py +++ b/s4/src/dataloaders/datasets/reconstruct.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from src.dataloaders.utils.signal import whitesignal +from s4.src.dataloaders.utils.signal import whitesignal class ReconstructTrainDataset(torch.utils.data.Dataset): diff --git a/src/dataloaders/datasets/sc.py b/s4/src/dataloaders/datasets/sc.py similarity index 100% rename from src/dataloaders/datasets/sc.py rename to s4/src/dataloaders/datasets/sc.py diff --git a/src/dataloaders/et.py b/s4/src/dataloaders/et.py similarity index 99% rename from src/dataloaders/et.py rename to s4/src/dataloaders/et.py index f7cc1691..a1d4f1d9 100644 --- a/src/dataloaders/et.py +++ b/s4/src/dataloaders/et.py @@ -17,7 +17,7 @@ import warnings warnings.filterwarnings("ignore") -from src.dataloaders.base import SequenceDataset, default_data_path +from s4.src.dataloaders.base import SequenceDataset, default_data_path class TimeFeature: diff --git a/src/dataloaders/lm.py b/s4/src/dataloaders/lm.py similarity index 98% rename from src/dataloaders/lm.py rename to s4/src/dataloaders/lm.py index 9fc40263..67aadc06 100644 --- a/src/dataloaders/lm.py +++ b/s4/src/dataloaders/lm.py @@ -26,14 +26,14 @@ import torch.nn.functional as F -from src.utils import distributed -import src.utils.train +from s4.src.utils import distributed +import s4.src.utils.train; import s4; src = s4.src log = src.utils.train.get_logger(__name__) -from src.dataloaders.base import SequenceDataset, default_data_path -from src.dataloaders.utils.vocabulary import OpenAIVocab, Vocab -import src.utils as utils +from s4.src.dataloaders.base import SequenceDataset, default_data_path +from s4.src.dataloaders.utils.vocabulary import OpenAIVocab, Vocab +import s4.src.utils as utils # TODO: create a package so we don't have to mess with sys.path? project_root = Path(__file__).parent.parent.absolute() diff --git a/src/dataloaders/lra.py b/s4/src/dataloaders/lra.py similarity index 99% rename from src/dataloaders/lra.py rename to s4/src/dataloaders/lra.py index 154203de..732b8015 100644 --- a/src/dataloaders/lra.py +++ b/s4/src/dataloaders/lra.py @@ -15,7 +15,7 @@ from PIL import Image # Only used for Pathfinder from datasets import DatasetDict, Value, load_dataset -from src.dataloaders.base import default_data_path, SequenceDataset, ImageResolutionSequenceDataset +from s4.src.dataloaders.base import default_data_path, SequenceDataset, ImageResolutionSequenceDataset class IMDB(SequenceDataset): diff --git a/src/dataloaders/prepare/bidmc/README.md b/s4/src/dataloaders/prepare/bidmc/README.md similarity index 100% rename from src/dataloaders/prepare/bidmc/README.md rename to s4/src/dataloaders/prepare/bidmc/README.md diff --git a/src/dataloaders/prepare/bidmc/data.ipynb b/s4/src/dataloaders/prepare/bidmc/data.ipynb similarity index 100% rename from src/dataloaders/prepare/bidmc/data.ipynb rename to s4/src/dataloaders/prepare/bidmc/data.ipynb diff --git a/src/dataloaders/prepare/bidmc/data_loader.py b/s4/src/dataloaders/prepare/bidmc/data_loader.py similarity index 100% rename from src/dataloaders/prepare/bidmc/data_loader.py rename to s4/src/dataloaders/prepare/bidmc/data_loader.py diff --git a/src/dataloaders/prepare/bidmc/process_data.py b/s4/src/dataloaders/prepare/bidmc/process_data.py similarity index 100% rename from src/dataloaders/prepare/bidmc/process_data.py rename to s4/src/dataloaders/prepare/bidmc/process_data.py diff --git a/src/dataloaders/synthetic.py b/s4/src/dataloaders/synthetic.py similarity index 98% rename from src/dataloaders/synthetic.py rename to s4/src/dataloaders/synthetic.py index 8aecc549..7df0b704 100644 --- a/src/dataloaders/synthetic.py +++ b/s4/src/dataloaders/synthetic.py @@ -4,9 +4,9 @@ import torch import torchvision from einops.layers.torch import Rearrange -from src.utils import permutations +from s4.src.utils import permutations -from src.dataloaders.base import SequenceDataset +from s4.src.dataloaders.base import SequenceDataset class Copying(SequenceDataset): diff --git a/src/dataloaders/ts.py b/s4/src/dataloaders/ts.py similarity index 99% rename from src/dataloaders/ts.py rename to s4/src/dataloaders/ts.py index 0caae1f5..55275f26 100644 --- a/src/dataloaders/ts.py +++ b/s4/src/dataloaders/ts.py @@ -6,7 +6,7 @@ from torch import nn from torch.nn import functional as F -from src.dataloaders.base import default_data_path, SequenceDataset, deprecated +from s4.src.dataloaders.base import default_data_path, SequenceDataset, deprecated class BIDMC(SequenceDataset): """BIDMC datasets for Respiratory Rate / Heart Rate / Oxygen Saturation regression""" @@ -95,7 +95,7 @@ def setup(self): assert self.sz_label_sensitivity <= self.clip_len - # from src.dataloaders.eegseizure import balance_dp, split_dp, merge_in_split + # from s4.src.dataloaders.eegseizure import balance_dp, split_dp, merge_in_split if self.machine == "gemini": data_dir = "/media/4tb_hdd" data_dir_tuh = "/media/nvme_data/siyitang/TUH_eeg_seq_v1.5.2" diff --git a/src/dataloaders/utils/cifar_augmentations.py b/s4/src/dataloaders/utils/cifar_augmentations.py similarity index 100% rename from src/dataloaders/utils/cifar_augmentations.py rename to s4/src/dataloaders/utils/cifar_augmentations.py diff --git a/src/dataloaders/utils/signal.py b/s4/src/dataloaders/utils/signal.py similarity index 100% rename from src/dataloaders/utils/signal.py rename to s4/src/dataloaders/utils/signal.py diff --git a/src/dataloaders/utils/timm_mixup.py b/s4/src/dataloaders/utils/timm_mixup.py similarity index 100% rename from src/dataloaders/utils/timm_mixup.py rename to s4/src/dataloaders/utils/timm_mixup.py diff --git a/src/dataloaders/utils/video_loader.py b/s4/src/dataloaders/utils/video_loader.py similarity index 100% rename from src/dataloaders/utils/video_loader.py rename to s4/src/dataloaders/utils/video_loader.py diff --git a/src/dataloaders/utils/vocabulary.py b/s4/src/dataloaders/utils/vocabulary.py similarity index 99% rename from src/dataloaders/utils/vocabulary.py rename to s4/src/dataloaders/utils/vocabulary.py index b2fa7bf0..3e6d281a 100644 --- a/src/dataloaders/utils/vocabulary.py +++ b/s4/src/dataloaders/utils/vocabulary.py @@ -19,7 +19,7 @@ import torch -import src.utils as utils +import s4.src.utils as utils class Vocab(object): diff --git a/src/dataloaders/vision.py b/s4/src/dataloaders/vision.py similarity index 99% rename from src/dataloaders/vision.py rename to s4/src/dataloaders/vision.py index 4b806079..5e949c22 100644 --- a/src/dataloaders/vision.py +++ b/s4/src/dataloaders/vision.py @@ -7,7 +7,7 @@ from torch.nn import functional as F import torchvision -from src.dataloaders.base import default_data_path, SequenceDataset +from s4.src.dataloaders.base import default_data_path, SequenceDataset class CIFAR100(SequenceDataset): _name_ = "cifar100" @@ -175,7 +175,7 @@ def d_input(self): return 3 def setup(self): - from src.dataloaders.datasets.cifarc import _CIFAR10C + from s4.src.dataloaders.datasets.cifarc import _CIFAR10C self.data_dir = self.data_dir or default_data_path / "CIFAR-10-C" # make sure self.corruptions was specified and is a valid choice @@ -943,7 +943,7 @@ class ImageNetP(ImageNet): def setup(self): from pl_bolts.transforms.dataset_normalizations import \ imagenet_normalization - from src.dataloaders.utils.video_loader import VideoFolder + from s4.src.dataloaders.utils.video_loader import VideoFolder from torch.utils.data.dataloader import default_collate self.imagenet_normalization = imagenet_normalization self.default_collate = default_collate diff --git a/src/models/README.md b/s4/src/models/README.md similarity index 100% rename from src/models/README.md rename to s4/src/models/README.md diff --git a/s4/src/models/__init__.py b/s4/src/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/models/baselines/ckconv.py b/s4/src/models/baselines/ckconv.py similarity index 100% rename from src/models/baselines/ckconv.py rename to s4/src/models/baselines/ckconv.py diff --git a/src/models/baselines/convnext_timm.py b/s4/src/models/baselines/convnext_timm.py similarity index 99% rename from src/models/baselines/convnext_timm.py rename to s4/src/models/baselines/convnext_timm.py index 677b4e88..af16d078 100644 --- a/src/models/baselines/convnext_timm.py +++ b/s4/src/models/baselines/convnext_timm.py @@ -29,9 +29,9 @@ from omegaconf import OmegaConf # S4 imports -import src.utils as utils -import src.utils.registry as registry -from src.models.nn import TransposedLinear +import s4.src.utils as utils +import s4.src.utils.registry as registry +from s4.src.models.nn import TransposedLinear __all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this diff --git a/src/models/baselines/gru.py b/s4/src/models/baselines/gru.py similarity index 94% rename from src/models/baselines/gru.py rename to s4/src/models/baselines/gru.py index 0022fea3..03089de9 100644 --- a/src/models/baselines/gru.py +++ b/s4/src/models/baselines/gru.py @@ -2,9 +2,9 @@ import torch from torch import nn -from src.models.sequence import SequenceModule, TransposedModule +from s4.src.models.sequence import SequenceModule, TransposedModule from einops import rearrange -import src.models.nn.utils as U +import s4.src.models.nn.utils as U @TransposedModule class TorchGRU(nn.GRU, SequenceModule): diff --git a/src/models/baselines/lipschitzrnn.py b/s4/src/models/baselines/lipschitzrnn.py similarity index 99% rename from src/models/baselines/lipschitzrnn.py rename to s4/src/models/baselines/lipschitzrnn.py index 8da47a7b..c26ec1d8 100644 --- a/src/models/baselines/lipschitzrnn.py +++ b/s4/src/models/baselines/lipschitzrnn.py @@ -8,7 +8,7 @@ import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable -from src.models.sequence.base import SequenceModule +from s4.src.models.sequence.base import SequenceModule from copy import deepcopy diff --git a/src/models/baselines/lstm.py b/s4/src/models/baselines/lstm.py similarity index 96% rename from src/models/baselines/lstm.py rename to s4/src/models/baselines/lstm.py index 7d23a0e2..d4bbd29a 100644 --- a/src/models/baselines/lstm.py +++ b/s4/src/models/baselines/lstm.py @@ -2,9 +2,9 @@ import torch from torch import nn -from src.models.sequence import SequenceModule, TransposedModule +from s4.src.models.sequence import SequenceModule, TransposedModule from einops import rearrange -import src.models.nn.utils as U +import s4.src.models.nn.utils as U @TransposedModule class TorchLSTM(nn.LSTM, SequenceModule): diff --git a/src/models/baselines/nonaka/LICENSE b/s4/src/models/baselines/nonaka/LICENSE similarity index 100% rename from src/models/baselines/nonaka/LICENSE rename to s4/src/models/baselines/nonaka/LICENSE diff --git a/src/models/baselines/nonaka/README.md b/s4/src/models/baselines/nonaka/README.md similarity index 100% rename from src/models/baselines/nonaka/README.md rename to s4/src/models/baselines/nonaka/README.md diff --git a/src/models/baselines/nonaka/basic_conv1d.py b/s4/src/models/baselines/nonaka/basic_conv1d.py similarity index 100% rename from src/models/baselines/nonaka/basic_conv1d.py rename to s4/src/models/baselines/nonaka/basic_conv1d.py diff --git a/src/models/baselines/nonaka/inception.py b/s4/src/models/baselines/nonaka/inception.py similarity index 97% rename from src/models/baselines/nonaka/inception.py rename to s4/src/models/baselines/nonaka/inception.py index 979f11a2..492799f7 100644 --- a/src/models/baselines/nonaka/inception.py +++ b/s4/src/models/baselines/nonaka/inception.py @@ -3,7 +3,7 @@ import torch.nn.functional as F import math -from src.models.baselines.nonaka.basic_conv1d import AdaptiveConcatPool1d, create_head1d +from s4.src.models.baselines.nonaka.basic_conv1d import AdaptiveConcatPool1d, create_head1d ######################################################################################################## # Inception time inspired by https://github.com/hfawaz/InceptionTime/blob/master/classifiers/inception.py and https://github.com/tcapelle/TimeSeries_fastai/blob/master/inception.py diff --git a/src/models/baselines/nonaka/resnet.py b/s4/src/models/baselines/nonaka/resnet.py similarity index 99% rename from src/models/baselines/nonaka/resnet.py rename to s4/src/models/baselines/nonaka/resnet.py index f35d631c..85e9f591 100644 --- a/src/models/baselines/nonaka/resnet.py +++ b/s4/src/models/baselines/nonaka/resnet.py @@ -3,7 +3,7 @@ import torch.nn.functional as F import math -from src.models.baselines.nonaka.basic_conv1d import create_head1d, Flatten +from s4.src.models.baselines.nonaka.basic_conv1d import create_head1d, Flatten ############################################################################################### # Standard resnet diff --git a/src/models/baselines/nonaka/xresnet.py b/s4/src/models/baselines/nonaka/xresnet.py similarity index 99% rename from src/models/baselines/nonaka/xresnet.py rename to s4/src/models/baselines/nonaka/xresnet.py index 96d319dc..b63888ec 100644 --- a/src/models/baselines/nonaka/xresnet.py +++ b/s4/src/models/baselines/nonaka/xresnet.py @@ -2,7 +2,7 @@ import torch.nn as nn import torch.nn.functional as F -from src.models.baselines.nonaka.basic_conv1d import create_head1d, Flatten +from s4.src.models.baselines.nonaka.basic_conv1d import create_head1d, Flatten from enum import Enum import re diff --git a/src/models/baselines/nrde.py b/s4/src/models/baselines/nrde.py similarity index 100% rename from src/models/baselines/nrde.py rename to s4/src/models/baselines/nrde.py diff --git a/src/models/baselines/odelstm.py b/s4/src/models/baselines/odelstm.py similarity index 100% rename from src/models/baselines/odelstm.py rename to s4/src/models/baselines/odelstm.py diff --git a/src/models/baselines/resnet.py b/s4/src/models/baselines/resnet.py similarity index 100% rename from src/models/baselines/resnet.py rename to s4/src/models/baselines/resnet.py diff --git a/src/models/baselines/resnet_timm.py b/s4/src/models/baselines/resnet_timm.py similarity index 100% rename from src/models/baselines/resnet_timm.py rename to s4/src/models/baselines/resnet_timm.py diff --git a/src/models/baselines/samplernn.py b/s4/src/models/baselines/samplernn.py similarity index 98% rename from src/models/baselines/samplernn.py rename to s4/src/models/baselines/samplernn.py index 859d828b..6b7ff546 100644 --- a/src/models/baselines/samplernn.py +++ b/s4/src/models/baselines/samplernn.py @@ -9,11 +9,11 @@ import math import numpy as np -from src.models.baselines.lstm import TorchLSTM -from src.models.baselines.gru import TorchGRU -from src.models.sequence.base import SequenceModule -from src.models.sequence.modules.s4block import S4Block -from src.dataloaders.audio import mu_law_decode, linear_decode, q_zero +from s4.src.models.baselines.lstm import TorchLSTM +from s4.src.models.baselines.gru import TorchGRU +from s4.src.models.sequence.base import SequenceModule +from s4.src.models.sequence.modules.s4block import S4Block +from s4.src.dataloaders.audio import mu_law_decode, linear_decode, q_zero class StackedRNN(SequenceModule): """ diff --git a/src/models/baselines/transformer.py b/s4/src/models/baselines/transformer.py similarity index 100% rename from src/models/baselines/transformer.py rename to s4/src/models/baselines/transformer.py diff --git a/src/models/baselines/unicornn.py b/s4/src/models/baselines/unicornn.py similarity index 99% rename from src/models/baselines/unicornn.py rename to s4/src/models/baselines/unicornn.py index ecdb9fab..667a96d7 100644 --- a/src/models/baselines/unicornn.py +++ b/s4/src/models/baselines/unicornn.py @@ -13,7 +13,7 @@ from torch.nn import Parameter from collections import namedtuple -from src.models.sequence.base import SequenceModule, TransposedModule +from s4.src.models.sequence.base import SequenceModule, TransposedModule try: from cupy.cuda import function diff --git a/src/models/baselines/vit.py b/s4/src/models/baselines/vit.py similarity index 98% rename from src/models/baselines/vit.py rename to s4/src/models/baselines/vit.py index 1a67d345..580136c9 100644 --- a/src/models/baselines/vit.py +++ b/s4/src/models/baselines/vit.py @@ -7,7 +7,7 @@ from einops import rearrange, repeat from einops.layers.torch import Rearrange -from src.models.sequence.base import SequenceModule +from s4.src.models.sequence.base import SequenceModule class Residual(nn.Module): def __init__(self, fn): diff --git a/src/models/baselines/vit_all.py b/s4/src/models/baselines/vit_all.py similarity index 98% rename from src/models/baselines/vit_all.py rename to s4/src/models/baselines/vit_all.py index 9cb6072a..3529478d 100644 --- a/src/models/baselines/vit_all.py +++ b/s4/src/models/baselines/vit_all.py @@ -17,10 +17,10 @@ from timm.models.helpers import build_model_with_cfg, overlay_external_default_cfg from timm.models.layers import PatchEmbed, Mlp, trunc_normal_, lecun_normal_ -from src.models.sequence.base import SequenceModule -from src.models.nn import Normalization -from src.models.sequence.backbones.block import SequenceResidualBlock -from src.utils.config import to_list, to_dict +from s4.src.models.sequence.base import SequenceModule +from s4.src.models.nn import Normalization +from s4.src.models.sequence.backbones.block import SequenceResidualBlock +from s4.src.utils.config import to_list, to_dict _logger = logging.getLogger(__name__) diff --git a/src/models/baselines/wavenet.py b/s4/src/models/baselines/wavenet.py similarity index 99% rename from src/models/baselines/wavenet.py rename to s4/src/models/baselines/wavenet.py index 38eb09dd..14fba6df 100644 --- a/src/models/baselines/wavenet.py +++ b/s4/src/models/baselines/wavenet.py @@ -10,7 +10,7 @@ from torch.autograd import Variable, Function import numpy as np -from src.models.sequence.base import SequenceModule +from s4.src.models.sequence.base import SequenceModule def mu_law_expansion(data, mu): s = np.sign(data) * (np.exp(np.abs(data) * np.log(mu + 1)) - 1) / mu diff --git a/src/models/functional/cauchy.py b/s4/src/models/functional/cauchy.py similarity index 100% rename from src/models/functional/cauchy.py rename to s4/src/models/functional/cauchy.py diff --git a/src/models/functional/krylov.py b/s4/src/models/functional/krylov.py similarity index 98% rename from src/models/functional/krylov.py rename to s4/src/models/functional/krylov.py index 498e38ca..b5ed5deb 100644 --- a/src/models/functional/krylov.py +++ b/s4/src/models/functional/krylov.py @@ -18,7 +18,7 @@ import torch.nn.functional as F from einops import rearrange, repeat -from src.models.functional.toeplitz import causal_convolution +from s4.src.models.functional.toeplitz import causal_convolution def krylov_sequential(L, A, b, c=None): """Compute the krylov function naively by sequential powering. diff --git a/src/models/functional/toeplitz.py b/s4/src/models/functional/toeplitz.py similarity index 100% rename from src/models/functional/toeplitz.py rename to s4/src/models/functional/toeplitz.py diff --git a/src/models/functional/unroll.py b/s4/src/models/functional/unroll.py similarity index 98% rename from src/models/functional/unroll.py rename to s4/src/models/functional/unroll.py index 14712c92..8de0d8d9 100644 --- a/src/models/functional/unroll.py +++ b/s4/src/models/functional/unroll.py @@ -7,8 +7,8 @@ import numpy as np import math -from src.models.functional.toeplitz import triangular_toeplitz_multiply, triangular_toeplitz_multiply_padded -from src.utils.permutations import bitreversal_po2, bitreversal_permutation +from s4.src.models.functional.toeplitz import triangular_toeplitz_multiply, triangular_toeplitz_multiply_padded +from s4.src.utils.permutations import bitreversal_po2, bitreversal_permutation ### Utilities diff --git a/src/models/functional/vandermonde.py b/s4/src/models/functional/vandermonde.py similarity index 100% rename from src/models/functional/vandermonde.py rename to s4/src/models/functional/vandermonde.py diff --git a/src/models/hippo/hippo.py b/s4/src/models/hippo/hippo.py similarity index 100% rename from src/models/hippo/hippo.py rename to s4/src/models/hippo/hippo.py diff --git a/src/models/hippo/transition.py b/s4/src/models/hippo/transition.py similarity index 98% rename from src/models/hippo/transition.py rename to s4/src/models/hippo/transition.py index 1a81ef72..d7d9a13f 100644 --- a/src/models/hippo/transition.py +++ b/s4/src/models/hippo/transition.py @@ -10,17 +10,17 @@ from scipy import special as ss from einops import rearrange -from src.models.hippo.hippo import transition -from src.models.functional.toeplitz import causal_convolution, causal_convolution_inverse, construct_toeplitz +from s4.src.models.hippo.hippo import transition +from s4.src.models.functional.toeplitz import causal_convolution, causal_convolution_inverse, construct_toeplitz # TODO figure out if we actually need this try: - from extensions.legt.legt import legt_gbt_forward, legt_gbt_backward, legt_gbt_forward_t, legt_gbt_backward_t + from s4.extensions.legt.legt import legt_gbt_forward, legt_gbt_backward, legt_gbt_forward_t, legt_gbt_backward_t except: pass try: - from extensions.trid.trid import trid_gbt_forward, trid_gbt_backward, trid_solve + from s4.extensions.trid.trid import trid_gbt_forward, trid_gbt_backward, trid_solve except: pass # from pytorch_memlab import profile diff --git a/src/models/hippo/visualizations.py b/s4/src/models/hippo/visualizations.py similarity index 99% rename from src/models/hippo/visualizations.py rename to s4/src/models/hippo/visualizations.py index 10244f9a..25a65c7c 100644 --- a/src/models/hippo/visualizations.py +++ b/s4/src/models/hippo/visualizations.py @@ -18,7 +18,7 @@ from scipy import special as ss from einops import rearrange, repeat, reduce -import src.models.functional.unroll as unroll # Not necessary, can comment out and set fast=False in HiPPO modules +import s4.src.models.functional.unroll as unroll # Not necessary, can comment out and set fast=False in HiPPO modules import matplotlib.pyplot as plt from matplotlib.animation import FuncAnimation diff --git a/src/models/nn/__init__.py b/s4/src/models/nn/__init__.py similarity index 100% rename from src/models/nn/__init__.py rename to s4/src/models/nn/__init__.py diff --git a/src/models/nn/activation.py b/s4/src/models/nn/activation.py similarity index 100% rename from src/models/nn/activation.py rename to s4/src/models/nn/activation.py diff --git a/src/models/nn/adaptive_softmax.py b/s4/src/models/nn/adaptive_softmax.py similarity index 99% rename from src/models/nn/adaptive_softmax.py rename to s4/src/models/nn/adaptive_softmax.py index 2c5d13f8..223faf90 100644 --- a/src/models/nn/adaptive_softmax.py +++ b/s4/src/models/nn/adaptive_softmax.py @@ -20,7 +20,7 @@ import torch.nn as nn import torch.nn.functional as F -import src.models.nn.utils as U +import s4.src.models.nn.utils as U class OptionalParameterList(nn.ParameterList): diff --git a/src/models/nn/dropout.py b/s4/src/models/nn/dropout.py similarity index 100% rename from src/models/nn/dropout.py rename to s4/src/models/nn/dropout.py diff --git a/src/models/nn/dxt.py b/s4/src/models/nn/dxt.py similarity index 100% rename from src/models/nn/dxt.py rename to s4/src/models/nn/dxt.py diff --git a/src/models/nn/exprnn/README.md b/s4/src/models/nn/exprnn/README.md similarity index 100% rename from src/models/nn/exprnn/README.md rename to s4/src/models/nn/exprnn/README.md diff --git a/src/models/nn/exprnn/expm32.py b/s4/src/models/nn/exprnn/expm32.py similarity index 100% rename from src/models/nn/exprnn/expm32.py rename to s4/src/models/nn/exprnn/expm32.py diff --git a/src/models/nn/exprnn/initialization.py b/s4/src/models/nn/exprnn/initialization.py similarity index 100% rename from src/models/nn/exprnn/initialization.py rename to s4/src/models/nn/exprnn/initialization.py diff --git a/src/models/nn/exprnn/orthogonal.py b/s4/src/models/nn/exprnn/orthogonal.py similarity index 98% rename from src/models/nn/exprnn/orthogonal.py rename to s4/src/models/nn/exprnn/orthogonal.py index c92493e9..83ed9fd4 100644 --- a/src/models/nn/exprnn/orthogonal.py +++ b/s4/src/models/nn/exprnn/orthogonal.py @@ -4,7 +4,7 @@ import torch.nn as nn from .parametrization import Parametrization -from src.models.nn.activation import ModReLU +from s4.src.models.nn.activation import ModReLU class Orthogonal(Parametrization): diff --git a/src/models/nn/exprnn/parametrization.py b/s4/src/models/nn/exprnn/parametrization.py similarity index 100% rename from src/models/nn/exprnn/parametrization.py rename to s4/src/models/nn/exprnn/parametrization.py diff --git a/src/models/nn/exprnn/trivializations.py b/s4/src/models/nn/exprnn/trivializations.py similarity index 100% rename from src/models/nn/exprnn/trivializations.py rename to s4/src/models/nn/exprnn/trivializations.py diff --git a/src/models/nn/gate.py b/s4/src/models/nn/gate.py similarity index 100% rename from src/models/nn/gate.py rename to s4/src/models/nn/gate.py diff --git a/src/models/nn/initialization.py b/s4/src/models/nn/initialization.py similarity index 100% rename from src/models/nn/initialization.py rename to s4/src/models/nn/initialization.py diff --git a/src/models/nn/linear.py b/s4/src/models/nn/linear.py similarity index 98% rename from src/models/nn/linear.py rename to s4/src/models/nn/linear.py index 24172c43..6847ae4a 100644 --- a/src/models/nn/linear.py +++ b/s4/src/models/nn/linear.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from einops import rearrange -from src.models.nn.activation import Activation +from s4.src.models.nn.activation import Activation contract = torch.einsum diff --git a/src/models/nn/normalization.py b/s4/src/models/nn/normalization.py similarity index 100% rename from src/models/nn/normalization.py rename to s4/src/models/nn/normalization.py diff --git a/src/models/nn/orthogonal.py b/s4/src/models/nn/orthogonal.py similarity index 100% rename from src/models/nn/orthogonal.py rename to s4/src/models/nn/orthogonal.py diff --git a/src/models/nn/residual.py b/s4/src/models/nn/residual.py similarity index 100% rename from src/models/nn/residual.py rename to s4/src/models/nn/residual.py diff --git a/src/models/nn/utils.py b/s4/src/models/nn/utils.py similarity index 100% rename from src/models/nn/utils.py rename to s4/src/models/nn/utils.py diff --git a/src/models/s4/README.md b/s4/src/models/s4/README.md similarity index 100% rename from src/models/s4/README.md rename to s4/src/models/s4/README.md diff --git a/src/models/sequence/README.md b/s4/src/models/sequence/README.md similarity index 100% rename from src/models/sequence/README.md rename to s4/src/models/sequence/README.md diff --git a/src/models/sequence/__init__.py b/s4/src/models/sequence/__init__.py similarity index 100% rename from src/models/sequence/__init__.py rename to s4/src/models/sequence/__init__.py diff --git a/src/models/sequence/attention/linear.py b/s4/src/models/sequence/attention/linear.py similarity index 98% rename from src/models/sequence/attention/linear.py rename to s4/src/models/sequence/attention/linear.py index 16354cd3..63b2a440 100644 --- a/src/models/sequence/attention/linear.py +++ b/s4/src/models/sequence/attention/linear.py @@ -12,8 +12,8 @@ from fast_transformers.feature_maps import elu_feature_map from fast_transformers.masking import TriangularCausalMask -from models.sequence.base import SequenceModule, TransposedModule -import src.models.nn.utils as U +from s4.models.sequence.base import SequenceModule, TransposedModule +import s4.src.models.nn.utils as U try: from apex import amp diff --git a/src/models/sequence/attention/mha.py b/s4/src/models/sequence/attention/mha.py similarity index 97% rename from src/models/sequence/attention/mha.py rename to s4/src/models/sequence/attention/mha.py index a12d2791..a645449e 100644 --- a/src/models/sequence/attention/mha.py +++ b/s4/src/models/sequence/attention/mha.py @@ -4,8 +4,8 @@ import torch.nn.functional as F from torch import nn import hydra -from models.sequence.base import SequenceModule, TransposedModule -import src.models.nn.utils as U +from s4.models.sequence.base import SequenceModule, TransposedModule +import s4.src.models.nn.utils as U from einops import rearrange @TransposedModule diff --git a/src/models/sequence/attention/performer.py b/s4/src/models/sequence/attention/performer.py similarity index 100% rename from src/models/sequence/attention/performer.py rename to s4/src/models/sequence/attention/performer.py diff --git a/s4/src/models/sequence/backbones/__init__.py b/s4/src/models/sequence/backbones/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/models/sequence/backbones/block.py b/s4/src/models/sequence/backbones/block.py similarity index 93% rename from src/models/sequence/backbones/block.py rename to s4/src/models/sequence/backbones/block.py index a6f9c48c..5bbb9492 100644 --- a/src/models/sequence/backbones/block.py +++ b/s4/src/models/sequence/backbones/block.py @@ -11,12 +11,12 @@ import torch from torch import nn -from src.models.nn import Normalization, StochasticDepth, DropoutNd -from src.models.sequence import SequenceModule -from src.models.sequence.modules.pool import registry as pool_registry -from src.models.nn.residual import registry as residual_registry -import src.utils as utils -import src.utils.registry as registry +from s4.src.models.nn import Normalization, StochasticDepth, DropoutNd +from s4.src.models.sequence import SequenceModule +from s4.src.models.sequence.modules.pool import registry as pool_registry +from s4.src.models.nn.residual import registry as residual_registry +import s4.src.utils as utils +import s4.src.utils.registry as registry class SequenceResidualBlock(SequenceModule): diff --git a/src/models/sequence/backbones/model.py b/s4/src/models/sequence/backbones/model.py similarity index 96% rename from src/models/sequence/backbones/model.py rename to s4/src/models/sequence/backbones/model.py index 109494fd..906e96d1 100644 --- a/src/models/sequence/backbones/model.py +++ b/s4/src/models/sequence/backbones/model.py @@ -10,10 +10,10 @@ import torch.nn as nn from einops import rearrange -from src.utils.config import to_list, to_dict -from src.models.sequence.backbones.block import SequenceResidualBlock -from src.models.sequence.base import SequenceModule -from src.models.nn import Normalization, DropoutNd +from s4.src.utils.config import to_list, to_dict +from s4.src.models.sequence.backbones.block import SequenceResidualBlock +from s4.src.models.sequence.base import SequenceModule +from s4.src.models.nn import Normalization, DropoutNd class SequenceModel(SequenceModule): diff --git a/src/models/sequence/backbones/sashimi.py b/s4/src/models/sequence/backbones/sashimi.py similarity index 97% rename from src/models/sequence/backbones/sashimi.py rename to s4/src/models/sequence/backbones/sashimi.py index ff0cb954..2184d644 100644 --- a/src/models/sequence/backbones/sashimi.py +++ b/s4/src/models/sequence/backbones/sashimi.py @@ -3,9 +3,9 @@ import torch.nn as nn import torch.nn.functional as F -from src.models.sequence.base import SequenceModule -from src.models.sequence.modules.pool import DownPool, UpPool -from src.models.sequence.backbones.block import SequenceResidualBlock +from s4.src.models.sequence.base import SequenceModule +from s4.src.models.sequence.modules.pool import DownPool, UpPool +from s4.src.models.sequence.backbones.block import SequenceResidualBlock class Sashimi(SequenceModule): diff --git a/src/models/sequence/backbones/unet.py b/s4/src/models/sequence/backbones/unet.py similarity index 96% rename from src/models/sequence/backbones/unet.py rename to s4/src/models/sequence/backbones/unet.py index bdefb27c..8d9e4ee2 100644 --- a/src/models/sequence/backbones/unet.py +++ b/s4/src/models/sequence/backbones/unet.py @@ -11,10 +11,10 @@ from omegaconf import DictConfig from einops import rearrange, repeat, reduce -import src.utils as utils -from src.models.sequence.base import SequenceModule -from src.models.sequence.modules.pool import DownPool, UpPool, up_registry, registry as down_registry -from src.models.sequence.backbones.block import SequenceResidualBlock +import s4.src.utils as utils +from s4.src.models.sequence.base import SequenceModule +from s4.src.models.sequence.modules.pool import DownPool, UpPool, up_registry, registry as down_registry +from s4.src.models.sequence.backbones.block import SequenceResidualBlock contract = torch.einsum diff --git a/src/models/sequence/base.py b/s4/src/models/sequence/base.py similarity index 100% rename from src/models/sequence/base.py rename to s4/src/models/sequence/base.py diff --git a/src/models/sequence/convs/conv1d.py b/s4/src/models/sequence/convs/conv1d.py similarity index 90% rename from src/models/sequence/convs/conv1d.py rename to s4/src/models/sequence/convs/conv1d.py index a9d42eaa..42250693 100644 --- a/src/models/sequence/convs/conv1d.py +++ b/s4/src/models/sequence/convs/conv1d.py @@ -4,11 +4,11 @@ import torch.nn.functional as F from torch import nn import hydra -from models.sequence.base import SequenceModule +from s4.models.sequence.base import SequenceModule from einops import rearrange -import src.models.nn.utils as U -from src.models.nn import Activation +import s4.src.models.nn.utils as U +from s4.src.models.nn import Activation class Conv1d(SequenceModule): """ Simple wrapper for nn.Conv1d """ diff --git a/src/models/sequence/convs/conv2d.py b/s4/src/models/sequence/convs/conv2d.py similarity index 94% rename from src/models/sequence/convs/conv2d.py rename to s4/src/models/sequence/convs/conv2d.py index 80737aab..8cf3cfd5 100644 --- a/src/models/sequence/convs/conv2d.py +++ b/s4/src/models/sequence/convs/conv2d.py @@ -3,8 +3,8 @@ import torch from torch import nn -from src.models.sequence.base import SequenceModule -from src.models.nn import Activation, DropoutNd +from s4.src.models.sequence.base import SequenceModule +from s4.src.models.nn import Activation, DropoutNd class Conv2d(SequenceModule): """ Simple wrapper for nn.Conv1d """ diff --git a/src/models/sequence/kernels/__init__.py b/s4/src/models/sequence/kernels/__init__.py similarity index 100% rename from src/models/sequence/kernels/__init__.py rename to s4/src/models/sequence/kernels/__init__.py diff --git a/src/models/sequence/kernels/dplr.py b/s4/src/models/sequence/kernels/dplr.py similarity index 98% rename from src/models/sequence/kernels/dplr.py rename to s4/src/models/sequence/kernels/dplr.py index 2f28aa5d..b9d49372 100644 --- a/src/models/sequence/kernels/dplr.py +++ b/s4/src/models/sequence/kernels/dplr.py @@ -4,9 +4,9 @@ import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat -import src.models.hippo.hippo as hippo +import s4.src.models.hippo.hippo as hippo -import src.utils.train +import s4.src.utils.train; import s4; src = s4.src log = src.utils.train.get_logger(__name__) def dplr( diff --git a/src/models/sequence/kernels/fftconv.py b/s4/src/models/sequence/kernels/fftconv.py similarity index 97% rename from src/models/sequence/kernels/fftconv.py rename to s4/src/models/sequence/kernels/fftconv.py index 6396f96c..809b8788 100644 --- a/src/models/sequence/kernels/fftconv.py +++ b/s4/src/models/sequence/kernels/fftconv.py @@ -5,9 +5,9 @@ import torch.nn.functional as F from einops import rearrange, repeat -from src.models.sequence import SequenceModule -from src.models.sequence.kernels import registry as kernel_registry -from src.models.nn import Activation, DropoutNd +from s4.src.models.sequence import SequenceModule +from s4.src.models.sequence.kernels import registry as kernel_registry +from s4.src.models.nn import Activation, DropoutNd contract = torch.einsum diff --git a/src/models/sequence/kernels/kernel.py b/s4/src/models/sequence/kernels/kernel.py similarity index 99% rename from src/models/sequence/kernels/kernel.py rename to s4/src/models/sequence/kernels/kernel.py index 52eaba5b..187bae17 100644 --- a/src/models/sequence/kernels/kernel.py +++ b/s4/src/models/sequence/kernels/kernel.py @@ -6,7 +6,7 @@ import torch import torch.nn as nn -import src.utils.train +import s4.src.utils.train; import s4; src = s4.src log = src.utils.train.get_logger(__name__) diff --git a/src/models/sequence/kernels/ssm.py b/s4/src/models/sequence/kernels/ssm.py similarity index 98% rename from src/models/sequence/kernels/ssm.py rename to s4/src/models/sequence/kernels/ssm.py index 269ccb03..f0bc2c72 100644 --- a/src/models/sequence/kernels/ssm.py +++ b/s4/src/models/sequence/kernels/ssm.py @@ -20,17 +20,17 @@ import numpy as np from einops import rearrange, repeat -import src.models.hippo.hippo as hippo -import src.models.sequence.kernels.dplr as dplr -from src.models.functional.krylov import krylov, power -import src.utils.train +import s4.src.models.hippo.hippo as hippo +import s4.src.models.sequence.kernels.dplr as dplr +from s4.src.models.functional.krylov import krylov, power +import s4.src.utils.train; import s4; src = s4.src log = src.utils.train.get_logger(__name__) # Try CUDA extension try: - from extensions.kernels.cauchy import cauchy_mult as cauchy_cuda - from extensions.kernels.vandermonde import log_vandermonde_cuda + from s4.extensions.kernels.cauchy import cauchy_mult as cauchy_cuda + from s4.extensions.kernels.vandermonde import log_vandermonde_cuda has_cuda_extension = True log.info("CUDA extension for structured kernels (Cauchy and Vandermonde multiplication) found.") except: @@ -41,8 +41,8 @@ try: import pykeops - from src.models.functional.cauchy import cauchy_conj as cauchy_keops - from src.models.functional.vandermonde import log_vandermonde as log_vandermonde_keops, log_vandermonde_transpose as log_vandermonde_transpose_keops + from s4.src.models.functional.cauchy import cauchy_conj as cauchy_keops + from s4.src.models.functional.vandermonde import log_vandermonde as log_vandermonde_keops, log_vandermonde_transpose as log_vandermonde_transpose_keops has_pykeops = True log.info("Pykeops installation found.") @@ -54,12 +54,12 @@ ) # Fallback versions -from src.models.functional.cauchy import cauchy_naive -from src.models.functional.vandermonde import log_vandermonde_naive -from src.models.functional.vandermonde import log_vandermonde_transpose_naive +from s4.src.models.functional.cauchy import cauchy_naive +from s4.src.models.functional.vandermonde import log_vandermonde_naive +from s4.src.models.functional.vandermonde import log_vandermonde_transpose_naive # Base Kernel class -from src.models.sequence.kernels.kernel import Kernel +from s4.src.models.sequence.kernels.kernel import Kernel # Alias torch.einsum; can easily swap to opt_einsum if desired contract = torch.einsum diff --git a/s4/src/models/sequence/modules/__init__.py b/s4/src/models/sequence/modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/models/sequence/modules/ffn.py b/s4/src/models/sequence/modules/ffn.py similarity index 93% rename from src/models/sequence/modules/ffn.py rename to s4/src/models/sequence/modules/ffn.py index 5c4b74fd..57537d1c 100644 --- a/src/models/sequence/modules/ffn.py +++ b/s4/src/models/sequence/modules/ffn.py @@ -2,8 +2,8 @@ from functools import partial from torch import nn -from src.models.sequence.base import SequenceModule -from src.models.nn import LinearActivation, DropoutNd +from s4.src.models.sequence.base import SequenceModule +from s4.src.models.nn import LinearActivation, DropoutNd class FFN(SequenceModule): def __init__( diff --git a/src/models/sequence/modules/lssl.py b/s4/src/models/sequence/modules/lssl.py similarity index 97% rename from src/models/sequence/modules/lssl.py rename to s4/src/models/sequence/modules/lssl.py index 6659fd7c..307f399a 100644 --- a/src/models/sequence/modules/lssl.py +++ b/s4/src/models/sequence/modules/lssl.py @@ -7,12 +7,12 @@ from einops import rearrange, repeat from omegaconf import DictConfig -from src.models.nn import Activation -from src.models.functional.krylov import krylov -from src.models.hippo import transition, hippo -from src.models.functional.toeplitz import causal_convolution -from src.models.sequence.base import SequenceModule, TransposedModule -import src.models.nn.utils as U +from s4.src.models.nn import Activation +from s4.src.models.functional.krylov import krylov +from s4.src.models.hippo import transition, hippo +from s4.src.models.functional.toeplitz import causal_convolution +from s4.src.models.sequence.base import SequenceModule, TransposedModule +import s4.src.models.nn.utils as U def linear_system_from_krylov(u, C, D, k): """ diff --git a/src/models/sequence/modules/megablock.py b/s4/src/models/sequence/modules/megablock.py similarity index 99% rename from src/models/sequence/modules/megablock.py rename to s4/src/models/sequence/modules/megablock.py index 1b381670..481a6dec 100644 --- a/src/models/sequence/modules/megablock.py +++ b/s4/src/models/sequence/modules/megablock.py @@ -15,9 +15,9 @@ import torch.nn.functional as F from einops import rearrange -from src.models.nn import Activation, DropoutNd, Normalization -from src.models.sequence.backbones.block import SequenceResidualBlock -from src.models.sequence.kernels.fftconv import FFTConv +from s4.src.models.nn import Activation, DropoutNd, Normalization +from s4.src.models.sequence.backbones.block import SequenceResidualBlock +from s4.src.models.sequence.kernels.fftconv import FFTConv class MegaBlock(nn.Module): diff --git a/src/models/sequence/modules/pool.py b/s4/src/models/sequence/modules/pool.py similarity index 99% rename from src/models/sequence/modules/pool.py rename to s4/src/models/sequence/modules/pool.py index 31a13595..ca47de7e 100644 --- a/src/models/sequence/modules/pool.py +++ b/s4/src/models/sequence/modules/pool.py @@ -5,8 +5,8 @@ import torch.nn.functional as F from einops import rearrange, repeat, reduce -from src.models.sequence import SequenceModule -from src.models.nn import LinearActivation +from s4.src.models.sequence import SequenceModule +from s4.src.models.nn import LinearActivation """The following pooling modules all subscribe to the same interface. diff --git a/src/models/sequence/modules/s4block.py b/s4/src/models/sequence/modules/s4block.py similarity index 96% rename from src/models/sequence/modules/s4block.py rename to s4/src/models/sequence/modules/s4block.py index 00ce7fde..7747a8db 100644 --- a/src/models/sequence/modules/s4block.py +++ b/s4/src/models/sequence/modules/s4block.py @@ -7,13 +7,13 @@ from functools import partial from einops import rearrange, repeat -from src.models.nn import LinearActivation, Activation, DropoutNd -from src.models.sequence.base import SequenceModule -from src.models.sequence.kernels.fftconv import FFTConv -import src.utils as utils -import src.utils.registry as registry +from s4.src.models.nn import LinearActivation, Activation, DropoutNd +from s4.src.models.sequence.base import SequenceModule +from s4.src.models.sequence.kernels.fftconv import FFTConv +import s4.src.utils as utils +import s4.src.utils.registry as registry -import src.utils.train +import s4.src.utils.train; import s4; src = s4.src log = src.utils.train.get_logger(__name__) contract = torch.einsum diff --git a/src/models/sequence/modules/s4nd.py b/s4/src/models/sequence/modules/s4nd.py similarity index 97% rename from src/models/sequence/modules/s4nd.py rename to s4/src/models/sequence/modules/s4nd.py index d4aaebd9..76dee2d2 100644 --- a/src/models/sequence/modules/s4nd.py +++ b/s4/src/models/sequence/modules/s4nd.py @@ -6,11 +6,11 @@ import torch.nn.functional as F from einops import rearrange, repeat, reduce -from src.models.sequence import SequenceModule -from src.models.sequence.kernels import registry as kernel_registry -from src.models.nn import LinearActivation, Activation, DropoutNd -import src.utils.train -import src.utils as utils +from s4.src.models.sequence import SequenceModule +from s4.src.models.sequence.kernels import registry as kernel_registry +from s4.src.models.nn import LinearActivation, Activation, DropoutNd +import s4.src.utils.train; import s4; src = s4.src +import s4.src.utils as utils log = src.utils.train.get_logger(__name__) @@ -80,7 +80,7 @@ def __init__( super().__init__() if verbose: - import src.utils.train + import s4; src = s4.src log = src.utils.train.get_logger(__name__) log.info(f"Constructing S4ND (H, N, L) = ({d_model}, {d_state}, {l_max})") diff --git a/src/models/sequence/rnns/__init__.py b/s4/src/models/sequence/rnns/__init__.py similarity index 100% rename from src/models/sequence/rnns/__init__.py rename to s4/src/models/sequence/rnns/__init__.py diff --git a/src/models/sequence/rnns/cells/__init__.py b/s4/src/models/sequence/rnns/cells/__init__.py similarity index 100% rename from src/models/sequence/rnns/cells/__init__.py rename to s4/src/models/sequence/rnns/cells/__init__.py diff --git a/src/models/sequence/rnns/cells/basic.py b/s4/src/models/sequence/rnns/cells/basic.py similarity index 96% rename from src/models/sequence/rnns/cells/basic.py rename to s4/src/models/sequence/rnns/cells/basic.py index fba3746c..7a098528 100644 --- a/src/models/sequence/rnns/cells/basic.py +++ b/s4/src/models/sequence/rnns/cells/basic.py @@ -4,10 +4,10 @@ import torch.nn as nn import torch.nn.functional as F -from src.models.nn import LinearActivation, Activation # , get_initializer -from src.models.nn.gate import Gate -from src.models.nn.orthogonal import OrthogonalLinear -from src.models.sequence.base import SequenceModule +from s4.src.models.nn import LinearActivation, Activation # , get_initializer +from s4.src.models.nn.gate import Gate +from s4.src.models.nn.orthogonal import OrthogonalLinear +from s4.src.models.sequence.base import SequenceModule class CellBase(SequenceModule): """Abstract class for our recurrent cell interface. diff --git a/src/models/sequence/rnns/cells/hippo.py b/s4/src/models/sequence/rnns/cells/hippo.py similarity index 96% rename from src/models/sequence/rnns/cells/hippo.py rename to s4/src/models/sequence/rnns/cells/hippo.py index 7b05f3bc..274b65c1 100644 --- a/src/models/sequence/rnns/cells/hippo.py +++ b/s4/src/models/sequence/rnns/cells/hippo.py @@ -5,8 +5,8 @@ from torch.nn import functional as F import numpy as np -from src.models.sequence.rnns.cells.memory import LTICell, LSICell -from src.models.hippo.hippo import transition +from s4.src.models.sequence.rnns.cells.memory import LTICell, LSICell +from s4.src.models.hippo.hippo import transition class HiPPOLTICell(LTICell): diff --git a/src/models/sequence/rnns/cells/memory.py b/s4/src/models/sequence/rnns/cells/memory.py similarity index 98% rename from src/models/sequence/rnns/cells/memory.py rename to s4/src/models/sequence/rnns/cells/memory.py index dbb579c2..d0bda977 100644 --- a/src/models/sequence/rnns/cells/memory.py +++ b/s4/src/models/sequence/rnns/cells/memory.py @@ -8,9 +8,9 @@ from scipy import signal from scipy import linalg as la -from src.models.sequence.rnns.cells.basic import RNNCell -from src.models.nn import LinearActivation, Activation # , get_initializer -from src.models.nn.gate import Gate +from s4.src.models.sequence.rnns.cells.basic import RNNCell +from s4.src.models.nn import LinearActivation, Activation # , get_initializer +from s4.src.models.nn.gate import Gate forward_aliases = ['euler', 'forward_euler', 'forward', 'forward_diff'] diff --git a/src/models/sequence/rnns/cells/minimalrnn.py b/s4/src/models/sequence/rnns/cells/minimalrnn.py similarity index 92% rename from src/models/sequence/rnns/cells/minimalrnn.py rename to s4/src/models/sequence/rnns/cells/minimalrnn.py index e6c81620..0db7814e 100644 --- a/src/models/sequence/rnns/cells/minimalrnn.py +++ b/s4/src/models/sequence/rnns/cells/minimalrnn.py @@ -5,9 +5,9 @@ [21-10-22] I believe this has not been tested in awhile but should work with minimal modifications """ -from src.models.sequence.rnns.cells.basic import CellBase -from src.models.nn import LinearActivation -from src.models.nn.gate import Gate +from s4.src.models.sequence.rnns.cells.basic import CellBase +from s4.src.models.nn import LinearActivation +from s4.src.models.nn.gate import Gate class MinimalRNNCell(CellBase): name = 'mrnn' diff --git a/src/models/sequence/rnns/cells/timestamp.py b/s4/src/models/sequence/rnns/cells/timestamp.py similarity index 96% rename from src/models/sequence/rnns/cells/timestamp.py rename to s4/src/models/sequence/rnns/cells/timestamp.py index badc6a74..9314af02 100644 --- a/src/models/sequence/rnns/cells/timestamp.py +++ b/s4/src/models/sequence/rnns/cells/timestamp.py @@ -5,8 +5,8 @@ import torch.nn.functional as F from functools import partial -from src.models.sequence.rnns.cells.memory import MemoryCell, forward_aliases, backward_aliases, bilinear_aliases, zoh_aliases -from src.models.hippo.transition import ( +from s4.src.models.sequence.rnns.cells.memory import MemoryCell, forward_aliases, backward_aliases, bilinear_aliases, zoh_aliases +from s4.src.models.hippo.transition import ( LegSAdaptiveTransitionManual, LegTAdaptiveTransitionManual, LagTAdaptiveTransitionManual, diff --git a/src/models/sequence/rnns/qrnn.py b/s4/src/models/sequence/rnns/qrnn.py similarity index 95% rename from src/models/sequence/rnns/qrnn.py rename to s4/src/models/sequence/rnns/qrnn.py index 32fb3629..326f9c0a 100644 --- a/src/models/sequence/rnns/qrnn.py +++ b/s4/src/models/sequence/rnns/qrnn.py @@ -9,11 +9,11 @@ import numpy as np from scipy import signal -from src.models.nn import LinearActivation -from src.models.functional import unroll -from src.models.hippo.hippo import transition -from src.models.hippo.transition import TLagTAdaptiveTransitionManual, LagTAdaptiveTransitionManual, LegTAdaptiveTransitionManual, LegSAdaptiveTransitionManual, LagTCumsumAdaptiveTransition, TLagTCumsumAdaptiveTransition -from src.models.sequence.base import SequenceModule +from s4.src.models.nn import LinearActivation +from s4.src.models.functional import unroll +from s4.src.models.hippo.hippo import transition +from s4.src.models.hippo.transition import TLagTAdaptiveTransitionManual, LagTAdaptiveTransitionManual, LegTAdaptiveTransitionManual, LegSAdaptiveTransitionManual, LagTCumsumAdaptiveTransition, TLagTCumsumAdaptiveTransition +from s4.src.models.sequence.base import SequenceModule class MemoryProjection(nn.Module): """Implements the memory projection operator for fixed dt.""" diff --git a/src/models/sequence/rnns/rnn.py b/s4/src/models/sequence/rnns/rnn.py similarity index 97% rename from src/models/sequence/rnns/rnn.py rename to s4/src/models/sequence/rnns/rnn.py index fee70bd9..ce50536a 100644 --- a/src/models/sequence/rnns/rnn.py +++ b/s4/src/models/sequence/rnns/rnn.py @@ -1,9 +1,9 @@ import torch import torch.nn as nn -import src.utils as utils -from src.models.sequence.rnns.cells import CellBase -from src.models.sequence import SequenceModule +import s4.src.utils as utils +from s4.src.models.sequence.rnns.cells import CellBase +from s4.src.models.sequence import SequenceModule # [21-09-12 AG]: We previously set up a way to register RNNCell classes, which gives them a "local" name # To convert this mapping from name to constructor, we use the fact that the str representation of a constructor is "" diff --git a/src/models/sequence/rnns/sru.py b/s4/src/models/sequence/rnns/sru.py similarity index 95% rename from src/models/sequence/rnns/sru.py rename to s4/src/models/sequence/rnns/sru.py index 88600e34..20df115b 100644 --- a/src/models/sequence/rnns/sru.py +++ b/s4/src/models/sequence/rnns/sru.py @@ -8,10 +8,10 @@ import torch.nn.functional as F from einops import rearrange -from src.models.sequence.rnns.cells import CellBase -from src.models.nn import LinearActivation -import src.models.nn.utils as U -from src.models.sequence.base import SequenceModule, TransposedModule +from s4.src.models.sequence.rnns.cells import CellBase +from s4.src.models.nn import LinearActivation +import s4.src.models.nn.utils as U +from s4.src.models.sequence.base import SequenceModule, TransposedModule class SRUCell(CellBase): """Implementation of the pure SRU cell that works with the models.rnns.rnn.RNN class.""" diff --git a/src/tasks/decoders.py b/s4/src/tasks/decoders.py similarity index 99% rename from src/tasks/decoders.py rename to s4/src/tasks/decoders.py index 71404d5b..fc2d179f 100644 --- a/src/tasks/decoders.py +++ b/s4/src/tasks/decoders.py @@ -5,8 +5,8 @@ import torch.nn.functional as F from einops import rearrange, reduce -import src.models.nn.utils as U -import src.utils as utils +import s4.src.models.nn.utils as U +import s4.src.utils as utils class Decoder(nn.Module): diff --git a/src/tasks/encoders.py b/s4/src/tasks/encoders.py similarity index 98% rename from src/tasks/encoders.py rename to s4/src/tasks/encoders.py index 1edb4fdc..4aece371 100644 --- a/src/tasks/encoders.py +++ b/s4/src/tasks/encoders.py @@ -9,11 +9,11 @@ import torch.nn.functional as F from einops import rearrange -import src.models.nn.utils as U -import src.utils as utils -import src.utils.config -from src.models.sequence.backbones.block import SequenceResidualBlock -from src.models.nn import Normalization +import s4.src.models.nn.utils as U +import s4.src.utils as utils +import s4.src.utils.train; import s4; src = s4.src +from s4.src.models.sequence.backbones.block import SequenceResidualBlock +from s4.src.models.nn import Normalization class Encoder(nn.Module): """Encoder abstraction. diff --git a/src/tasks/metrics.py b/s4/src/tasks/metrics.py similarity index 100% rename from src/tasks/metrics.py rename to s4/src/tasks/metrics.py diff --git a/src/tasks/tasks.py b/s4/src/tasks/tasks.py similarity index 98% rename from src/tasks/tasks.py rename to s4/src/tasks/tasks.py index fefcb448..16c8e087 100644 --- a/src/tasks/tasks.py +++ b/s4/src/tasks/tasks.py @@ -9,18 +9,18 @@ import torch.nn.functional as F from einops import rearrange from omegaconf import ListConfig -from src.models.nn.normalization import ( +from s4.src.models.nn.normalization import ( ReversibleInstanceNorm1dInput, ReversibleInstanceNorm1dOutput, TSNormalization, TSInverseNormalization, ) -from src.models.nn.adaptive_softmax import AdaptiveEmbedding, ProjectedAdaptiveLogSoftmax -import src.tasks.metrics as M -import src.models.nn.utils as U +from s4.src.models.nn.adaptive_softmax import AdaptiveEmbedding, ProjectedAdaptiveLogSoftmax +import s4.src.tasks.metrics as M +import s4.src.models.nn.utils as U import torchmetrics as tm -from src.utils.config import to_list, instantiate +from s4.src.utils.config import to_list, instantiate class BaseTask: diff --git a/src/utils/__init__.py b/s4/src/utils/__init__.py similarity index 100% rename from src/utils/__init__.py rename to s4/src/utils/__init__.py diff --git a/src/utils/config.py b/s4/src/utils/config.py similarity index 100% rename from src/utils/config.py rename to s4/src/utils/config.py diff --git a/src/utils/distributed.py b/s4/src/utils/distributed.py similarity index 100% rename from src/utils/distributed.py rename to s4/src/utils/distributed.py diff --git a/src/utils/optim/ema.py b/s4/src/utils/optim/ema.py similarity index 100% rename from src/utils/optim/ema.py rename to s4/src/utils/optim/ema.py diff --git a/src/utils/optim/lamb.py b/s4/src/utils/optim/lamb.py similarity index 100% rename from src/utils/optim/lamb.py rename to s4/src/utils/optim/lamb.py diff --git a/src/utils/optim/schedulers.py b/s4/src/utils/optim/schedulers.py similarity index 100% rename from src/utils/optim/schedulers.py rename to s4/src/utils/optim/schedulers.py diff --git a/src/utils/optim_groups.py b/s4/src/utils/optim_groups.py similarity index 100% rename from src/utils/optim_groups.py rename to s4/src/utils/optim_groups.py diff --git a/src/utils/permutations.py b/s4/src/utils/permutations.py similarity index 100% rename from src/utils/permutations.py rename to s4/src/utils/permutations.py diff --git a/s4/src/utils/registry.py b/s4/src/utils/registry.py new file mode 100644 index 00000000..bf25fb19 --- /dev/null +++ b/s4/src/utils/registry.py @@ -0,0 +1,106 @@ +optimizer = { + "adam": "torch.optim.Adam", + "adamw": "torch.optim.AdamW", + "rmsprop": "torch.optim.RMSprop", + "sgd": "torch.optim.SGD", + "lamb": "s4.src.utils.optim.lamb.JITLamb", +} + +scheduler = { + "constant": "transformers.get_constant_schedule", + "plateau": "torch.optim.lr_scheduler.ReduceLROnPlateau", + "step": "torch.optim.lr_scheduler.StepLR", + "multistep": "torch.optim.lr_scheduler.MultiStepLR", + "cosine": "torch.optim.lr_scheduler.CosineAnnealingLR", + "constant_warmup": "transformers.get_constant_schedule_with_warmup", + "linear_warmup": "transformers.get_linear_schedule_with_warmup", + "cosine_warmup": "transformers.get_cosine_schedule_with_warmup", + "timm_cosine": "s4.src.utils.optim.schedulers.TimmCosineLRScheduler", +} + +callbacks = { + "timer": "s4.src.callbacks.timer.Timer", + "params": "s4.src.callbacks.params.ParamsLog", + "learning_rate_monitor": "pytorch_lightning.callbacks.LearningRateMonitor", + "model_checkpoint": "pytorch_lightning.callbacks.ModelCheckpoint", + "early_stopping": "pytorch_lightning.callbacks.EarlyStopping", + "swa": "pytorch_lightning.callbacks.StochasticWeightAveraging", + "rich_model_summary": "pytorch_lightning.callbacks.RichModelSummary", + "rich_progress_bar": "pytorch_lightning.callbacks.RichProgressBar", + "progressive_resizing": "s4.src.callbacks.progressive_resizing.ProgressiveResizing", + # "profiler": "pytorch_lightning.profilers.PyTorchProfiler", +} + +model = { + # Backbones from this repo + "model": "s4.src.models.sequence.backbones.model.SequenceModel", + "unet": "s4.src.models.sequence.backbones.unet.SequenceUNet", + "sashimi": "s4.src.models.sequence.backbones.sashimi.Sashimi", + "sashimi_standalone": "s4.models.sashimi.sashimi.Sashimi", + # Baseline RNNs + "lstm": "s4.src.models.baselines.lstm.TorchLSTM", + "gru": "s4.src.models.baselines.gru.TorchGRU", + "unicornn": "s4.src.models.baselines.unicornn.UnICORNN", + "odelstm": "s4.src.models.baselines.odelstm.ODELSTM", + "lipschitzrnn": "s4.src.models.baselines.lipschitzrnn.RnnModels", + "stackedrnn": "s4.src.models.baselines.samplernn.StackedRNN", + "stackedrnn_baseline": "s4.src.models.baselines.samplernn.StackedRNNBaseline", + "samplernn": "s4.src.models.baselines.samplernn.SampleRNN", + "dcgru": "s4.src.models.baselines.dcgru.DCRNNModel_classification", + "dcgru_ss": "s4.src.models.baselines.dcgru.DCRNNModel_nextTimePred", + # Baseline CNNs + "ckconv": "s4.src.models.baselines.ckconv.ClassificationCKCNN", + "wavegan": "s4.src.models.baselines.wavegan.WaveGANDiscriminator", # DEPRECATED + "denseinception": "s4.src.models.baselines.dense_inception.DenseInception", + "wavenet": "s4.src.models.baselines.wavenet.WaveNetModel", + "torch/resnet2d": "s4.src.models.baselines.resnet.TorchVisionResnet", # 2D ResNet + # Nonaka 1D CNN baselines + "nonaka/resnet18": "s4.src.models.baselines.nonaka.resnet.resnet1d18", + "nonaka/inception": "s4.src.models.baselines.nonaka.inception.inception1d", + "nonaka/xresnet50": "s4.src.models.baselines.nonaka.xresnet.xresnet1d50", + # ViT Variants (note: small variant is taken from Tri, differs from original) + "vit": "s4.models.baselines.vit.ViT", + "vit_s_16": "s4.src.models.baselines.vit_all.vit_small_patch16_224", + "vit_b_16": "s4.src.models.baselines.vit_all.vit_base_patch16_224", + # Timm models + "timm/convnext_base": "s4.src.models.baselines.convnext_timm.convnext_base", + "timm/convnext_small": "s4.src.models.baselines.convnext_timm.convnext_small", + "timm/convnext_tiny": "s4.src.models.baselines.convnext_timm.convnext_tiny", + "timm/convnext_micro": "s4.src.models.baselines.convnext_timm.convnext_micro", + "timm/resnet50": "s4.src.models.baselines.resnet_timm.resnet50", # Can also register many other variants in resnet_timm + "timm/convnext_tiny_3d": "s4.src.models.baselines.convnext_timm.convnext3d_tiny", + # Segmentation models + "convnext_unet_tiny": "s4.src.models.segmentation.convnext_unet.convnext_tiny_unet", +} + +layer = { + "id": "s4.src.models.sequence.base.SequenceIdentity", + "lstm": "s4.src.models.baselines.lstm.TorchLSTM", + "standalone": "s4.models.s4.s4.S4Block", + "s4d": "s4.models.s4.s4d.S4D", + "ffn": "s4.src.models.sequence.modules.ffn.FFN", + "sru": "s4.src.models.sequence.rnns.sru.SRURNN", + "rnn": "s4.src.models.sequence.rnns.rnn.RNN", # General RNN wrapper + "conv1d": "s4.src.models.sequence.convs.conv1d.Conv1d", + "conv2d": "s4.src.models.sequence.convs.conv2d.Conv2d", + "mha": "s4.src.models.sequence.attention.mha.MultiheadAttention", + "vit": "s4.src.models.sequence.attention.mha.VitAttention", + "performer": "s4.src.models.sequence.attention.linear.Performer", + "lssl": "s4.src.models.sequence.modules.lssl.LSSL", + "s4": "s4.src.models.sequence.modules.s4block.S4Block", + "fftconv": "s4.src.models.sequence.kernels.fftconv.FFTConv", + "s4nd": "s4.src.models.sequence.modules.s4nd.S4ND", + "mega": "s4.src.models.sequence.modules.mega.MegaBlock", + "h3": "s4.src.models.sequence.experimental.h3.H3", + "h4": "s4.src.models.sequence.experimental.h4.H4", + # "packedrnn": "s4.models.sequence.rnns.packedrnn.PackedRNN", +} + +layer_decay = { + "convnext_timm_tiny": "s4.src.models.baselines.convnext_timm.get_num_layer_for_convnext_tiny", +} + +model_state_hook = { + "convnext_timm_tiny_2d_to_3d": "s4.src.models.baselines.convnext_timm.convnext_timm_tiny_2d_to_3d", + "convnext_timm_tiny_s4nd_2d_to_3d": "s4.src.models.baselines.convnext_timm.convnext_timm_tiny_s4nd_2d_to_3d", +} diff --git a/src/utils/train.py b/s4/src/utils/train.py similarity index 99% rename from src/utils/train.py rename to s4/src/utils/train.py index cd6546ce..8116dcd4 100644 --- a/src/utils/train.py +++ b/s4/src/utils/train.py @@ -10,7 +10,7 @@ from omegaconf import DictConfig, OmegaConf from pytorch_lightning.utilities import rank_zero_only -from src.utils.config import omegaconf_filter_keys +from s4.src.utils.config import omegaconf_filter_keys # Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging diff --git a/train.py b/s4/train.py similarity index 98% rename from train.py rename to s4/train.py index 0f72e7fe..a0b75c33 100644 --- a/train.py +++ b/s4/train.py @@ -17,14 +17,14 @@ from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn from tqdm.auto import tqdm -import src.models.nn.utils as U -import src.utils as utils -import src.utils.train -from src.dataloaders import SequenceDataset # TODO make registry -from src.tasks import decoders, encoders, tasks -from src.utils import registry -from src.utils.optim.ema import build_ema_optimizer -from src.utils.optim_groups import add_optimizer_hooks +import s4.src.models.nn.utils as U +import s4.src.utils as utils +import s4.src.utils.train; import s4; src = s4.src +from s4.src.dataloaders import SequenceDataset # TODO make registry +from s4.src.tasks import decoders, encoders, tasks +from s4.src.utils import registry +from s4.src.utils.optim.ema import build_ema_optimizer +from s4.src.utils.optim_groups import add_optimizer_hooks log = src.utils.train.get_logger(__name__) diff --git a/src/utils/registry.py b/src/utils/registry.py deleted file mode 100644 index 9741d05b..00000000 --- a/src/utils/registry.py +++ /dev/null @@ -1,106 +0,0 @@ -optimizer = { - "adam": "torch.optim.Adam", - "adamw": "torch.optim.AdamW", - "rmsprop": "torch.optim.RMSprop", - "sgd": "torch.optim.SGD", - "lamb": "src.utils.optim.lamb.JITLamb", -} - -scheduler = { - "constant": "transformers.get_constant_schedule", - "plateau": "torch.optim.lr_scheduler.ReduceLROnPlateau", - "step": "torch.optim.lr_scheduler.StepLR", - "multistep": "torch.optim.lr_scheduler.MultiStepLR", - "cosine": "torch.optim.lr_scheduler.CosineAnnealingLR", - "constant_warmup": "transformers.get_constant_schedule_with_warmup", - "linear_warmup": "transformers.get_linear_schedule_with_warmup", - "cosine_warmup": "transformers.get_cosine_schedule_with_warmup", - "timm_cosine": "src.utils.optim.schedulers.TimmCosineLRScheduler", -} - -callbacks = { - "timer": "src.callbacks.timer.Timer", - "params": "src.callbacks.params.ParamsLog", - "learning_rate_monitor": "pytorch_lightning.callbacks.LearningRateMonitor", - "model_checkpoint": "pytorch_lightning.callbacks.ModelCheckpoint", - "early_stopping": "pytorch_lightning.callbacks.EarlyStopping", - "swa": "pytorch_lightning.callbacks.StochasticWeightAveraging", - "rich_model_summary": "pytorch_lightning.callbacks.RichModelSummary", - "rich_progress_bar": "pytorch_lightning.callbacks.RichProgressBar", - "progressive_resizing": "src.callbacks.progressive_resizing.ProgressiveResizing", - # "profiler": "pytorch_lightning.profilers.PyTorchProfiler", -} - -model = { - # Backbones from this repo - "model": "src.models.sequence.backbones.model.SequenceModel", - "unet": "src.models.sequence.backbones.unet.SequenceUNet", - "sashimi": "src.models.sequence.backbones.sashimi.Sashimi", - "sashimi_standalone": "models.sashimi.sashimi.Sashimi", - # Baseline RNNs - "lstm": "src.models.baselines.lstm.TorchLSTM", - "gru": "src.models.baselines.gru.TorchGRU", - "unicornn": "src.models.baselines.unicornn.UnICORNN", - "odelstm": "src.models.baselines.odelstm.ODELSTM", - "lipschitzrnn": "src.models.baselines.lipschitzrnn.RnnModels", - "stackedrnn": "src.models.baselines.samplernn.StackedRNN", - "stackedrnn_baseline": "src.models.baselines.samplernn.StackedRNNBaseline", - "samplernn": "src.models.baselines.samplernn.SampleRNN", - "dcgru": "src.models.baselines.dcgru.DCRNNModel_classification", - "dcgru_ss": "src.models.baselines.dcgru.DCRNNModel_nextTimePred", - # Baseline CNNs - "ckconv": "src.models.baselines.ckconv.ClassificationCKCNN", - "wavegan": "src.models.baselines.wavegan.WaveGANDiscriminator", # DEPRECATED - "denseinception": "src.models.baselines.dense_inception.DenseInception", - "wavenet": "src.models.baselines.wavenet.WaveNetModel", - "torch/resnet2d": "src.models.baselines.resnet.TorchVisionResnet", # 2D ResNet - # Nonaka 1D CNN baselines - "nonaka/resnet18": "src.models.baselines.nonaka.resnet.resnet1d18", - "nonaka/inception": "src.models.baselines.nonaka.inception.inception1d", - "nonaka/xresnet50": "src.models.baselines.nonaka.xresnet.xresnet1d50", - # ViT Variants (note: small variant is taken from Tri, differs from original) - "vit": "models.baselines.vit.ViT", - "vit_s_16": "src.models.baselines.vit_all.vit_small_patch16_224", - "vit_b_16": "src.models.baselines.vit_all.vit_base_patch16_224", - # Timm models - "timm/convnext_base": "src.models.baselines.convnext_timm.convnext_base", - "timm/convnext_small": "src.models.baselines.convnext_timm.convnext_small", - "timm/convnext_tiny": "src.models.baselines.convnext_timm.convnext_tiny", - "timm/convnext_micro": "src.models.baselines.convnext_timm.convnext_micro", - "timm/resnet50": "src.models.baselines.resnet_timm.resnet50", # Can also register many other variants in resnet_timm - "timm/convnext_tiny_3d": "src.models.baselines.convnext_timm.convnext3d_tiny", - # Segmentation models - "convnext_unet_tiny": "src.models.segmentation.convnext_unet.convnext_tiny_unet", -} - -layer = { - "id": "src.models.sequence.base.SequenceIdentity", - "lstm": "src.models.baselines.lstm.TorchLSTM", - "standalone": "models.s4.s4.S4Block", - "s4d": "models.s4.s4d.S4D", - "ffn": "src.models.sequence.modules.ffn.FFN", - "sru": "src.models.sequence.rnns.sru.SRURNN", - "rnn": "src.models.sequence.rnns.rnn.RNN", # General RNN wrapper - "conv1d": "src.models.sequence.convs.conv1d.Conv1d", - "conv2d": "src.models.sequence.convs.conv2d.Conv2d", - "mha": "src.models.sequence.attention.mha.MultiheadAttention", - "vit": "src.models.sequence.attention.mha.VitAttention", - "performer": "src.models.sequence.attention.linear.Performer", - "lssl": "src.models.sequence.modules.lssl.LSSL", - "s4": "src.models.sequence.modules.s4block.S4Block", - "fftconv": "src.models.sequence.kernels.fftconv.FFTConv", - "s4nd": "src.models.sequence.modules.s4nd.S4ND", - "mega": "src.models.sequence.modules.mega.MegaBlock", - "h3": "src.models.sequence.experimental.h3.H3", - "h4": "src.models.sequence.experimental.h4.H4", - # 'packedrnn': 'models.sequence.rnns.packedrnn.PackedRNN', -} - -layer_decay = { - 'convnext_timm_tiny': 'src.models.baselines.convnext_timm.get_num_layer_for_convnext_tiny', -} - -model_state_hook = { - 'convnext_timm_tiny_2d_to_3d': 'src.models.baselines.convnext_timm.convnext_timm_tiny_2d_to_3d', - 'convnext_timm_tiny_s4nd_2d_to_3d': 'src.models.baselines.convnext_timm.convnext_timm_tiny_s4nd_2d_to_3d', -} From 3b6e6897a4a5c87e9a7e957cb25a6e825b4d04cd Mon Sep 17 00:00:00 2001 From: Leo Auri Date: Thu, 25 Sep 2025 14:58:10 +0200 Subject: [PATCH 2/4] Make lightning optional dependency --- pyproject.toml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 2c96dec7..bb1ae3ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,6 @@ dependencies = [ "rich", "torchtext", "lit", - "pytorch-lightning==2.0.4", "hydra-core", "omegaconf", "wandb", @@ -30,6 +29,9 @@ dependencies = [ "timm==0.5.4", ] +[project.optional-dependencies] +train = ["pytorch-lightning==2.0.4"] + [tool.setuptools.packages.find] where = ["."] include = ["s4*"] \ No newline at end of file From 49f09e4e61dc89f10a848b016d4701f39843f59d Mon Sep 17 00:00:00 2001 From: Leo Auri Date: Thu, 25 Sep 2025 15:04:17 +0200 Subject: [PATCH 3/4] Add kernel install CLI command --- pyproject.toml | 5 ++++- s4/extensions/kernels/install.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) create mode 100644 s4/extensions/kernels/install.py diff --git a/pyproject.toml b/pyproject.toml index bb1ae3ef..dd00ced0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,4 +34,7 @@ train = ["pytorch-lightning==2.0.4"] [tool.setuptools.packages.find] where = ["."] -include = ["s4*"] \ No newline at end of file +include = ["s4*"] + +[project.scripts] +s4-install-kernels = "s4.extensions.kernels.install:main" \ No newline at end of file diff --git a/s4/extensions/kernels/install.py b/s4/extensions/kernels/install.py new file mode 100644 index 00000000..cc64897f --- /dev/null +++ b/s4/extensions/kernels/install.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 +import os +import subprocess +import sys +from pathlib import Path + +def main(): + """Install structured kernels CUDA extensions.""" + kernel_dir = Path(__file__).parent + setup_py = kernel_dir / "setup.py" + + if not setup_py.exists(): + print(f"Error: {setup_py} not found", file=sys.stderr) + sys.exit(1) + + # Change to the kernel directory and run setup.py + original_dir = os.getcwd() + try: + os.chdir(kernel_dir) + result = subprocess.run([ + sys.executable, "setup.py", "build_ext", "--inplace" + ], check=True) + print("Kernels installed successfully!") + except subprocess.CalledProcessError as e: + print(f"Error installing kernels: {e}", file=sys.stderr) + sys.exit(1) + finally: + os.chdir(original_dir) + +if __name__ == "__main__": + main() \ No newline at end of file From 3ab83f17bf0824ca91935f26ef13d70bf1141952 Mon Sep 17 00:00:00 2001 From: Leo Auri Date: Thu, 25 Sep 2025 16:24:41 +0200 Subject: [PATCH 4/4] Install kernels globally --- s4/extensions/kernels/install.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/s4/extensions/kernels/install.py b/s4/extensions/kernels/install.py index cc64897f..5c1e921f 100644 --- a/s4/extensions/kernels/install.py +++ b/s4/extensions/kernels/install.py @@ -18,7 +18,7 @@ def main(): try: os.chdir(kernel_dir) result = subprocess.run([ - sys.executable, "setup.py", "build_ext", "--inplace" + sys.executable, "setup.py", "install" ], check=True) print("Kernels installed successfully!") except subprocess.CalledProcessError as e: