Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 42 additions & 9 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,13 +1,46 @@
FROM hpcaitech/pytorch-cuda:1.12.0-11.3.0
FROM image-cloud.luchentech.com/hpcaitech/pytorch-cuda:2.3.0-12.1.0

RUN conda install openmm=7.7.0 pdbfixer -c conda-forge -y \
&& conda install hmmer==3.3.2 hhsuite=3.3.0 kalign2=2.04 -c bioconda -y

RUN pip install biopython==1.79 dm-tree==0.1.6 ml-collections==0.1.0 \
scipy==1.7.1 ray pyarrow pandas einops
RUN apt-get update && \
apt-get install -y gcc-9 g++-9 && \
update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-9 100 \
--slave /usr/bin/g++ g++ /usr/bin/g++-9 && \
rm -rf /var/lib/apt/lists/*

RUN pip install colossalai
ENV CC=/usr/bin/gcc
ENV CXX=/usr/bin/g++

Run git clone https://github.com/hpcaitech/FastFold.git \
&& cd ./FastFold \
&& python setup.py install
RUN conda update -n base -c defaults -y conda openssl ca-certificates && \
conda config --remove-key channels || true && \
conda config --add channels conda-forge && \
conda config --add channels bioconda && \
conda config --add channels defaults && \
conda config --set channel_priority strict && \
conda config --set show_channel_urls yes

RUN conda install -n pytorch -y -c conda-forge -c bioconda \
openmm=8.2.0 \
pdbfixer \
hmmer=3.3.2 \
hhsuite=3.3.0 \
scipy==1.11.4 \
kalign2=2.04 && \
conda clean -afy

RUN /opt/conda/envs/pytorch/bin/pip install --no-cache-dir \
biopython==1.79 \
dm-tree==0.1.6 \
ml-collections==0.1.0 \
ray==2.47.1 \
pyarrow \
pandas \
einops \
colossalai==0.5.0


RUN git clone --depth 1 https://github.com/hpcaitech/FastFold.git && \
cd FastFold && \
/opt/conda/envs/pytorch/bin/python setup.py install &&


WORKDIR /workspace
6 changes: 3 additions & 3 deletions fastfold/common/residue_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -1127,10 +1127,10 @@ def _make_rigid_transformation_4x4(ex, ey, translation):
# and an array with (restype, atomtype, coord) for the atom positions
# and compute affine transformation matrices (4,4) from one rigid group to the
# previous group
restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=np.int)
restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=int)
restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32)
restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=np.int)
restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=int)
restype_atom14_mask = np.zeros([21, 14], dtype=np.float32)
restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32)
restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32)
Expand Down Expand Up @@ -1286,7 +1286,7 @@ def make_atom14_dists_bounds(

restype_atom14_ambiguous_atoms = np.zeros((21, 14), dtype=np.float32)
restype_atom14_ambiguous_atoms_swap_idx = np.tile(
np.arange(14, dtype=np.int), (21, 1)
np.arange(14, dtype=int), (21, 1)
)


Expand Down
10 changes: 5 additions & 5 deletions fastfold/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,12 @@ def make_sequence_features(
)
features["between_segment_residues"] = np.zeros((num_res,), dtype=np.int32)
features["domain_name"] = np.array(
[description.encode("utf-8")], dtype=np.object_
[description.encode("utf-8")], dtype=object
)
features["residue_index"] = np.array(range(num_res), dtype=np.int32)
features["seq_length"] = np.array([num_res] * num_res, dtype=np.int32)
features["sequence"] = np.array(
[sequence.encode("utf-8")], dtype=np.object_
[sequence.encode("utf-8")], dtype=object
)
return features

Expand Down Expand Up @@ -137,7 +137,7 @@ def make_mmcif_features(
)

mmcif_feats["release_date"] = np.array(
[mmcif_object.header["release_date"].encode("utf-8")], dtype=np.object_
[mmcif_object.header["release_date"].encode("utf-8")], dtype=object
)

mmcif_feats["is_distillation"] = np.array(0., dtype=np.float32)
Expand Down Expand Up @@ -238,7 +238,7 @@ def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict:
features["num_alignments"] = np.array(
[num_alignments] * num_res, dtype=np.int32
)
features["msa_species_identifiers"] = np.array(species_ids, dtype=np.object_)
features["msa_species_identifiers"] = np.array(species_ids, dtype=object)
return features

def run_msa_tool(
Expand Down Expand Up @@ -681,7 +681,7 @@ def convert_monomer_features(
) -> FeatureDict:
"""Reshapes and modifies monomer features for multimer models."""
converted = {}
converted['auth_chain_id'] = np.asarray(chain_id, dtype=np.object_)
converted['auth_chain_id'] = np.asarray(chain_id, dtype=object)
unnecessary_leading_dim_feats = {
'sequence', 'domain_name', 'num_alignments', 'seq_length'
}
Expand Down
8 changes: 4 additions & 4 deletions fastfold/data/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ class LengthError(PrefilterError):
"template_aatype": np.int64,
"template_all_atom_mask": np.float32,
"template_all_atom_positions": np.float32,
"template_domain_names": np.object,
"template_sequence": np.object,
"template_domain_names": object,
"template_sequence": object,
"template_sum_probs": np.float32,
}

Expand Down Expand Up @@ -1209,8 +1209,8 @@ def get_templates(
"template_all_atom_positions": np.zeros(
(1, num_res, residue_constants.atom_type_num, 3), np.float32
),
"template_domain_names": np.array([''.encode()], dtype=np.object),
"template_sequence": np.array([''.encode()], dtype=np.object),
"template_domain_names": np.array([''.encode()], dtype=object),
"template_sequence": np.array([''.encode()], dtype=object),
"template_sum_probs": np.array([0], dtype=np.float32),
}

Expand Down
4 changes: 2 additions & 2 deletions fastfold/distributed/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import torch.distributed as dist
from torch import Tensor

from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.legacy.context import ParallelMode
from colossalai.legacy.core import global_context as gpc

from .core import ensure_divisibility

Expand Down
4 changes: 2 additions & 2 deletions fastfold/distributed/comm_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import torch.distributed as dist
from torch import Tensor

from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.legacy.context import ParallelMode
from colossalai.legacy.core import global_context as gpc

from .comm import _split, divide

Expand Down
6 changes: 5 additions & 1 deletion fastfold/distributed/core.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from colossalai.legacy import launch_from_torch as _ff_launch_from_torch
from colossalai.legacy import launch_from_torch as _ff_launch_from_torch
from colossalai.legacy import launch_from_torch as _ff_launch_from_torch
from colossalai.legacy import launch_from_torch as _ff_launch_from_torch
import os

import torch
Expand Down Expand Up @@ -36,5 +40,5 @@ def init_dap(tensor_model_parallel_size_=None):
set_missing_distributed_environ('MASTER_ADDR', "localhost")
set_missing_distributed_environ('MASTER_PORT', 18417)

colossalai.launch_from_torch(
_ff_launch_from_torch(
config={"parallel": dict(tensor=dict(size=tensor_model_parallel_size_))})
4 changes: 2 additions & 2 deletions fastfold/model/fastnn/evoformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import torch
import torch.nn as nn

from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.legacy.context import ParallelMode
from colossalai.legacy.core import global_context as gpc

from fastfold.model.fastnn import MSACore, OutProductMean, PairCore
from fastfold.model.fastnn.ops import Linear
Expand Down
4 changes: 2 additions & 2 deletions fastfold/model/fastnn/msa.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.legacy.context import ParallelMode
from colossalai.legacy.core import global_context as gpc
from fastfold.model.fastnn.kernel import LayerNorm, bias_dropout_add
from fastfold.model.fastnn.ops import (ChunkMSARowAttentionWithPairBias, ChunkTransition,
SelfAttention, GlobalAttention, Transition,
Expand Down
4 changes: 2 additions & 2 deletions fastfold/model/fastnn/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
import torch
import torch.nn as nn

from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.legacy.context import ParallelMode
from colossalai.legacy.core import global_context as gpc

from fastfold.model.nn.primitives import Attention
from fastfold.utils.checkpointing import checkpoint_blocks
Expand Down
1 change: 1 addition & 0 deletions fastfold/utils/import_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,7 @@ def FoldIterationParams(sm):


def import_jax_weights_(model, npz_path, version="model_1"):
npz_path = '/' + npz_path
data = np.load(npz_path)

translations = get_translation_dict(model, version)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def cuda_ext_helper(name, sources, extra_cuda_flags):
cc_flag.append('arch=compute_80,code=sm_80')

extra_cuda_flags = [
'-std=c++14', '-maxrregcount=50', '-U__CUDA_NO_HALF_OPERATORS__',
'-std=c++17', '-maxrregcount=50', '-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__', '--expt-relaxed-constexpr', '--expt-extended-lambda'
]

Expand Down