from __future__ import annotations
import copy
import glob
import json
import os
import shutil
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Literal,
Optional,
Tuple,
Union,
)
if TYPE_CHECKING:
import subprocess
from datashuttle.utils.custom_types import (
ConnectionMethods,
DisplayMode,
OverwriteExistingFiles,
Prefix,
TopLevelFolder,
)
import yaml
from datashuttle.configs import (
canonical_configs,
canonical_folders,
load_configs,
)
from datashuttle.configs.config_class import Configs
from datashuttle.datashuttle_functions import _format_top_level_folder
from datashuttle.utils import (
aws,
ds_logger,
folders,
formatting,
gdrive,
getters,
rclone,
rclone_encryption,
ssh,
utils,
validation,
)
from datashuttle.utils.custom_exceptions import (
ConfigError,
NeuroBlueprintError,
)
from datashuttle.utils.data_transfer import TransferData
from datashuttle.utils.decorators import ( # noqa
check_configs_set,
check_is_not_local_project,
requires_aws_configs,
requires_ssh_configs,
)
from datashuttle.utils.transfer_output_class import TransferOutput
# -----------------------------------------------------------------------------
# Project Manager Class
# -----------------------------------------------------------------------------
[docs]
class DataShuttle:
"""DataShuttle is a tool for neuroscience project management and data transfer."""
def __init__(self, project_name: str, print_startup_message: bool = True):
"""Initialise ``DataShuttle``.
Parameters
----------
project_name
The project name.
print_startup_message
If `True`, a start-up message displaying the
current state of the program (e.g. persistent
settings such as the 'top-level folder') is shown.
"""
self._error_on_base_project_name(project_name)
self.project_name = project_name
(
self._datashuttle_path,
self._temp_log_path,
) = canonical_folders.get_project_datashuttle_path(self.project_name)
folders.create_folders([self._datashuttle_path, self._temp_log_path])
self._config_path = self._datashuttle_path / "config.yaml"
self._persistent_settings_path = (
self._datashuttle_path / "persistent_settings.yaml"
)
self.cfg: Any = None
self.cfg = load_configs.attempt_load_configs(
self.project_name, self._config_path, verbose=print_startup_message
)
if self.cfg:
self._set_attributes_after_config_load()
def _set_attributes_after_config_load(self) -> None:
"""Update all private attributes according to config contents."""
self.cfg.init_paths()
self._make_project_metadata_if_does_not_exist()
# -------------------------------------------------------------------------
# Public Folder Makers
# -------------------------------------------------------------------------
[docs]
@check_configs_set
def create_folders(
self,
top_level_folder: TopLevelFolder,
sub_names: Union[str, List[str]],
ses_names: Optional[Union[str, List[str]]] = None,
datatype: Union[str, List[str]] = "",
bypass_validation: bool = False,
allow_letters_in_sub_ses_values: bool = False,
log: bool = True,
) -> Dict[str, List[Path]]:
"""Create a folder tree in the project folder.
The passed names are initially formatted and validated,
then folders are created.
Parameters
----------
top_level_folder
Whether to make the folders within `rawdata` or
`derivatives`.
sub_names
subject name / list of subject names to make
within the top-level project folder
(if not already, these will be prefixed with
"sub-")
ses_names
session name / list of session names.
(if not already, these will be prefixed with
"ses-"). If no session is provided, no session-level
folders are made.
datatype
The datatype to make in the sub / ses folders.
(e.g. "ephys", "behav", "anat"). If "" is
passed no datatype will be created. Broad or
Narrow NeuroBlueprint datatypes are accepted.
bypass_validation
If `True`, folders will be created even if they are not
valid to NeuroBlueprint style.
allow_letters_in_sub_ses_values
If `True`, any alphanumeric character are allowed for the values associated
with sub- or ses- keys. Otherwise, values must be integer
and the following additional checks are performed:
- Labels must be the same length (e.g. sub-01 and sub-002 is invalid).
log
If `True`, details of folder creation will be logged.
Returns
-------
created_paths
A dictionary of the full filepaths made during folder creation,
where the keys are the type of folder made and the values are a
list of created folder paths (Path objects). If datatype were
created, the dict keys will separate created folders by datatype
name. Similarly, if only subject or session level folders were
created, these are separated by "sub" and "ses" keys.
Notes
-----
sub_names or ses_names may contain formatting tags
@TO@
used to make a range of subjects / sessions.
Boundaries of the range must be either side of the tag
e.g. sub-001@TO@003 will generate ["sub-001", "sub-002", "sub-003"]
@DATE@, @TIME@ @DATETIME@
will add date-<value>, time-<value> or
date-<value>_time-<value> keys respectively. Only one per-name
is permitted.
e.g. sub-001_@DATE@ will generate sub-001_date-20220101 (on the 1st january, 2022).
Examples
--------
project.create_folders("rawdata", "sub-001", datatype="behav")
project.create_folders("rawdata", "sub-002@TO@005", ["ses-001", "ses-002"], ["ephys", "behav"])
"""
if log:
self._start_log(
"create-folders",
local_vars={
"top_level_folder": top_level_folder,
"sub_names": sub_names,
"ses_names": ses_names,
"datatype": datatype,
"bypass_validation": bypass_validation,
},
)
self._check_top_level_folder(top_level_folder)
if ses_names is None and datatype != "":
datatype = ""
utils.log_and_message(
"`datatype` passed without `ses_names`, no datatype "
"folders will be created."
)
utils.log("\nFormatting Names...")
ds_logger.log_names(["sub_names", "ses_names"], [sub_names, ses_names])
validation_templates = self.get_validation_templates()
format_sub, format_ses = self._format_and_validate_names(
top_level_folder,
sub_names,
ses_names,
validation_templates,
bypass_validation,
allow_letters_in_sub_ses_values,
log=True,
)
ds_logger.log_names(
["formatted_sub_names", "formatted_ses_names"],
[format_sub, format_ses],
)
utils.log("\nMaking folders...")
created_paths = folders.create_folder_trees(
self.cfg,
top_level_folder,
format_sub,
format_ses,
datatype,
log=True,
)
utils.print_message_to_user("Finished making folders.")
if log:
utils.print_message_to_user(
f"For log of all created folders, "
f"please see {self.cfg.logging_path}"
)
ds_logger.close_log_filehandler()
return created_paths
def _format_and_validate_names(
self,
top_level_folder: TopLevelFolder,
sub_names: Union[str, List[str]],
ses_names: Optional[Union[str, List[str]]],
validation_templates: Dict,
bypass_validation: bool,
allow_letters_in_sub_ses_values: bool,
log: bool = True,
) -> Tuple[List[str], List[str]]:
"""Central method to format and validate subject and session names."""
format_sub = formatting.check_and_format_names(
sub_names,
"sub",
validation_templates,
bypass_validation,
allow_letters_in_sub_ses_values,
)
if ses_names is not None:
format_ses = formatting.check_and_format_names(
ses_names,
"ses",
validation_templates,
bypass_validation,
allow_letters_in_sub_ses_values,
)
else:
format_ses = []
if not bypass_validation:
validation.validate_names_against_project(
self.cfg,
top_level_folder,
format_sub,
format_ses,
include_central=False,
display_mode="error",
log=log,
validation_templates=validation_templates,
allow_letters_in_sub_ses_values=allow_letters_in_sub_ses_values,
)
return format_sub, format_ses
# -------------------------------------------------------------------------
# Public File Transfer
# -------------------------------------------------------------------------
[docs]
@check_configs_set
@check_is_not_local_project
def upload_custom(
self,
top_level_folder: TopLevelFolder,
sub_names: Union[str, list],
ses_names: Union[str, list],
datatype: Union[List[str], str] = "all",
overwrite_existing_files: OverwriteExistingFiles = "never",
dry_run: bool = False,
init_log: bool = True,
display_transfer_output: bool = True,
) -> TransferOutput:
"""Upload data from a local project to the central project folder.
Parameters
----------
top_level_folder
The top-level folder (e.g. `"rawdata"`, `"derivatives"`) to transfer within.
sub_names
A subject name / list of subject names. These must
be prefixed with ``"sub-"``, or the prefix will be
automatically added. ``"@*@"`` can be used as a wildcard.
"all" will search for all sub-folders in the
datatype folder to upload.
ses_names
A session name / list of session names, similar to
sub_names but requiring a ``"ses-"`` prefix.
datatype
The (broad or narrow) NeuroBlueprint datatypes to transfer.
If ``"all"``, any broad or narrow datatype folder will be transferred.
overwrite_existing_files
If ``"never"`` files on target will never be overwritten by source.
If ``"always"`` files on target will be overwritten by source if
there is any difference in date or size.
If ``"if_source_newer"`` files on target will only be overwritten
by files on source with newer creation / modification datetime.
dry_run
Perform a dry-run of transfer. This will output as if file
transfer was taking place, but no files will be moved.
init_log
Whether to handle logging. This should
always be ``True``, unless logger is handled elsewhere
(e.g. in a calling function).
display_transfer_output
If `True`, a summary of number of transferred files, and any
errors will be printed alongside the Rclone logs.
"""
if init_log:
self._start_log(
"upload-custom",
local_vars={
"top_level_folder": top_level_folder,
"sub_names": sub_names,
"ses_names": ses_names,
"datatype": datatype,
"overwrite_existing_files": overwrite_existing_files,
"dry_run": dry_run,
},
)
self._check_top_level_folder(top_level_folder)
transfer_output = TransferData(
self.cfg,
"upload",
top_level_folder,
sub_names,
ses_names,
datatype,
overwrite_existing_files,
dry_run,
).run()
if display_transfer_output:
rclone.log_rclone_transfer_output(transfer_output)
if init_log:
ds_logger.close_log_filehandler()
return transfer_output
[docs]
@check_configs_set
@check_is_not_local_project
def download_custom(
self,
top_level_folder: TopLevelFolder,
sub_names: Union[str, list],
ses_names: Union[str, list],
datatype: Union[List[str], str] = "all",
overwrite_existing_files: OverwriteExistingFiles = "never",
dry_run: bool = False,
init_log: bool = True,
display_transfer_output: bool = True,
) -> TransferOutput:
"""Download data from the central project to the local project folder.
Parameters
----------
top_level_folder
The top-level folder (e.g. `"rawdata"`, `"derivatives"`) to transfer within.
sub_names
A subject name / list of subject names. These must
be prefixed with ``"sub-"``, or the prefix will be
automatically added. ``"@*@"`` can be used as a wildcard.
"all" will search for all sub-folders in the
datatype folder to upload.
ses_names
A session name / list of session names, similar to
sub_names but requiring a ``"ses-"`` prefix.
datatype
The (broad or narrow) NeuroBlueprint datatypes to transfer.
If ``"all"``, any broad or narrow datatype folder will be transferred.
overwrite_existing_files
If ``"never"`` files on target will never be overwritten by source.
If ``"always"`` files on target will be overwritten by source if
there is any difference in date or size.
If ``"if_source_newer"`` files on target will only be overwritten
by files on source with newer creation / modification datetime.
dry_run
Perform a dry-run of transfer. This will output as if file
transfer was taking place, but no files will be moved.
init_log
Whether to handle logging. This should
always be ``True``, unless logger is handled elsewhere
(e.g. in a calling function).
display_transfer_output
If `True`, a summary of number of transferred files, and any
errors will be printed alongside the Rclone logs.
"""
if init_log:
self._start_log(
"download-custom",
local_vars={
"top_level_folder": top_level_folder,
"sub_names": sub_names,
"ses_names": ses_names,
"datatype": datatype,
"overwrite_existing_files": overwrite_existing_files,
"dry_run": dry_run,
},
)
self._check_top_level_folder(top_level_folder)
transfer_output = TransferData(
self.cfg,
"download",
top_level_folder,
sub_names,
ses_names,
datatype,
overwrite_existing_files,
dry_run,
).run()
if display_transfer_output:
rclone.log_rclone_transfer_output(transfer_output)
if init_log:
ds_logger.close_log_filehandler()
return transfer_output
# Specific top-level folder
# ----------------------------------------------------------------------------------
# A set of convenience functions are provided to abstract
# away the 'top_level_folder' concept.
[docs]
@check_configs_set
@check_is_not_local_project
def upload_rawdata(
self,
overwrite_existing_files: OverwriteExistingFiles = "never",
dry_run: bool = False,
) -> TransferOutput:
"""Upload all files in the `rawdata` top level folder.
Parameters
----------
overwrite_existing_files
If ``"never"`` files on target will never be overwritten by source.
If ``"always"`` files on target will be overwritten by source if
there is any difference in date or size.
If ``"if_source_newer"`` files on target will only be overwritten
by files on source with newer creation / modification datetime.
dry_run
Perform a dry-run of transfer. This will output as if file
transfer was taking place, but no files will be moved.
"""
return self._transfer_top_level_folder(
"upload",
"rawdata",
overwrite_existing_files=overwrite_existing_files,
dry_run=dry_run,
)
[docs]
@check_configs_set
@check_is_not_local_project
def upload_derivatives(
self,
overwrite_existing_files: OverwriteExistingFiles = "never",
dry_run: bool = False,
) -> TransferOutput:
"""Upload all files in the `derivatives` top level folder.
Parameters
----------
overwrite_existing_files
If ``"never"`` files on target will never be overwritten by source.
If ``"always"`` files on target will be overwritten by source if
there is any difference in date or size.
If ``"if_source_newer"`` files on target will only be overwritten
by files on source with newer creation / modification datetime.
dry_run
Perform a dry-run of transfer. This will output as if file
transfer was taking place, but no files will be moved.
"""
return self._transfer_top_level_folder(
"upload",
"derivatives",
overwrite_existing_files=overwrite_existing_files,
dry_run=dry_run,
)
[docs]
@check_configs_set
@check_is_not_local_project
def download_rawdata(
self,
overwrite_existing_files: OverwriteExistingFiles = "never",
dry_run: bool = False,
) -> TransferOutput:
"""Download all files in the `rawdata` top level folder.
Parameters
----------
overwrite_existing_files
If ``"never"`` files on target will never be overwritten by source.
If ``"always"`` files on target will be overwritten by source if
there is any difference in date or size.
If ``"if_source_newer"`` files on target will only be overwritten
by files on source with newer creation / modification datetime.
dry_run
Perform a dry-run of transfer. This will output as if file
transfer was taking place, but no files will be moved..
"""
return self._transfer_top_level_folder(
"download",
"rawdata",
overwrite_existing_files=overwrite_existing_files,
dry_run=dry_run,
)
[docs]
@check_configs_set
@check_is_not_local_project
def download_derivatives(
self,
overwrite_existing_files: OverwriteExistingFiles = "never",
dry_run: bool = False,
) -> TransferOutput:
"""Download all files in the `derivatives` top level folder.
Parameters
----------
overwrite_existing_files
If ``"never"`` files on target will never be overwritten by source.
If ``"always"`` files on target will be overwritten by source if
there is any difference in date or size.
If ``"if_source_newer"`` files on target will only be overwritten
by files on source with newer creation / modification datetime.
dry_run
Perform a dry-run of transfer. This will output as if file
transfer was taking place, but no files will be moved.
"""
return self._transfer_top_level_folder(
"download",
"derivatives",
overwrite_existing_files=overwrite_existing_files,
dry_run=dry_run,
)
[docs]
@check_configs_set
@check_is_not_local_project
def upload_entire_project(
self,
overwrite_existing_files: OverwriteExistingFiles = "never",
dry_run: bool = False,
) -> TransferOutput:
"""Upload the entire project.
Includes every top level folder (e.g. ``rawdata``, ``derivatives``).
Parameters
----------
overwrite_existing_files
If ``"never"`` files on target will never be overwritten by source.
If ``"always"`` files on target will be overwritten by source if
there is any difference in date or size.
If ``"if_source_newer"`` files on target will only be overwritten
by files on source with newer creation / modification datetime.
dry_run
Perform a dry-run of transfer. This will output as if file
transfer was taking place, but no files will be moved.
"""
self._start_log(
"upload-entire-project",
local_vars={
"overwrite_existing_files": overwrite_existing_files,
"dry_run": dry_run,
},
)
transfer_output = self._transfer_entire_project(
"upload", overwrite_existing_files, dry_run
)
ds_logger.close_log_filehandler()
return transfer_output
[docs]
@check_configs_set
@check_is_not_local_project
def download_entire_project(
self,
overwrite_existing_files: OverwriteExistingFiles = "never",
dry_run: bool = False,
) -> TransferOutput:
"""Download the entire project.
Includes every top level folder (e.g. ``rawdata``, ``derivatives``).
Parameters
----------
overwrite_existing_files
If ``"never"`` files on target will never be overwritten by source.
If ``"always"`` files on target will be overwritten by source if
there is any difference in date or size.
If ``"if_source_newer"`` files on target will only be overwritten
by files on source with newer creation / modification datetime.
dry_run
Perform a dry-run of transfer. This will output as if file
transfer was taking place, but no files will be moved.
"""
self._start_log(
"download-entire-project",
local_vars={
"overwrite_existing_files": overwrite_existing_files,
"dry_run": dry_run,
},
)
transfer_output = self._transfer_entire_project(
"download", overwrite_existing_files, dry_run
)
ds_logger.close_log_filehandler()
return transfer_output
[docs]
@check_configs_set
@check_is_not_local_project
def upload_specific_folder_or_file(
self,
filepath: Union[str, Path],
overwrite_existing_files: OverwriteExistingFiles = "never",
dry_run: bool = False,
) -> TransferOutput:
"""Upload a specific file or folder.
If transferring a single file, the path including the filename is
required (see 'filepath' input). If a folder, wildcards "*" or "**" must be used to transfer
all files in the folder ("*") or all files and sub-folders ("**").
Parameters
----------
filepath
a string containing the full filepath.
overwrite_existing_files
If ``"never"`` files on target will never be overwritten by source.
If ``"always"`` files on target will be overwritten by source if
there is any difference in date or size.
If ``"if_source_newer"`` files on target will only be overwritten
by files on source with newer creation / modification datetime.
dry_run
Perform a dry-run of transfer. This will output as if file
transfer was taking place, but no files will be moved.
"""
self._start_log(
"upload-specific-folder-or-file",
local_vars={
"filepath": filepath,
"overwrite_existing_files": overwrite_existing_files,
"dry_run": dry_run,
},
)
transfer_output = self._transfer_specific_file_or_folder(
"upload", filepath, overwrite_existing_files, dry_run
)
ds_logger.close_log_filehandler()
return transfer_output
[docs]
@check_configs_set
@check_is_not_local_project
def download_specific_folder_or_file(
self,
filepath: Union[str, Path],
overwrite_existing_files: OverwriteExistingFiles = "never",
dry_run: bool = False,
) -> TransferOutput:
"""Download a specific file or folder.
If transferring a single file, the path including the filename is
required (see 'filepath' input). If a folder, wildcards "*" or "**"
must be used to transfer all files in the folder ("*") or all files
and sub-folders ("**").
Parameters
----------
filepath
a string containing the full filepath.
overwrite_existing_files
If ``"never"`` files on target will never be overwritten by source.
If ``"always"`` files on target will be overwritten by source if
there is any difference in date or size.
If ``"if_source_newer"`` files on target will only be overwritten
by files on source with newer creation / modification datetime.
dry_run
Perform a dry-run of transfer. This will output as if file
transfer was taking place, but no files will be moved.
"""
self._start_log(
"download-specific-folder-or-file",
local_vars={
"filepath": filepath,
"overwrite_existing_files": overwrite_existing_files,
"dry_run": dry_run,
},
)
transfer_output = self._transfer_specific_file_or_folder(
"download", filepath, overwrite_existing_files, dry_run
)
ds_logger.close_log_filehandler()
return transfer_output
def _transfer_top_level_folder(
self,
upload_or_download: Literal["upload", "download"],
top_level_folder: TopLevelFolder,
overwrite_existing_files: OverwriteExistingFiles = "never",
dry_run: bool = False,
init_log: bool = True,
display_transfer_output: bool = True,
) -> TransferOutput:
"""Upload or download files within a particular top-level-folder.
A centralised function to upload or download data within
a particular top level folder (e.g. ``rawdata``, ``derivatives``).
"""
if init_log:
self._start_log(
f"{upload_or_download}-{top_level_folder}",
local_vars={
"upload_or_download": upload_or_download,
"top_level_folder": top_level_folder,
"overwrite_existing_files": overwrite_existing_files,
"dry_run": dry_run,
},
)
transfer_func = (
self.upload_custom
if upload_or_download == "upload"
else self.download_custom
)
transfer_output = transfer_func(
top_level_folder,
"all",
"all",
"all",
overwrite_existing_files=overwrite_existing_files,
dry_run=dry_run,
init_log=False,
display_transfer_output=display_transfer_output,
)
if init_log:
ds_logger.close_log_filehandler()
return transfer_output
def _transfer_specific_file_or_folder(
self, upload_or_download, filepath, overwrite_existing_files, dry_run
) -> TransferOutput:
"""Core function for upload/download_specific_folder_or_file()."""
if isinstance(filepath, str):
filepath = Path(filepath)
if upload_or_download == "upload":
base_path = self.cfg["local_path"]
else:
base_path = self.cfg["central_path"]
if base_path is not None:
if not utils.path_starts_with_base_folder(base_path, filepath):
utils.log_and_raise_error(
"Transfer failed. "
"Must pass the full filepath to file or folder to transfer.",
ValueError,
)
processed_filepath = filepath.relative_to(base_path)
top_level_folder = processed_filepath.parts[0]
processed_filepath = Path(*processed_filepath.parts[1:])
else:
assert self.cfg["connection_method"] == "gdrive", (
"`None` only permitted for gdrive or local only mode."
)
processed_filepath = filepath
include_list = [f"--include /{processed_filepath.as_posix()}"]
output = rclone.transfer_data(
self.cfg,
upload_or_download,
top_level_folder,
include_list,
rclone.make_rclone_transfer_options(
overwrite_existing_files, dry_run
),
)
stdout, stderr, transfer_output = rclone.parse_rclone_copy_output(
top_level_folder, output
)
rclone.log_stdout_stderr_python_api(stdout, stderr)
rclone.log_rclone_transfer_output(transfer_output)
return transfer_output
# -------------------------------------------------------------------------
# SSH
# -------------------------------------------------------------------------
[docs]
@requires_ssh_configs
@check_is_not_local_project
def setup_ssh_connection(self) -> None:
"""Set up a connection to the central server using SSH.
Assumes the central_host_id and central_host_username
are set in configs (see make_config_file() and update_config_file()).
First, the server key will be displayed, requiring
verification of the server ID. This will store the
hostkey for all future use.
Next, prompt to input their password for the central
cluster. Once input, SSH private / public key pair
will be setup.
Do not log this method, too high a risk of logging secrets.
"""
if self.cfg["connection_method"] != "ssh":
raise RuntimeError(
"configs `connection_method` must be 'ssh' to set up SSH connection."
)
verified = ssh.verify_ssh_central_host_api(
self.cfg["central_host_id"],
self.cfg.hostkeys_path,
log=True,
)
if verified:
private_key_str = ssh.setup_ssh_key_api(self.cfg, log=True)
self._setup_rclone_central_ssh_config(private_key_str, log=True)
utils.log_and_message(
f"Your SSH key will be stored in the rclone config at:\n "
f"{self.cfg.rclone.get_rclone_central_connection_config_filepath()}.\n"
)
if not self.cfg.rclone.rclone_file_is_encrypted():
if self._ask_user_rclone_encryption():
self._try_encrypt_rclone_config()
rclone.check_successful_connection_and_raise_error_on_fail(
self.cfg
)
utils.log_and_message(
"SSH key pair setup successfully. SSH key saved to the RClone config file."
)
# -------------------------------------------------------------------------
# Google Drive
# -------------------------------------------------------------------------
[docs]
@check_configs_set
def setup_gdrive_connection(self) -> None:
"""Set up a connection to Google Drive using the provided credentials.
Assumes `gdrive_root_folder_id` is set in configs.
First, the user will be prompted to enter their Google Drive client
secret if `gdrive_client_id` is set in the configs.
Next, the user will be asked if their machine has access to a browser.
If not, they will be prompted to input a config_token after running an
rclone command displayed to the user on a machine with access to a browser.
Next, with the provided credentials, the final setup will be done. This
opens up a browser if the user confirmed access to a browser.
Do not log this method, too high a risk of logging secrets.
"""
if self.cfg["connection_method"] != "gdrive":
raise RuntimeError(
"configs `connection_method` must be 'gdrive' to set up Google Drive connection."
)
if self.cfg["gdrive_client_id"]:
gdrive_client_secret = gdrive.get_client_secret()
else:
gdrive_client_secret = None
browser_available = gdrive.ask_user_for_browser(log=True)
if not browser_available:
config_token = gdrive.prompt_and_get_config_token(
self.cfg,
gdrive_client_secret,
self.cfg.rclone.get_rclone_config_name("gdrive"),
log=True,
)
else:
config_token = None
process = self._setup_rclone_gdrive_config(
gdrive_client_secret, config_token
)
rclone.await_call_rclone_with_popen_for_central_connection_raise_on_fail(
self.cfg, process, log=True
)
if not self.cfg.rclone.rclone_file_is_encrypted():
if self._ask_user_rclone_encryption():
self._try_encrypt_rclone_config()
rclone.check_successful_connection_and_raise_error_on_fail(self.cfg)
utils.log_and_message("Google Drive Connection Successful.")
# -------------------------------------------------------------------------
# AWS S3
# -------------------------------------------------------------------------
[docs]
@requires_aws_configs
@check_configs_set
def setup_aws_connection(self) -> None:
"""Set up a connection to AWS S3 buckets using the provided credentials.
Assumes `aws_access_key_id` and `aws_region` are set in configs.
First, the user will be prompted to input their AWS secret access key.
Next, with the provided credentials, the final connection setup will be done.
Do not log this method, too high a risk of logging secrets.
"""
if self.cfg["connection_method"] != "aws":
raise RuntimeError(
"configs `connection_method` must be 'aws' to "
"set up Amazon Web Services S3 Bucket connection."
)
aws_secret_access_key = aws.get_aws_secret_access_key()
self._setup_rclone_aws_config(aws_secret_access_key, log=True)
if not self.cfg.rclone.rclone_file_is_encrypted():
if self._ask_user_rclone_encryption():
self._try_encrypt_rclone_config()
rclone.check_successful_connection_and_raise_error_on_fail(self.cfg)
aws.raise_if_bucket_absent(self.cfg)
utils.log_and_message("AWS Connection Successful.")
# -------------------------------------------------------------------------
# Rclone config encryption
# -------------------------------------------------------------------------
def _ask_user_rclone_encryption(self) -> bool:
"""Get user input to determine if they want to encrypt the rclone config."""
input_ = utils.get_user_input(
f"{rclone_encryption.get_explanation_message(self.cfg)}\n"
f"Press 'y' to encrypt the Rclone config or leave blank to skip."
)
return input_ == "y"
def _try_encrypt_rclone_config(self, is_using_api=True) -> None:
"""Try to encrypt the rclone config file.
If it fails, error and let the user know the config file is unencrypted.
"""
try:
self.encrypt_rclone_config()
except Exception as e:
config_path = (
self.cfg.rclone.get_rclone_central_connection_config_filepath()
)
api_prompt = (
"Use `encrypt_rclone_config()` to attempt to encrypt the file again "
if is_using_api
else ""
)
# don't log during encryption
utils.raise_error(
f"Config encryption failed:\n"
f"{str(e)}\n"
f"{api_prompt}\n\n"
f"IMPORTANT: The config at {config_path} is not currently encrypted.\n",
RuntimeError,
)
utils.print_message_to_user(
f"Rclone config file for the central connection "
f"{self.cfg['connection_method']} was successfully encrypted."
)
[docs]
def encrypt_rclone_config(self) -> None:
"""Encrypt the rclone config file for the central connection."""
if self.cfg.rclone.rclone_file_is_encrypted():
self.remove_rclone_encryption()
rclone_encryption.run_rclone_config_encrypt(self.cfg)
self.cfg.rclone.set_rclone_config_encryption_state(True)
[docs]
def remove_rclone_encryption(self) -> None:
"""Unencrypt the rclone config file for the central connection."""
if not self.cfg.rclone.rclone_file_is_encrypted():
raise RuntimeError(
f"The config for the current connection method: "
f"{self.cfg['connection_method']} "
f"is not encrypted. Cannot unencrypt."
)
rclone_encryption.remove_rclone_encryption(self.cfg)
self.cfg.rclone.set_rclone_config_encryption_state(False)
# -------------------------------------------------------------------------
# Configs
# -------------------------------------------------------------------------
[docs]
def make_config_file(
self,
local_path: str,
central_path: Optional[str] = None,
connection_method: Optional[ConnectionMethods] = "local_only",
central_host_id: Optional[str] = None,
central_host_username: Optional[str] = None,
gdrive_client_id: Optional[str] = None,
gdrive_root_folder_id: Optional[str] = None,
aws_access_key_id: Optional[str] = None,
aws_region: Optional[str] = None,
) -> None:
"""Initialize the configurations for datashuttle on the local machine.
Once initialized, these settings will be used each
time the datashuttle is opened.
These settings are stored in a config file on the
datashuttle path (not in the project folder) on the
local machine. Use ``get_config_path()`` to get the
full path to the saved config file.
Use ``update_config_file()`` to selectively update settings.
Parameters
----------
local_path
path to project folder on local machine
central_path
Filepath to central project.
If this is local (i.e. ``connection_method = "local_filesystem"``),
this is the full path on the local filesystem
Otherwise, if this is via ssh (i.e. ``connection method = "ssh"``),
this is the path to the project folder on central machine.
This should be a full path to central folder i.e. this cannot
include ~ home folder syntax, must contain the full path
(e.g. ``/nfs/nhome/live/jziminski``)
connection_method
The method used to connect to the central project filesystem,
``None`` is an alias for ``"local_only"``.
e.g. ``"local_filesystem"`` (e.g. mounted drive) or ``"ssh"``
central_host_id
server address for central host for ssh connection
e.g. ``"ssh.swc.ucl.ac.uk"``
central_host_username
username for which to log in to central host.
e.g. ``"jziminski"``
gdrive_client_id
The client ID used to authenticate with the Google Drive API via OAuth 2.0.
This is obtained from the Google Cloud Console when setting up API credentials.
e.g. "1234567890-abc123def456.apps.googleusercontent.com"
gdrive_root_folder_id
The folder ID for the Google Drive folder to connect to. This can be copied
directly from your browser when on the folder in Google Drive.
e.g. 1eoAnopd2ZHOd87LgiPtgViFE7u3R9sSw
aws_access_key_id
The AWS access key ID used to authenticate requests to AWS services.
This is part of your AWS credentials and can be generated via the AWS IAM console.
e.g. "AKIAIOSFODNN7EXAMPLE"
aws_region
The AWS region in which your resources are located.
This determines the data center your requests are routed to.
e.g. "us-west-2"
"""
self._start_log(
"make-config-file",
store_in_temp_folder=True,
)
if connection_method is None:
# For backward compatibility
connection_method = "local_only"
if self._config_path.is_file():
utils.log_and_raise_error(
"A config file already exists for this project. "
"Use `update_config_file` to update settings.",
RuntimeError,
)
cfg = Configs(
self.project_name,
self._config_path,
{
"local_path": local_path,
"central_path": central_path,
"connection_method": connection_method,
"central_host_id": central_host_id,
"central_host_username": central_host_username,
"gdrive_client_id": gdrive_client_id,
"gdrive_root_folder_id": gdrive_root_folder_id,
"aws_access_key_id": aws_access_key_id,
"aws_region": aws_region,
},
)
cfg.setup_after_load() # will raise error if fails
self.cfg = cfg
self.cfg.dump_to_file()
self._set_attributes_after_config_load()
# This is just a placeholder rclone config that will suffice
# if central is a 'local filesystem'.
if connection_method != "local_only":
self._setup_rclone_central_local_filesystem_config()
utils.log_and_message(
"Configuration file has been saved and "
"options loaded into datashuttle."
)
self._log_successful_config_change()
self._move_logs_from_temp_folder()
ds_logger.close_log_filehandler()
[docs]
def update_config_file(self, **kwargs) -> None:
"""Update the configuration file.
Parameters
----------
**kwargs
A dictionary of key-value pairs containing the config
settings to update. For example,
``{"connection_method": "local_filesystem", "central_path": "/my/local/path"}``
will update the ``connection_method`` and ``central_path`` settings.
"""
if not self.cfg:
utils.log_and_raise_error(
"Must have a config loaded before updating configs.",
ConfigError,
)
self._start_log("update-config-file", local_vars=kwargs)
if "connection_method" in kwargs:
if kwargs["connection_method"] is None:
# For backward compatibility
kwargs["connection_method"] = "local_only"
if (
self.cfg["connection_method"] == "local_only"
and kwargs["connection_method"] != "local_only"
):
# We need to ensure this rclone config is created if it was not created during
# initial set up because the project is local only. It does not matter if the
# RClone config is ever overwritten, it's just a placeholder.
self._setup_rclone_central_local_filesystem_config()
new_cfg = copy.deepcopy(self.cfg)
new_cfg.update(**kwargs)
new_cfg.setup_after_load() # will raise on error
self.cfg = new_cfg
self._set_attributes_after_config_load()
self.cfg.dump_to_file()
self._log_successful_config_change(message=True)
ds_logger.close_log_filehandler()
# -------------------------------------------------------------------------
# Getters
# -------------------------------------------------------------------------
[docs]
@check_configs_set
def get_local_path(self) -> Path:
"""Return the projects local path."""
return self.cfg["local_path"]
[docs]
@check_configs_set
@check_is_not_local_project
def get_central_path(self) -> Path:
"""Return the project central path."""
return self.cfg["central_path"]
[docs]
def get_datashuttle_path(self) -> Path:
"""Return the path to the local datashuttle folder.
This is where configs and other datashuttle files are stored.
"""
return self._datashuttle_path
[docs]
@check_configs_set
def get_config_path(self) -> Path:
"""Return the full path to the DataShuttle config file."""
return self._config_path
[docs]
@check_configs_set
def get_rclone_central_config_path(self) -> Path:
"""Get the path to the Rclone config for the current `connection_method`."""
return rclone.get_rclone_config_filepath(self.cfg)
[docs]
@check_configs_set
def get_configs(self) -> Configs:
"""Return the datashuttle configs."""
return self.cfg
[docs]
@check_configs_set
def get_logging_path(self) -> Path:
"""Return the path where datashuttle logs are written."""
return self.cfg.logging_path
[docs]
@staticmethod
def get_existing_projects() -> List[Path]:
"""Return a list of existing project names found on the local machine.
This is based on project folders in the "home / .datashuttle" folder
that contain valid config.yaml files.
"""
return getters.get_existing_project_paths()
[docs]
@check_configs_set
def get_next_sub(
self,
top_level_folder: TopLevelFolder,
return_with_prefix: bool = True,
include_central: bool = False,
) -> str:
"""Return the next subject number.
Parameters
----------
top_level_folder
The top-level folder, "rawdata" or "derivatives".
return_with_prefix
If `True`, return the subject with the "sub-" prefix.
include_central
If `False, only get names from `local_path`, otherwise from
`local_path` and `central_path`. If in local-project mode,
this flag is ignored.
Returns
-------
The next subject ID.
"""
validation_template = self.get_validation_templates()
validation_template_regexp = (
validation_template["sub"] if validation_template["on"] else None
)
if self.is_local_project():
include_central = False
return getters.get_next_sub_or_ses(
self.cfg,
top_level_folder,
sub=None,
include_central=include_central,
return_with_prefix=return_with_prefix,
search_str="sub-*",
validation_template_regexp=validation_template_regexp,
)
[docs]
@check_configs_set
def get_next_ses(
self,
top_level_folder: TopLevelFolder,
sub: str,
return_with_prefix: bool = True,
include_central: bool = False,
) -> str:
"""Return the next session number.
Parameters
----------
top_level_folder
The top-level folder, "rawdata" or "derivatives".
sub
Name of the subject to find the next session of.
return_with_prefix
If `True`, return with the "ses-" prefix.
include_central
If ``False``, only get names from ``local_path``, otherwise from
``local_path`` and ``central_path``. If in local-project mode,
this flag is ignored.
Returns
-------
The next session ID.
"""
validation_template = self.get_validation_templates()
validation_template_regexp = (
validation_template["ses"] if validation_template["on"] else None
)
if self.is_local_project():
include_central = False
return getters.get_next_sub_or_ses(
self.cfg,
top_level_folder,
sub=sub,
include_central=include_central,
return_with_prefix=return_with_prefix,
search_str="ses-*",
validation_template_regexp=validation_template_regexp,
)
[docs]
@check_configs_set
def is_local_project(self) -> bool:
"""Return a bool indicating whether the project is 'local only'.
A project is 'local-only' if it has no ``central_path`` and ``connection_method``.
It can be used to make folders and validate, but not for transfer.
"""
return self.cfg.is_local_project()
# Name Templates
# -------------------------------------------------------------------------
[docs]
def get_validation_templates(self) -> Dict:
"""Return the regexp templates used for validation.
If the "on" key is set to `False`, template validation is not performed.
Returns
-------
validation_templates
e.g. {"validation_templates": {"on": False, "sub": None, "ses": None}}
"""
settings = self._load_persistent_settings()
return settings["validation_templates"]
[docs]
def set_validation_templates(self, new_validation_templates: Dict) -> None:
"""Update the persistent settings with new name templates.
Name templates are regexp for that, when ``validation_templates["on"]`` is
set to ``True``, ``"sub"`` and ``"ses"`` names are validated against
the regexp contained in the dict.
Parameters
----------
new_validation_templates
e.g. ``{"validation_templates": {"on": False, "sub": None, "ses": None}}``
where ``"sub"`` or ``"ses"`` can be a regexp that subject and session
names respectively are validated against.
"""
self._update_persistent_setting(
"validation_templates", new_validation_templates
)
# -------------------------------------------------------------------------
# Showers
# -------------------------------------------------------------------------
[docs]
@check_configs_set
def show_configs(self) -> None:
"""Print the current configs to the terminal."""
utils.print_message_to_user(self._get_json_dumps_config())
# -------------------------------------------------------------------------
# Validators
# -------------------------------------------------------------------------
[docs]
@check_configs_set
def validate_project(
self,
top_level_folder: Optional[TopLevelFolder],
display_mode: DisplayMode,
include_central: bool = False,
strict_mode: bool = False,
allow_letters_in_sub_ses_values: bool = False,
) -> List[str]:
"""Perform validation on the project.
This checks the subject and session level folders to
ensure there are no NeuroBlueprint formatting issues.
Parameters
----------
top_level_folder
Folder to check, either ``"rawdata"`` or ``"derivatives"``. If ``None``,
will check both folders.
display_mode
The validation issues are displayed as ``"error"`` (raise error)
``"warn"`` (show warning) or ``"print"``
include_central
If ``False``, only the local project is validated. Otherwise, both
local and central projects are validated. If in local-project mode,
this flag is ignored.
strict_mode
If ``True``, only allow NeuroBlueprint-formatted folders to exist in
the project. By default, non-NeuroBlueprint folders (e.g. a folder
called '`my_stuff'` in the '`rawdata'`) are allowed, and only folders
starting with sub- or ses- prefix are checked. In ``Strict Mode``,
any folder not prefixed with sub-, ses- or a valid datatype will
raise a validation issue.
allow_letters_in_sub_ses_values
If `True`, any alphanumeric character are allowed for the values associated
with sub- or ses- keys. Otherwise, values must be integer
and the following additional checks are performed:
- Labels must be the same length (e.g. sub-01 and sub-002 is invalid).
Returns
-------
error_messages
A list of validation errors found in the project.
"""
if include_central and strict_mode:
raise ValueError(
"`strict_mode` is currently only available for `include_central=False`. "
"Please raise a GitHub issue if you would like to use this feature."
)
utils.print_message_to_user(
f"Logs of the validation will be stored in: "
f"{self.cfg.make_and_get_logging_path()}\n\nValidation results:"
)
self._start_log(
"validate-project",
)
validation_templates = self.get_validation_templates()
if self.is_local_project():
include_central = False
top_level_folder_to_validate = _format_top_level_folder(
top_level_folder
)
error_messages = validation.validate_project(
self.cfg,
top_level_folder_to_validate,
include_central=include_central,
display_mode=display_mode,
validation_templates=validation_templates,
strict_mode=strict_mode,
allow_letters_in_sub_ses_values=allow_letters_in_sub_ses_values,
)
ds_logger.close_log_filehandler()
return error_messages
# -------------------------------------------------------------------------
# Private Functions
# -------------------------------------------------------------------------
def _transfer_entire_project(
self,
upload_or_download: Literal["upload", "download"],
overwrite_existing_files: OverwriteExistingFiles,
dry_run: bool,
) -> TransferOutput:
"""Transfer the entire project.
i.e. every 'top level folder' (e.g. 'rawdata', 'derivatives').
See ``upload_custom()`` or ``download_custom()`` for parameters.
"""
all_output = TransferOutput()
for top_level_folder in canonical_folders.get_top_level_folders():
utils.log_and_message(
f"\n\n*************************************\n"
f"Transferring `{top_level_folder}`\n"
f"*************************************\n"
)
transfer_output = self._transfer_top_level_folder(
upload_or_download,
top_level_folder,
overwrite_existing_files=overwrite_existing_files,
dry_run=dry_run,
init_log=False,
display_transfer_output=False,
)
all_output["errors"]["file_names"] += transfer_output["errors"][
"file_names"
]
all_output["errors"]["messages"] += transfer_output["errors"][
"messages"
]
all_output["num_transferred"][top_level_folder] = transfer_output[
"num_transferred"
][top_level_folder]
rclone.log_rclone_transfer_output(all_output)
return all_output
def _start_log(
self,
command_name: str,
local_vars: Optional[dict] = None,
store_in_temp_folder: bool = False,
verbose: bool = True,
) -> None:
"""Initialize the logger.
This is typically called at the start of public
methods to initialize logging for a specific function call.
Parameters
----------
command_name
Name of the command, for the log output files.
local_vars
local_vars are passed to fancylog variables argument.
see ds_logger.wrap_variables_for_fancylog for more info
store_in_temp_folder
If `False`, existing logging path will be used
(local project .datashuttle).
verbose
Print warnings and error messages.
"""
if local_vars is None:
variables = None
else:
variables = ds_logger.wrap_variables_for_fancylog(
local_vars, self.cfg
)
if store_in_temp_folder:
path_to_save = self._temp_log_path
self._clear_temp_log_path()
else:
path_to_save = self.cfg.logging_path
os.makedirs(path_to_save, exist_ok=True)
ds_logger.start(path_to_save, command_name, variables, verbose)
def _move_logs_from_temp_folder(self) -> None:
"""Create a temporary logging folder when the project folder is unknown.
Logs are stored within the project folder. Although
in some instances, when setting configs, we do not know what
the project folder is. In this case, make the logs
in a temp folder in the .datashuttle config folder,
and move them to the project folder once set.
"""
if not self.cfg or not self.cfg["local_path"].is_dir():
utils.log_and_raise_error(
"Project folder does not exist. Logs were not moved.",
FileNotFoundError,
)
ds_logger.close_log_filehandler()
log_files = glob.glob(str(self._temp_log_path / "*.log"))
for file_path in log_files:
file_name = os.path.basename(file_path)
shutil.move(
self._temp_log_path / file_name,
self.cfg.logging_path / file_name,
)
def _clear_temp_log_path(self) -> None:
"""Delete temporary log files."""
log_files = glob.glob(str(self._temp_log_path / "*.log"))
for file in log_files:
os.remove(file)
def _error_on_base_project_name(self, project_name):
if validation.name_has_special_character(project_name):
utils.log_and_raise_error(
"The project name must contain alphanumeric characters only.",
ValueError,
)
if project_name == "":
utils.log_and_raise_error(
"The project name cannot be empty.", NeuroBlueprintError
)
def _log_successful_config_change(self, message: bool = False) -> None:
"""Log the entire config at the time of config change.
If messaged, just message "update successful" rather than
print the entire configs as it becomes confusing.
"""
if message:
utils.print_message_to_user("Update successful.")
utils.log(
f"Update successful. New config file: "
f"\n {self._get_json_dumps_config()}"
)
def _get_json_dumps_config(self) -> str:
"""Return the config dictionary formatted as json.dumps() which allows well formatted printing."""
copy_dict = copy.deepcopy(self.cfg.data)
load_configs.convert_str_and_pathlib_paths(copy_dict, "path_to_str")
return json.dumps(copy_dict, indent=4)
def _make_project_metadata_if_does_not_exist(self) -> None:
"""Locate the .datashuttle folder within the project local_path.
Within the project local_path is also a .datashuttle
folder that contains additional information, e.g. logs.
"""
folders.create_folders(self.cfg.project_metadata_path, log=False)
def _setup_rclone_central_ssh_config(
self, private_key_str: str, log: bool
) -> None:
rclone.setup_rclone_config_for_ssh(
self.cfg,
self.cfg.rclone.get_rclone_config_name("ssh"),
private_key_str,
log=log,
)
def _setup_rclone_central_local_filesystem_config(self) -> None:
rclone.setup_rclone_config_for_local_filesystem(
self.cfg,
self.cfg.rclone.get_rclone_config_name("local_filesystem"),
)
def _setup_rclone_gdrive_config(
self,
gdrive_client_secret: str | None,
config_token: str | None,
) -> subprocess.Popen:
return rclone.setup_rclone_config_for_gdrive(
self.cfg,
self.cfg.rclone.get_rclone_config_name("gdrive"),
gdrive_client_secret,
config_token,
)
def _setup_rclone_aws_config(
self, aws_secret_access_key: str, log: bool
) -> None:
rclone.setup_rclone_config_for_aws(
self.cfg,
self.cfg.rclone.get_rclone_config_name("aws"),
aws_secret_access_key,
log=log,
)
# Persistent settings
# -------------------------------------------------------------------------
def _update_persistent_setting(
self, setting_name: str, setting_value: Any
) -> None:
"""Load settings that are stored persistently across datashuttle sessions.
These are stored in yaml dumped to dictionary.
Parameters
----------
setting_name
dictionary key of the persistent setting to change
setting_value
value to change the persistent setting to
"""
settings = self._load_persistent_settings()
if setting_name not in settings:
utils.log_and_raise_error(
f"Setting key {setting_name} not found in settings dictionary",
KeyError,
)
settings[setting_name] = setting_value
self._save_persistent_settings(settings)
def _init_persistent_settings(self) -> None:
"""Initialise the default persistent settings and save to file."""
settings = canonical_configs.get_persistent_settings_defaults()
self._save_persistent_settings(settings)
def _save_persistent_settings(self, settings: Dict) -> None:
"""Save the settings dict to file as ".yaml"."""
with open(self._persistent_settings_path, "w") as settings_file:
yaml.dump(settings, settings_file, sort_keys=False)
def _load_persistent_settings(self) -> Dict:
"""Return settings that are stored persistently across datashuttle sessions."""
if not self._persistent_settings_path.is_file():
self._init_persistent_settings()
with open(self._persistent_settings_path) as settings_file:
settings = yaml.full_load(settings_file)
self._update_settings_with_new_canonical_keys(settings)
return settings
def _update_settings_with_new_canonical_keys(self, settings: Dict) -> None:
"""Check and update keys within persistent settings if missing.
Perform a check on the keys within persistent settings.
If they do not exist, persistent settings is from an older version
and the new keys need adding.
If changing keys within the top level (e.g. a dict entry in
"tui") this method will need to be extended.
Added keys:
v0.4.0: tui "overwrite_existing_files" and "dry_run"
"""
if "validation_templates" not in settings:
if "name_templates" in settings:
settings["validation_templates"] = settings.pop(
"name_templates"
)
else:
settings.update(
canonical_configs.get_validation_templates_defaults()
)
canonical_tui_configs = canonical_configs.get_tui_config_defaults()
if "tui" not in settings:
settings.update(canonical_tui_configs)
for key in [
"overwrite_existing_files",
"dry_run",
"suggest_next_sub_ses_central",
"allow_letters_in_sub_ses_values",
]:
if key not in settings["tui"]:
settings["tui"][key] = canonical_tui_configs["tui"][key]
# Handle updating with narrow datatypes
canonical_configs.in_place_update_narrow_datatypes_if_required(
settings
)
def _check_top_level_folder(self, top_level_folder) -> None:
"""Raise an error if ``top_level_folder`` not correct."""
canonical_top_level_folders = canonical_folders.get_top_level_folders()
if top_level_folder not in canonical_top_level_folders:
utils.log_and_raise_error(
f"`top_level_folder` must be one of "
f"{canonical_top_level_folders}",
ValueError,
)