Copy project

CopyProject.py

import os
import shutil
import tempfile
from zipfile import ZipFile

from LOGS import LOGS
from LOGS.Auxiliary import LOGSException
from LOGS.Entities import (
    Dataset,
    DatasetRequestParameter,
    DocumentRequestParameter,
    EntitiesRequestParameter,
    ExperimentRequestParameter,
    InstrumentRequestParameter,
    MethodRequestParameter,
    PersonRequestParameter,
    ProjectRequestParameter,
    SampleRequestParameter,
)


def unique(l):
    return list(set(l))


class EntityCollection:
    """Helper class for CopyProject. Contains a collection of all LOGS entities types."""

    def __init__(self):
        self.origins = {}
        self.projects = {}
        self.samples = {}
        self.documents = {}
        self.datasets = {}
        self.instruments = {}
        self.methods = {}
        self.persons = {}
        self.experiments = {}

    def printSummary(self):
        """Print the counts of each LOGS entity type."""
        print("Found %d project(s) to copy." % len(self.projects.keys()))
        print("Found %d sample(s) to copy." % len(self.samples.keys()))
        # print("Found %d document(s) to copy." % len(self.documents.keys()))
        print("Found %d experiment(s) to copy." % len(self.experiments.keys()))
        print("Found %d dataset(s) to copy." % len(self.datasets.keys()))
        print("Found %d instrument(s) to copy." % len(self.instruments.keys()))
        print("Found %d method(s) to copy." % len(self.methods.keys()))
        print("Found %d person(s) to copy." % len(self.persons.keys()))


class CopyProject:
    """Copies a project from a source LOGS to of a target LOGS."""

    def __init__(self, sourceLogs: LOGS, targetLogs: LOGS):
        """Constructor of CopyProject

        Args:
            sourceLogs: The LOGS object connected to gather project data from
            targetLogs: The LOGS object connected to the copy the project contents to
        """
        self.sourceLogs = sourceLogs
        self.targetLogs = targetLogs

        self.unknownCount = 0

    def collectEntitiesFromProjects(self, projectNames=[], projectTags=[]):
        """Collects entities based on specified projects

        Args:
            projectNames: Name list for project selection to collect entities from. Defaults to [].
            projectTags: Tag list for project selection to collect entities from. Defaults to [].

        Raises:
            Exception: Throws if source projects are not

        Returns:
            A collection of entities from the source projects.
        """
        request1 = self.sourceLogs.projects(ProjectRequestParameter(names=projectNames))
        request2 = self.sourceLogs.projects(
            ProjectRequestParameter(projectTags=projectTags)
        )
        print(
            "+++ Collecting %d projects from source +++"
            % (request1.count + request2.count)
        )
        projectList = list(request1)

        projectList.extend(list(request2))

        projects = {project.name: project for project in projectList}

        for name in projectNames:
            if name not in projects:
                raise Exception("Source project %a not found." % name)

        origin = self.sourceLogs.instanceOrigin

        projects = {project.id: project for project in projectList}

        # Samples from source
        request3 = self.sourceLogs.samples(
            SampleRequestParameter(projectIds=list(projects.keys()))
        )
        print("+++ Collecting %d samples from source +++" % request3.count)
        samples = {sample.id: sample for sample in request3}

        # Documents from source
        request4 = self.sourceLogs.documents(
            DocumentRequestParameter(projectIds=list(projects.keys()))
        )
        print("+++ Collecting %d documents from source +++" % request4.count)
        documents = {dataset.id: dataset for dataset in request4}

        # Datasets from source
        request5 = self.sourceLogs.datasets(
            DatasetRequestParameter(projectIds=list(projects.keys()))
        )
        print("+++ Collecting %d datasets from source +++" % request5.count)
        datasets = {dataset.id: dataset for dataset in request5}

        ## Uncomment the following blocks if you want to also copy all related datasets
        # datasets.update(
        #     {
        #         dataset.id: dataset
        #         for dataset in logs.datasets(
        #             DatasetRequestParameter(sampleIds=list(samples.keys()))
        #         )
        #     }
        # )

        # datasetIds = [
        #     dataset.id
        #     for document in documents.values()
        #     if document.datasets
        #     for dataset in document.datasets
        # ]

        # datasets.update(
        #     {
        #         dataset.id: dataset
        #         for dataset in logs.datasets(DatasetRequestParameter(ids=datasetIds))
        #     }
        # )

        print("+++ Collecting related entities +++")
        sampleIds = [
            dataset.sample.id for dataset in datasets.values() if dataset.sample
        ]
        samples.update(
            {
                sample.id: sample
                for sample in self.sourceLogs.samples(
                    SampleRequestParameter(ids=sampleIds)
                )
            }
        )

        # Instruments from source
        instrumentIds = [
            dataset.instrumentId
            for dataset in datasets.values()
            if dataset.instrumentId
        ]
        instrumentIds = unique(instrumentIds)
        instruments = {
            instrument.id: instrument
            for instrument in self.sourceLogs.instruments(
                InstrumentRequestParameter(ids=instrumentIds)
            )
        }

        # Experiments from source
        experimentIds = [
            dataset.experiment.id for dataset in datasets.values() if dataset.experiment
        ]
        experimentIds = unique(experimentIds)
        experiments = {
            experiment.id: experiment
            for experiment in self.sourceLogs.experiments(
                ExperimentRequestParameter(ids=experimentIds)
            )
        }

        # Methods from source
        methodIds = [
            dataset.methodId for dataset in datasets.values() if dataset.methodId
        ]
        methodIds.extend(
            instrument.methodId
            for instrument in instruments.values()
            if instrument.methodId
        )
        methodIds.extend(
            experiment.method.id
            for experiment in experiments.values()
            if experiment.method
        )

        methodIds = unique(methodIds)

        methods = {
            method.id: method
            for method in self.sourceLogs.methods(MethodRequestParameter(ids=methodIds))
        }

        # Persons from other entities
        personIds = list(sample.owner.id for sample in samples.values() if sample.owner)
        personIds.extend(
            person.id
            for sample in samples.values()
            if sample.preparedBy
            for person in sample.preparedBy
        )

        personIds.extend(
            dataset.owner.id for dataset in datasets.values() if dataset.owner
        )
        personIds.extend(
            id
            for dataset in datasets.values()
            if dataset.operatorIds
            for id in dataset.operatorIds
        )

        personIds = unique(personIds)

        persons = {
            person.id: person
            for person in self.sourceLogs.persons(PersonRequestParameter(ids=personIds))
        }

        collection = EntityCollection()
        collection.origins = {origin.id: origin}
        collection.projects = projects
        collection.samples = samples
        collection.documents = documents
        collection.datasets = datasets
        collection.instruments = instruments
        collection.methods = methods
        collection.persons = persons
        collection.experiments = experiments

        return collection

    @classmethod
    def _removeKnownByUid(cls, logs: LOGS, entities, targetMapping):
        mapper = {str(v.uid): k for k, v in entities.items()}
        for targetEntity in logs.entities(
            EntitiesRequestParameter(uids=list(mapper.keys()))
        ):
            knownEntity = entities[mapper[str(targetEntity.uid)]]
            print(
                "%s with name %a already exists and will be skipped."
                % (type(knownEntity).__name__, knownEntity.name)
            )
            targetMapping[targetEntity.id] = targetEntity
            del entities[mapper[str(targetEntity.uid)]]

    @classmethod
    def _removeKnownByName(cls, logs: LOGS, entities, targetMapping):
        mapper = {v.name: k for k, v in entities.items()}
        # print(
        #     "mapper",
        #     list(mapper.keys()),
        #     list(logs.entities(EntitiesRequestParameter(names=list(mapper.keys())))),
        # )
        for targetEntity in logs.entities(
            EntitiesRequestParameter(names=list(mapper.keys()))
        ):
            if not targetEntity.name or targetEntity.name not in mapper:
                continue
            knownEntity = entities[mapper[targetEntity.name]]
            if (
                type(targetEntity).__name__.replace("Minimal", "")
                != type(knownEntity).__name__
            ):
                continue
            print(
                "%s with name %a already exists and will be skipped."
                % (type(knownEntity).__name__, knownEntity.name)
            )
            targetMapping[knownEntity.id] = targetEntity
            del entities[mapper[str(targetEntity.name)]]
            del mapper[str(targetEntity.name)]
            cls._removeKnownByUid(
                logs=logs, entities=entities, targetMapping=targetMapping
            )

    @classmethod
    def _applyMapping(cls, collection, nameMapping):
        print(list(nameMapping.keys()))

        hasMapping = False
        if "instrument" in nameMapping:
            hasMapping = True
            print("+++ Mapping for instruments found +++")
            nameMap = (
                {m["original"]: m["mapped"] for m in nameMapping["instrument"]}
                if "instrument" in nameMapping
                else {}
            )
            # print([i.name for i in collection.instruments.values()])
            for instrument in collection.instruments.values():
                if instrument.name in nameMap:
                    print(
                        "Renaming %a to %a"
                        % (instrument.name, nameMap[instrument.name])
                    )
                    instrument.name = nameMap[instrument.name]

            # for i in collection.instruments.values():
            #     i.printJson()
            # print([i.name for i in collection.instruments.values()])

        if "projectNamePrefix" in nameMapping:
            hasMapping = True
            print(
                "+++ Project name prefix %a found +++"
                % nameMapping["projectNamePrefix"]
            )
            for project in collection.projects.values():
                project.name = nameMapping["projectNamePrefix"] + (project.name or "")

        if not hasMapping:
            raise Exception("Mapping does not contain any entity name map.")

    @classmethod
    def _filterKnown(cls, logs, collection, targetMapping):
        cls._removeKnownByUid(logs, collection.origins, targetMapping.origins)
        cls._removeKnownByUid(logs, collection.samples, targetMapping.samples)
        cls._removeKnownByUid(logs, collection.datasets, targetMapping.datasets)
        cls._removeKnownByUid(logs, collection.projects, targetMapping.projects)

        cls._removeKnownByName(logs, collection.instruments, targetMapping.instruments)
        cls._removeKnownByName(logs, collection.methods, targetMapping.methods)
        cls._removeKnownByName(logs, collection.experiments, targetMapping.experiments)

        # documents TODO: skip documents for now

        persons = {}
        for knownPerson in collection.persons.values():
            l = list(
                logs.persons(
                    PersonRequestParameter(
                        firstNames=[knownPerson.firstName],
                        lastNames=[knownPerson.lastName],
                    )
                )
            )
            if l:
                targetPerson = l[0]
                print(
                    "Person with name %a already exists and will be skipped."
                    % knownPerson.name
                )
                targetMapping.persons[knownPerson.id] = targetPerson
            else:
                persons[knownPerson.id] = knownPerson

        collection.persons = persons

    @classmethod
    def _mapEntities(cls, orig, mapping):
        if orig == None:
            return None

        single = True
        if isinstance(orig, list):
            single = False
        else:
            orig = [orig]

        mapped = [mapping[o.id] for o in orig if o.id in mapping]

        if single:
            if mapped:
                return mapped[0]
            else:
                return None
        return mapped

    @classmethod
    def _mapOwner(cls, entities, mapping):
        for e in entities:
            if hasattr(e, "owner"):
                setattr(e, "owner", cls._mapEntities(getattr(e, "owner"), mapping))

    def _createEntities(self, collection, targetMapping):
        temporaryDirectory = tempfile.TemporaryDirectory()
        tempdir = temporaryDirectory.name

        if not os.path.exists(tempdir):
            raise Exception("Temp path %a does not exist" % tempdir)

        if not os.path.isdir(tempdir):
            raise Exception("Temp path %a is not a directory" % tempdir)

        numberSum = 0

        number = len(collection.origins.keys())
        numberSum += number
        count = 0
        for id, origin in collection.origins.items():
            count += 1
            p = 100.0 * count / number
            print("%.1f%% Creating" % p, origin, origin.url)
            targetMapping.origins[id] = origin
            self.targetLogs.create(origin)

        # We have only one origin for now so we use the first values from the dict
        origin = list(targetMapping.origins.values())[0]

        number = len(collection.persons.keys())
        numberSum += number
        count = 0
        for id, person in collection.persons.items():
            if not person.salutation:
                person.salutation = "Mrs."
            count += 1
            p = 100.0 * count / number
            print("%.1f%% Creating" % p, person)
            person.setOrigin(uid=person.uid, origin=origin)
            targetMapping.persons[id] = person
            self.targetLogs.create(person)

        self._mapOwner(collection.projects.values(), targetMapping.persons)
        self._mapOwner(collection.samples.values(), targetMapping.persons)
        self._mapOwner(collection.documents.values(), targetMapping.persons)
        self._mapOwner(collection.datasets.values(), targetMapping.persons)
        self._mapOwner(collection.instruments.values(), targetMapping.persons)
        self._mapOwner(collection.methods.values(), targetMapping.persons)
        self._mapOwner(collection.persons.values(), targetMapping.persons)
        self._mapOwner(collection.experiments.values(), targetMapping.persons)

        number = len(collection.projects.keys())
        numberSum += number
        count = 0
        for id, project in collection.projects.items():
            count += 1
            p = 100.0 * count / number
            print("%.1f%% Creating" % p, project)
            project.setOrigin(uid=project.uid, origin=origin)
            if project.projectPersonPermissions:
                projectPersonPermissions = []
                for perm in project.projectPersonPermissions:
                    if perm.person:
                        perm.person = self._mapEntities(
                            perm.person, targetMapping.persons
                        )
                        if perm.person:
                            projectPersonPermissions.append(perm)
                project.projectPersonPermissions = projectPersonPermissions

            targetMapping.projects[id] = project
            if not any(
                person.administer for person in project.projectPersonPermissions
            ):
                for person in project.projectPersonPermissions:
                    person.administer = True
            project.projectTags = None

            self.targetLogs.create(project)

        number = len(collection.projects.keys())
        numberSum += number
        count = 0
        for id, method in collection.methods.items():
            count += 1
            p = 100.0 * count / number
            print("%.1f%% Creating" % p, method)
            method.setOrigin(uid=method.uid, origin=origin)
            targetMapping.methods[id] = method
            self.targetLogs.create(method)

        number = len(collection.instruments.keys())
        numberSum += number
        count = 0
        for id, instrument in collection.instruments.items():
            count += 1
            p = 100.0 * count / number
            print("%.1f%% Creating" % p, instrument)
            instrument.method = self._mapEntities(
                instrument.method, targetMapping.methods
            )
            instrument.setOrigin(uid=instrument.uid, origin=origin)
            targetMapping.instruments[id] = instrument
            self.targetLogs.create(instrument)

        number = len(collection.experiments.keys())
        numberSum += number
        count = 0
        for id, experiment in collection.experiments.items():
            count += 1
            p = 100.0 * count / number
            print("%.1f%% Creating" % p, experiment)
            experiment.method = self._mapEntities(
                experiment.method, targetMapping.methods
            )
            if not experiment.name:
                self.unknownCount += 1
                experiment.name = "Unknown %d" % self.unknownCount
            experiment.setOrigin(uid=experiment.uid, origin=origin)
            targetMapping.experiments[id] = experiment
            self.targetLogs.create(experiment)

        number = len(collection.samples.keys())
        numberSum += number
        count = 0
        for id, sample in collection.samples.items():
            count += 1
            p = 100.0 * count / number
            print("%.1f%% Creating" % p, sample)
            sample.preparedBy = self._mapEntities(
                sample.preparedBy, targetMapping.persons
            )
            sample.projects = self._mapEntities(sample.projects, targetMapping.projects)
            if not sample.projects:
                sample.projects = list(targetMapping.projects.values())[:1]
            sample.customFields = None
            if sample.tags:
                if not sample.notes:
                    sample.notes = ""
                else:
                    sample.notes += "\n"
                sample.notes += "Tags from origin:\n" + "\n".join(
                    [
                        "%s: %s" % (k, v.replace(":", "_"))
                        for k, v in sample.tags.items()
                    ]
                )

            sample.setOrigin(uid=sample.uid, origin=origin)
            targetMapping.samples[id] = sample
            self.targetLogs.create(sample)

        number = len(collection.datasets.keys())
        numberSum += number
        count = 0
        for id, dataset in collection.datasets.items():
            count += 1
            p = 100.0 * count / number
            print("%.1f%% Creating" % p, dataset)
            zipName = "dataset_%s" % id
            zipFileName = zipName + ".zip"
            zipDir = os.path.join(tempdir, zipName)
            zipFile = os.path.join(tempdir, zipFileName)
            if os.path.exists(zipFile):
                os.remove(zipFile)

            dataset.download(directory=tempdir, fileName=zipFileName, overwrite=True)
            if not os.path.exists(zipDir):
                with ZipFile(zipFile, "r") as zObject:
                    os.mkdir(zipDir)
                    zObject.extractall(path=zipDir)

            uid = dataset.uid
            dataset = Dataset(dataset, files=zipDir)

            dataset.instrument = self._mapEntities(
                dataset.instrument, targetMapping.instruments
            )
            dataset.method = self._mapEntities(dataset.method, targetMapping.methods)
            dataset.operators = self._mapEntities(
                dataset.operators, targetMapping.persons
            )
            dataset.projects = self._mapEntities(
                dataset.projects, targetMapping.projects
            )
            dataset.sample = self._mapEntities(dataset.sample, targetMapping.samples)
            dataset.experiment = self._mapEntities(
                dataset.experiment, targetMapping.experiments
            )
            if not (dataset.projects and dataset.sample and dataset.operators):
                dataset.claimed = False

            dataset.setOrigin(uid=uid, origin=origin)

            try:
                targetMapping.datasets[id] = dataset
                self.targetLogs.create(dataset)

            except LOGSException as e:
                print(
                    " Dataset %a upload failed: '" % dataset.toString() + str(e) + "'"
                )

            os.remove(zipFile)
            shutil.rmtree(zipDir)

        temporaryDirectory.cleanup()
        print("+++ Created %d entities on target +++" % numberSum)

    def copy(
        self, projectNames=[], projectTags=[], nameMapping=None, projectNamePrefix=None
    ):
        if not nameMapping:
            print("+++ No name mapping. +++")

        if projectNamePrefix:
            if not nameMapping:
                nameMapping = {}
            nameMapping["projectNamePrefix"] = projectNamePrefix

        collection = self.collectEntitiesFromProjects(
            projectNames=projectNames, projectTags=projectTags
        )

        collection.printSummary()

        if nameMapping is not None:
            self._applyMapping(collection, nameMapping)
        targetMapping = EntityCollection()
        self._filterKnown(self.targetLogs, collection, targetMapping)

        print()
        print(
            "+++ Final number of entities to copy (after removing duplicates and known). +++"
        )
        collection.printSummary()
        print()

        self._createEntities(collection, targetMapping)


if __name__ == "__main__":
    source_key = "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
    source_url = "https://source.logs.com/sorceGroup"

    target_key = "YYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYY"
    target_url = "https://target.logs.com/targetGroup"

    # create two logs instances (source and target LOGS)
    copy = CopyProject(
        LOGS(source_url, source_key, verbose=False),
        LOGS(target_url, target_key, verbose=False),
    )

    # You can also introduce some name mapping for the instrument name
    # nameMapping = {
    #     "instrument": [
    #         {"original": "OriginalInstrumentName", "mapped": "TargetInstrumentName"}
    #     ]
    # }
    # copy.copy(
    #     projectNames=["Project to Copy 1", "Project to Copy 2"], nameMapping=nameMapping
    # )

    copy.copy(projectNames=["Project to Copy 1", "Project to Copy 2"])