Commit cf8bc06b by DLA-Ranker

Updates

parent 08302f9a
name: null
channels:
- pytorch
- conda-forge
- defaults
dependencies:
- bokeh=1.4
- cmake=3.16 # insures that Gloo library extensions will be built
- cudnn=7.6
- cupti=10.1
- cxx-compiler=1.0 # insures C and C++ compilers are available
- jupyterlab=1.2
- mpi4py=3.0 # installs cuda-aware openmpi
- nccl=2.5
- nodejs=13
- nvcc_linux-64=10.1 # configures environment to be "cuda-aware"
- pip=20.0
- pip:
- mxnet-cu101mkl==1.6.* # MXNET is installed prior to horovod
- -r file:requirements.txt
- python=3.7
- pytorch=1.4
- tensorboard=2.1
- tensorflow-gpu=2.1
- torchvision=0.5
Comp;ch1;ch2;Conf;Class
1AK4;A;B;1AK4_cm-it0_8585;0
1AK4;A;B;1AK4_cm-it0_9733;0
2I25;A;B;2I25_ti5-it1_46;1
# Auto detect text files and perform LF normalization
* text=auto
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintainted in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
MIT License
Copyright (c) 2022 Simon Crouzet
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
# InteractionGNN
Source code to run InteractionGNN on protein-protein interfaces
### Usage
1. Set the desired parameters in *config.ini*
2. in a bash shell, run `python interaction_gnn.py`
### Requirements
##### Packages:
- Manual: install Pytorch, Pytorch-Geometric, Cuda-Toolkit, Scikit-Learn and the packages numpy pandas matplotlib lz4 and tqdm (`conda install -c pytorch -c pyg -c conda-forge python=3.9 numpy pandas matplotlib tqdm pytorch pyg scikit-learn cuda-toolkit lz4`)
- All-in-one: Run `conda create --name interaction_gnn --file interaction_gnn.yml`
\
InteractionGNN is using [Pytorch-Geometric]([github.com/pyg-team/pytorch_geometric](https://github.com/pyg-team/pytorch_geometric)).
##### Data files:
Should be in the folder data, displayed like the following example for binary classification:
\
```
InteractionGNN
| interaction_gnn.py
|
|___src
| | ...
|
|___data
|___protein_pair_1
| |___0
| | | file1
| | | file2
| |
| |___1
| | file3
| | file4
|
|___protein_pair_2
| |___0
| | | file5
| | | file6
| |
| |___1
| | file7
| | file8
..........
```
### Citing
If you use this code, please cite the associated paper:
\
```Y. Mohseni Behbahani, S. Crouzet, E. Laine, A. Carbone, *Deep Local Analysis evaluates protein docking conformations with locally oriented cubes*```
\ No newline at end of file
[RUNINFO]
save_models = True
data_dir = ./data/intermediate_m5_e30_nonorm_split1/
[MODELINFO]
nn_model_string = GCN
learning_rate = 0.001
nb_negative = -1
nb_positive = -1
nb_features = 22
specific_scr_fea = S,C,R
balance_test_split = True
exclude_last = True
[RUNPARAMS]
use_kfold = True
epochs = 100
batch_size = 500
nb_folds = 5
use_weights = True
\ No newline at end of file
name: interaction_gnn
channels:
- pytorch
- pyg
- nvidia
- bioconda
- conda-forge
- defaults
dependencies:
- blas=2.113=mkl
- blas-devel=3.9.0=13_win64_mkl
- brotli=1.0.9=h8ffe710_6
- brotli-bin=1.0.9=h8ffe710_6
- brotlipy=0.7.0=py39hb82d6ee_1003
- bzip2=1.0.8=h8ffe710_4
- ca-certificates=2021.10.8=h5b45459_0
- certifi=2021.10.8=py39hcbf5309_2
- cffi=1.15.0=py39h0878f49_0
- charset-normalizer=2.0.12=pyhd8ed1ab_0
- colorama=0.4.4=pyh9f0ad1d_0
- cryptography=36.0.2=py39h7bc7c5c_0
- cuda-cccl=11.6.55=hd268e57_0
- cuda-command-line-tools=11.6.2=h65bbf44_0
- cuda-compiler=11.6.2=h65bbf44_0
- cuda-cudart=11.6.55=h5fb1900_0
- cuda-cuobjdump=11.6.124=h8654613_0
- cuda-cupti=11.6.124=h532822a_0
- cuda-cuxxfilt=11.6.124=h3f9c74b_0
- cuda-libraries=11.6.2=h65bbf44_0
- cuda-libraries-dev=11.6.2=h65bbf44_0
- cuda-memcheck=11.6.124=hea6bc18_0
- cuda-nvcc=11.6.124=h769bc0d_0
- cuda-nvdisasm=11.6.124=he05ff55_0
- cuda-nvml-dev=11.6.55=h2bb381e_0
- cuda-nvprof=11.6.124=he581663_0
- cuda-nvprune=11.6.124=hb892de1_0
- cuda-nvrtc=11.6.124=h231bd66_0
- cuda-nvrtc-dev=11.6.124=hd7d06dc_0
- cuda-nvtx=11.6.124=hee9d5a4_0
- cuda-nvvp=11.6.124=h6a974fa_0
- cuda-sanitizer-api=11.6.124=ha4888a7_0
- cuda-toolkit=11.6.2=h65bbf44_0
- cuda-tools=11.6.2=h65bbf44_0
- cuda-visual-tools=11.6.2=h65bbf44_0
- cudatoolkit=11.5.0=hfde6d1b_9
- cycler=0.11.0=pyhd8ed1ab_0
- fonttools=4.25.0=pyhd3eb1b0_0
- freetype=2.10.4=h546665d_1
- icu=69.1=h0e60522_0
- idna=3.3=pyhd8ed1ab_0
- intel-openmp=2022.0.0=h57928b3_3663
- jbig=2.1=h8d14728_2003
- jinja2=3.1.1=pyhd8ed1ab_0
- joblib=1.1.0=pyhd8ed1ab_0
- jpeg=9e=h8ffe710_0
- kiwisolver=1.4.2=py39h2e07f2f_0
- lcms2=2.12=h2a16943_0
- lerc=3.0=h0e60522_0
- libblas=3.9.0=13_win64_mkl
- libbrotlicommon=1.0.9=h8ffe710_6
- libbrotlidec=1.0.9=h8ffe710_6
- libbrotlienc=1.0.9=h8ffe710_6
- libcblas=3.9.0=13_win64_mkl
- libclang=13.0.1=default_h81446c8_0
- libcublas=11.9.2.110=hb6c9e1a_0
- libcublas-dev=11.9.2.110=h1662081_0
- libcufft=10.7.2.124=hfa55d67_0
- libcufft-dev=10.7.2.124=hfb516c1_0
- libcurand=10.2.9.124=h99c9f72_0
- libcurand-dev=10.2.9.124=he731d31_0
- libcusolver=11.3.4.124=hc34ff3b_0
- libcusolver-dev=11.3.4.124=h4c0dfd9_0
- libcusparse=11.7.2.124=hfe5ea2b_0
- libcusparse-dev=11.7.2.124=h39db74a_0
- libdeflate=1.10=h8ffe710_0
- libffi=3.4.2=h8ffe710_5
- liblapack=3.9.0=13_win64_mkl
- liblapacke=3.9.0=13_win64_mkl
- libnpp=11.6.3.124=h516bb01_0
- libnpp-dev=11.6.3.124=hc355075_0
- libnvjpeg=11.6.2.124=hf97cc0b_0
- libnvjpeg-dev=11.6.2.124=h6c8d1d7_0
- libpng=1.6.37=h1d00b33_2
- libtiff=4.3.0=hc4061b1_3
- libuv=1.43.0=h8ffe710_0
- libwebp=1.2.2=h57928b3_0
- libwebp-base=1.2.2=h8ffe710_1
- libxcb=1.13=hcd874cb_1004
- libzlib=1.2.11=h8ffe710_1014
- lz4=3.1.10=py39h92e281b_0
- lz4-c=1.9.3=h8ffe710_1
- m2w64-gcc-libgfortran=5.3.0=6
- m2w64-gcc-libs=5.3.0=7
- m2w64-gcc-libs-core=5.3.0=7
- m2w64-gmp=6.1.0=2
- m2w64-libwinpthread-git=5.0.0.4634.697f757=2
- markupsafe=2.1.1=py39hb82d6ee_1
- matplotlib=3.5.1=py39hcbf5309_0
- matplotlib-base=3.5.1=py39h581301d_0
- mkl=2022.0.0=h0e2418a_796
- mkl-devel=2022.0.0=h57928b3_797
- mkl-include=2022.0.0=h0e2418a_796
- msys2-conda-epoch=20160418=1
- munkres=1.0.7=py_1
- networkx=2.7.1=pyhd8ed1ab_1
- numpy=1.22.3=py39h6331f09_0
- openjpeg=2.4.0=hb211442_1
- openssl=1.1.1n=h8ffe710_0
- packaging=21.3=pyhd8ed1ab_0
- pandas=1.4.1=py39h2e25243_0
- pillow=9.0.1=py39ha53f419_2
- pip=22.0.4=pyhd8ed1ab_0
- pthread-stubs=0.4=hcd874cb_1001
- pycparser=2.21=pyhd8ed1ab_0
- pyg=2.0.4=py39_torch_1.11.0_cu115
- pyopenssl=22.0.0=pyhd8ed1ab_0
- pyparsing=3.0.7=pyhd8ed1ab_0
- pyqt=5.12.3=py39hb0d2dfa_4
- pysocks=1.7.1=py39hcbf5309_4
- python=3.9.12=h9a09f29_1_cpython
- python-dateutil=2.8.2=pyhd8ed1ab_0
- python-louvain=0.15=pyhd8ed1ab_1
- python_abi=3.9=2_cp39
- pytorch=1.11.0=py3.9_cuda11.5_cudnn8_0
- pytorch-cluster=1.6.0=py39_torch_1.11.0_cu115
- pytorch-mutex=1.0=cuda
- pytorch-scatter=2.0.9=py39_torch_1.11.0_cu115
- pytorch-sparse=0.6.13=py39_torch_1.11.0_cu115
- pytorch-spline-conv=1.2.1=py39_torch_1.11.0_cu115
- pytz=2022.1=pyhd8ed1ab_0
- pyyaml=6.0=py39hb82d6ee_4
- qt=5.12.9=h556501e_6
- requests=2.27.1=pyhd8ed1ab_0
- scikit-learn=1.0.2=py39he931e04_0
- scipy=1.8.0=py39hc0c34ad_1
- setuptools=61.3.0=py39hcbf5309_0
- six=1.16.0=pyh6c4a22f_0
- sqlite=3.37.1=h8ffe710_0
- tbb=2021.5.0=h2d74725_0
- threadpoolctl=3.1.0=pyh8a188c0_0
- tk=8.6.12=h8ffe710_0
- tornado=6.1=py39hb82d6ee_3
- tqdm=4.63.1=pyhd8ed1ab_0
- typing_extensions=4.1.1=pyha770c72_0
- tzdata=2022a=h191b570_0
- ucrt=10.0.20348.0=h57928b3_0
- urllib3=1.26.9=pyhd8ed1ab_0
- vc=14.2=hb210afc_6
- vs2015_runtime=14.29.30037=h902a5da_6
- wheel=0.37.1=pyhd8ed1ab_0
- win_inet_pton=1.1.0=py39hcbf5309_4
- xorg-libxau=1.0.9=hcd874cb_0
- xorg-libxdmcp=1.1.3=hcd874cb_0
- xz=5.2.5=h62dcd97_1
- yacs=0.1.8=pyhd8ed1ab_0
- yaml=0.2.5=h8ffe710_2
- zlib=1.2.11=h8ffe710_1014
- zstd=1.5.2=h6255e5f_0
- pip:
- pyqt5-sip==4.19.18
- pyqtchart==5.12
- pyqtwebengine==5.12.1
prefix: C:\ProgramData\Anaconda3\envs\interaction_gnn
from src.model import GAT, GCN, LinearNetwork
from configparser import ConfigParser, Error
from datetime import datetime
def read_config(path):
config_object = ConfigParser()
config_object.read(path)
# datetime object containing current date and time
now = datetime.now()
dt_string = now.strftime("%d_%m_%Y_%Hh%Mm%Ss")
model_no = "run_" + str(dt_string)
runinfo = config_object['RUNINFO']
modelinfo = config_object['MODELINFO']
runparams = config_object['RUNPARAMS']
runinfo['model_no'] = model_no
return runinfo, modelinfo, runparams
# functions to read config file
def str_to_bool(s):
if s == 'True':
return True
elif s == 'False':
return False
def str_to_model(s):
if s == "GCN":
return GCN
elif s == "GAT":
return GAT
elif s == "Linear":
return LinearNetwork
else:
raise Error("Model error: {} not recognized as a model".format(s))
\ No newline at end of file
import ast, os
def transform_Dockground_name(pair, sample):
if '\uf03a' in pair:
pair1, pair2 = pair.split('\uf03a')
pair2 = pair1.split('_')[0] + '_' + pair2
sample_id = sample.split('_')[2]
return pair1 + '--' + pair2, pair1 + '--' + pair2 + '_' + sample_id
elif len(pair.split('_')) == 3:
base, pair1, pair2 = pair.split('_')
pair1 = base + '_' + pair1
pair2 = base + '_' + pair2
sample_id = sample.split('_')[3]
return pair1 + '--' + pair2, pair1 + '--' + pair2 + '_' + sample_id
else:
raise ValueError('Dockground utils: problem with pair {}'.format(pair))
def select_Dockground_split(folder_path, dataset):
groups_txt_path = os.path.join(folder_path, 'groups.txt')
with open(groups_txt_path) as f:
data = f.read()
data = data.split('=')[1].replace(' ', '').strip().upper()
split_dict = ast.literal_eval(data)
splits = [[] for f in range(len(split_dict))]
for d in dataset:
pair1, pair2 = d.pair.split('--')
for key_fold, values in enumerate(split_dict.values()): # Use enumerate() to assure presence of correct keys (i.e. from 0 to len(split_dict) - 1)
for v in values:
if v in pair1.upper() and v in pair2.upper():
splits[key_fold].append(d)
elif v not in pair1.upper() and v not in pair2.upper():
continue
else:
raise ValueError('Dockground utils: problem with pair {}'.format(d.pair))
predefined_splits = [{'train':[], 'val':[]} for f in range(len(split_dict))]
for k in range(len(split_dict)):
predefined_splits[k]['val'] = splits[k]
for fold in range(len(split_dict)):
if fold != k:
predefined_splits[k]['train'] += splits[fold]
return predefined_splits
\ No newline at end of file
import numpy as np
import pandas as pd
class ScoreDataframe():
# Class to extract scores and relevant metrics from a dataframe
def __init__(self):
self.df = pd.DataFrame({'conformation_name': [], 'fold': [], 'empty_cell_1': [], 'empty_cell_2': [], 'score': [], 'time_of_prediction': [], 'target': [], 'epoch': []})
self.df = self.df.astype({'conformation_name': 'str', 'fold': 'int', 'score': 'float', 'time_of_prediction': 'float', 'target': 'int', 'epoch': 'int'})
def add_row(self, conformation_name, fold, score, time_of_prediction, target, epoch):
# Add one result to the dataframe
if score < 0.0 or score > 1.0:
raise ValueError('Score must be between 0 and 1')
self.df = pd.concat([self.df, {'conformation_name': conformation_name, 'fold': fold, 'score': score, 'time_of_prediction': time_of_prediction, 'target': target, 'epoch': epoch}], ignore_index=True)
def add_rows(self, conformation_names, folds, scores, time_of_predictions, targets, epoch):
# Add multiple results to the dataframe
if len(conformation_names) != len(folds) or len(conformation_names) != len(scores) or len(conformation_names) != len(time_of_predictions) or len(conformation_names) != len(targets):
raise ValueError('All lists must be of the same length')
news = {'conformation_name': conformation_names, 'fold': folds, 'score': scores, 'time_of_prediction': time_of_predictions, 'target': targets, 'epoch': [epoch for _ in range(len(conformation_names))]}
new_df = pd.DataFrame(news)
self.df = pd.concat([self.df, new_df], ignore_index=True)
def export(self, savepath):
# Export the dataframe to a csv file
self.df.to_csv(savepath, header=True)
import numpy as np
import torch
import json, os
import lz4.frame
import pandas as pd
import matplotlib.pyplot as plt
import sklearn.metrics as skmetrics
from sklearn.metrics import RocCurveDisplay
import shelve
from src.model import Prediction
def as_numpy(values):
# Convert a Tensor or a list of Tensors to numpy structures
if type(values) is torch.Tensor:
return values.cpu().detach().numpy()
elif type(values) is list or type(values) is tuple:
output = [0 for i in range(len(values))]
for i in range(len(values)):
if type(values[i]) is torch.Tensor:
output[i] = values[i].cpu().detach().numpy()
else:
output[i] = values[i]
return output
else:
return values
def as_numpy_tuple(pred:Prediction):
# Pass Prediction named tuple from the model to Prediction named tuple made of numpy arrays
return Prediction(as_numpy(pred.loss), as_numpy(pred.pred), as_numpy(pred.gt), as_numpy(pred.correct))
def save_with_compression(obj, path, compression_level=lz4.frame.COMPRESSIONLEVEL_MINHC):
# Save a python object to a compressed file
with lz4.frame.open(path, mode='wb') as fp:
obj_bytes = json.dumps(obj).encode('utf-8')
compressed_obj = lz4.frame.compress(obj_bytes, compression_level=compression_level)
fp.write(compressed_obj)
def load_with_compression(path):
# Load a python object from a compressed file
with lz4.frame.open(path, mode='r') as fp:
output_compressed_data = fp.read()
obj_bytes = lz4.frame.decompress(output_compressed_data)
obj = json.loads(obj_bytes.decode('utf-8'))
return obj
def euclidian_distance(coords_from, coords_to):
# Compute the euclidian distance between two points
if len(coords_from) != len(coords_to):
raise ValueError('Euclidian distance: coords don\'t have the same shape')
else:
return np.sqrt(np.array([np.power(i-j, 2) for i,j in zip(coords_from, coords_to)]).sum())
def plot_roc_pr_curves(metrics, savepath, plot_roc_lines=False):
# Plot meaningful ROC and PR curves
if not plot_roc_lines:
if len(np.unique(np.array([len(a) for a in metrics['tprs']]))) != 1:
raise ValueError("Error while plotting ROC curves")
# Plot ROC Curve...
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
# Plot ROC baseline
ax1.plot([0, 1], [0, 1], linestyle="--", lw=2, color="r", label="Chance", alpha=0.8)
if plot_roc_lines: # Plot each ROC curve (one per fold)
for fpr,tpr,auc in zip(metrics['fprs'], metrics['tprs'], metrics['aucs']):
ax1.plot(
fpr,
tpr,
label=r"Mean ROC (AUC = %0.2f)" % (auc),
lw=1,
alpha=0.8,
)
else:
# Plot mean ROC curve (one curve for all folds)
# Area between best and worst ROC curves are colored in grey
mean_tpr = np.mean(metrics['tprs'], axis=0)
mean_tpr[-1] = 1.0
mean_auc = skmetrics.auc(metrics['base_fpr'], mean_tpr)
std_auc = np.std(metrics['aucs'])
ax1.plot(
metrics['base_fpr'],
mean_tpr,
color="b",
label=r"Mean ROC (AUC = %0.2f $\pm$ %0.2f)" % (mean_auc, std_auc),
lw=2,
alpha=0.8,
)
std_tpr = np.std(metrics['tprs'], axis=0)
tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
ax1.fill_between(
metrics['base_fpr'],
tprs_lower,
tprs_upper,
color="grey",
alpha=0.2,
label=r"$\pm$ 1 std. dev.",
)
# Label, legend and set axe limits
ax1.set(
xlim=[-0.05, 1.05],
ylim=[-0.05, 1.05],
title="Receiver Operating Characteristic (ROC) Curve",
xlabel='False Positive Rate',
ylabel='True Positive Rate'
)
ax1.legend(loc="lower right")
# Plot PR Curve...
if len(metrics['precisions']) != len(metrics['recalls']):
raise ValueError("Error while computing prediction / recall")
else:
nb_folds = len(metrics['precisions'])
# Plot each PR curve (one per fold)
for p,r,i in zip(metrics['precisions'], metrics['recalls'], range(nb_folds)):
ax2.plot(
r,
p,
label="fold {}".format(i),
lw=1
)
# Label, legend and set axe limits
ax2.legend(loc="lower right")
ax2.set(
xlim=[-0.05, 1.05],
ylim=[-0.05, 1.05],
title="Precision Recall Curve",
xlabel='Recall',
ylabel='Precision'
)
# Save figure
plt.savefig(savepath)
plt.close()
def save_session(shelf_path):
# Save the current session
my_shelf = shelve.open(shelf_path,'n')
for key in globals().keys():
try:
my_shelf[key] = globals()[key]
except:
pass
my_shelf.close()
def load_session(shelf_path):
# Load a session
my_shelf = shelve.open(shelf_path)
for key in my_shelf:
try:
globals()[key]=my_shelf[key]
except:
print('Not loaded:', key)
pass
import sys, os
sys.path.insert(0, os.path.abspath(os.getcwd()))
import numpy as np
import torch
import matplotlib.pyplot as plt
import pandas as pd
import torch_geometric.data as geom_data
from src.data import ProteinDataset, StratifiedSplit, kFoldStratifiedCrossValidation, reduce_split
from torch_geometric.data import Data, Dataset, Batch, DataLoader
from src.model import GCN, GAT
from src.utils import as_numpy, as_numpy_tuple, plot_roc_pr_curves
from datetime import datetime
from sklearn.metrics import roc_curve, roc_auc_score, precision_recall_fscore_support, precision_recall_curve
from sklearn.metrics import PrecisionRecallDisplay, RocCurveDisplay
from configparser import ConfigParser
from src.config import str_to_bool, str_to_model, read_config
import unittest
class TestStratifiedSplit(unittest.TestCase):
def testIndependance(self):
data_dir = './data/intermediate_m5_e30_nonorm_split1/'
contains_pairs = True
skiptest = False
use_kfold = True
nb_folds = 5
seed = 123 # Seed to fix the simulation
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Build dataset
dataset = ProteinDataset(data_dir, random_seed=seed, device=device, nb_features=22)
dataset.reduce(nb_negative=200, nb_positive=200)
dataset_train, dataset_test = StratifiedSplit(dataset, first_partition=0.85, second_partition=0.15, shuffle=True)
# dataset_test = reduce_split(dataset_test, nb_positive=100, nb_negative=100)
train_len = (len(dataset_train)*(nb_folds -1)) / nb_folds
val_len = (len(dataset_train)) / nb_folds
test_len = len(dataset_test)
print()
print('Training on {} proteins.'.format(train_len))
print('Validating on {} proteins.'.format(val_len))
print('Testing on {} proteins.'.format(test_len))
kFoldStrCV = kFoldStratifiedCrossValidation(dataset_train, nb_folds=nb_folds, nb_classes=2, shuffle=True)
iter_kFold = iter(kFoldStrCV)
range_fold = range(nb_folds)
# Set tests
all_train_pairs = [d.pair for d in dataset_train]
test_pairs = [d.pair for d in dataset_test]
for i_p in range(len(test_pairs)):
with self.subTest(msg='Pair {}'.format(test_pairs[i_p])):
self.assertNotIn(test_pairs[i_p], all_train_pairs)
for k in range_fold:
train_set, val_set = iter_kFold.__next__()
train_pairs = [d.pair for d in train_set]
val_pairs = [d.pair for d in val_set]
for i_p in range(len(val_pairs)):
with self.subTest(msg='Fold {} - Pair {}'.format(k, val_pairs[i_p])):
self.assertNotIn(val_pairs[i_p], train_pairs)
if __name__ == '__main__':
unittest.main()
This diff is collapsed. Click to expand it.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment