Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
4 changes: 2 additions & 2 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
import os
import argparse

from models.s4.s4 import S4Block as S4 # Can use full version instead of minimal S4D standalone below
from models.s4.s4d import S4D
from s4.models.s4.s4 import S4Block as S4 # Can use full version instead of minimal S4D standalone below
from s4.models.s4.s4d import S4D
from tqdm.auto import tqdm

# Dropout broke in PyTorch 1.11
Expand Down
40 changes: 40 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
[build-system]
requires = ["setuptools>=45", "wheel"]
build-backend = "setuptools.build_meta"

[project]
name = "s4"
version = "0.1.0"
description = "Structured State Space Models"
dependencies = [
"numpy",
"scipy",
"pandas",
"scikit-learn",
"matplotlib",
"tqdm",
"rich",
"torchtext",
"lit",
"hydra-core",
"omegaconf",
"wandb",
"einops",
"cmake",
"transformers",
"datasets",
"sktime",
"numba",
"gluonts",
"timm==0.5.4",
]

[project.optional-dependencies]
train = ["pytorch-lightning==2.0.4"]

[tool.setuptools.packages.find]
where = ["."]
include = ["s4*"]

[project.scripts]
s4-install-kernels = "s4.extensions.kernels.install:main"
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
from pathlib import Path

from train import SequenceLightningModule
from s4.train import SequenceLightningModule


parser = argparse.ArgumentParser()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from torch.nn.modules import module
import torch.nn.functional as F
from torch.distributions import Categorical
from src import utils
from s4.src import utils
from einops import rearrange, repeat, reduce

from train import SequenceLightningModule
from s4.train import SequenceLightningModule
from omegaconf import OmegaConf


Expand Down
4 changes: 2 additions & 2 deletions checkpoints/evaluate.py → s4/checkpoints/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from torch.nn.modules import module
import torch.nn.functional as F
from torch.distributions import Categorical
from src import utils
from s4.src import utils
from einops import rearrange, repeat, reduce

from train import SequenceLightningModule
from s4.train import SequenceLightningModule
from omegaconf import OmegaConf

@hydra.main(config_path="../configs", config_name="generate.yaml")
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Empty file added s4/extensions/__init__.py
Empty file.
File renamed without changes.
Empty file.
File renamed without changes.
File renamed without changes.
File renamed without changes.
31 changes: 31 additions & 0 deletions s4/extensions/kernels/install.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#!/usr/bin/env python3
import os
import subprocess
import sys
from pathlib import Path

def main():
"""Install structured kernels CUDA extensions."""
kernel_dir = Path(__file__).parent
setup_py = kernel_dir / "setup.py"

if not setup_py.exists():
print(f"Error: {setup_py} not found", file=sys.stderr)
sys.exit(1)

# Change to the kernel directory and run setup.py
original_dir = os.getcwd()
try:
os.chdir(kernel_dir)
result = subprocess.run([
sys.executable, "setup.py", "install"
], check=True)
print("Kernels installed successfully!")
except subprocess.CalledProcessError as e:
print(f"Error installing kernels: {e}", file=sys.stderr)
sys.exit(1)
finally:
os.chdir(original_dir)

if __name__ == "__main__":
main()
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from einops import rearrange

from src.ops.vandermonde import log_vandermonde, log_vandermonde_fast
from s4.src.ops.vandermonde import log_vandermonde, log_vandermonde_fast


@pytest.mark.parametrize('L', [3, 17, 489, 2**10, 1047, 2**11, 2**12])
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
8 changes: 4 additions & 4 deletions generate.py → s4/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from torch.distributions import Categorical
from tqdm.auto import tqdm

from src import utils
from src.dataloaders.audio import mu_law_decode
from src.models.baselines.wavenet import WaveNetModel
from train import SequenceLightningModule
from s4.src import utils
from s4.src.dataloaders.audio import mu_law_decode
from s4.src.models.baselines.wavenet import WaveNetModel
from s4.train import SequenceLightningModule

def test_step(model):
B, L = 2, 64
Expand Down
File renamed without changes.
Empty file added s4/models/__init__.py
Empty file.
File renamed without changes.
Empty file added s4/models/dss/__init__.py
Empty file.
File renamed without changes.
Empty file added s4/models/hippo/__init__.py
Empty file.
File renamed without changes.
Empty file added s4/models/related/__init__.py
Empty file.
File renamed without changes.
Empty file added s4/models/s4/__init__.py
Empty file.
File renamed without changes.
Empty file added s4/models/s4/lssl.md
Empty file.
4 changes: 2 additions & 2 deletions models/s4/s4.py → s4/models/s4/s4.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:

# Try CUDA extension
try:
from extensions.kernels.cauchy import cauchy_mult as cauchy_cuda
from extensions.kernels.vandermonde import log_vandermonde_cuda
from s4.extensions.kernels.cauchy import cauchy_mult as cauchy_cuda
from s4.extensions.kernels.vandermonde import log_vandermonde_cuda
has_cuda_extension = True
log.info("CUDA extension for structured kernels (Cauchy and Vandermonde multiplication) found.")
except:
Expand Down
2 changes: 1 addition & 1 deletion models/s4/s4d.py → s4/models/s4/s4d.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch.nn.functional as F
from einops import rearrange, repeat

from src.models.nn import DropoutNd
from s4.src.models.nn import DropoutNd

class S4DKernel(nn.Module):
"""Generate convolution kernel from diagonal SSM parameters."""
Expand Down
File renamed without changes.
Empty file added s4/models/s4nd/__init__.py
Empty file.
File renamed without changes.
Empty file added s4/models/sashimi/__init__.py
Empty file.
File renamed without changes.
Empty file.
Empty file.
2 changes: 1 addition & 1 deletion models/sashimi/sashimi.py → s4/models/sashimi/sashimi.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from einops import rearrange

from models.s4.s4 import LinearActivation, S4Block as S4
from s4.models.s4.s4 import LinearActivation, S4Block as S4

class DownPool(nn.Module):
def __init__(self, d_input, expand, pool):
Expand Down
Empty file.
Empty file.
Empty file.
Empty file.
Loading