Skip to content

dvc_pipeline_writer

DVCPipelineWriter

Bases: BasePipelineWriter

Class for pipeline file writer. Corresponds to "DVC" framework.

Source code in lineapy/plugins/dvc_pipeline_writer.py
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
class DVCPipelineWriter(BasePipelineWriter):
    """
    Class for pipeline file writer. Corresponds to "DVC" framework.
    """

    @property
    def docker_template_name(self) -> str:
        return "dvc_dockerfile.jinja"

    def _write_dag(self) -> None:
        dag_flavor = self.dag_config.get(
            "dag_flavor", "SingleStageAllSessions"
        )

        # Check if the given DAG flavor is a supported/valid one
        if dag_flavor not in DVCDagFlavor.__members__:
            raise ValueError(f'"{dag_flavor}" is an invalid dvc dag flavor.')

        # Construct DAG text for the given flavor
        if DVCDagFlavor[dag_flavor] == DVCDagFlavor.SingleStageAllSessions:
            dvc_yaml_code = self._write_operator_run_all_sessions()

        if DVCDagFlavor[dag_flavor] == DVCDagFlavor.StagePerArtifact:
            dvc_yaml_code = self._write_operator_run_per_artifact()

        # Write out file
        dvc_dag_file = self.output_dir / "dvc.yaml"
        dvc_dag_file.write_text(dvc_yaml_code)
        logger.info(f"Generated DAG file: {dvc_dag_file}")

    def _write_operator_run_all_sessions(self) -> str:
        """
        This hidden method implements DVC DAG code generation corresponding
        to the `SingleStageAllSessions` flavor. This DAG only has one stage and
        calls `run_all_sessions` generated by the module file.
        """

        DAG_TEMPLATE = load_plugin_template(
            "dvc_dag_SingleStageAllSessions.jinja"
        )

        full_code = DAG_TEMPLATE.render(
            MODULE_COMMAND=f"python {self.pipeline_name}_module.py",
        )

        return full_code

    def _write_operator_run_per_artifact(self) -> str:
        """
        This hidden method implements DVC DAG code generation corresponding
        to the `StagePerArtifact` flavor.
        """

        DAG_TEMPLATE = load_plugin_template("dvc_dag_StagePerArtifact.jinja")

        task_defs, _ = get_task_graph(
            self.artifact_collection,
            pipeline_name=self.pipeline_name,
            task_breakdown=DagTaskBreakdown.TaskPerArtifact,
        )

        # Get DAG parameters for an ARGO pipeline
        input_parameters_dict: Dict[str, Any] = {}
        for parameter_name, input_spec in super().get_pipeline_args().items():
            input_parameters_dict[parameter_name] = input_spec.value

        stages = [
            Stage(
                name=key,
                deps=value.loaded_input_variables,
                outs=value.return_vars,
                call_block=value.call_block,
                user_input_variables={
                    key: input_parameters_dict[key]
                    for key in value.user_input_variables
                },
            ).dict()
            for key, value in task_defs.items()
        ]

        full_code = DAG_TEMPLATE.render(
            MODULE_NAME=f"{self.pipeline_name}_module", STAGES=stages
        )

        self._write_params(stages)

        for stage in stages:
            self._write_python_operator_per_run_artifact(stage)

        return full_code

    def _write_params(self, stages: List[dict]):
        PARAMS_TEMPLATE = load_plugin_template("dvc_dag_params.jinja")

        params_code = PARAMS_TEMPLATE.render(STAGES=stages)
        filename = "params.yaml"
        params_file = self.output_dir / filename
        params_file.write_text(params_code)
        logger.info(f"Generated DAG file: {params_file}")

    def _write_python_operator_per_run_artifact(self, stage: dict):
        """
        This hidden method generates the python cmd files for each DVC stage.
        """
        TASK_TEMPLATE = load_plugin_template("dvc_dag_PythonOperator.jinja")

        python_operator_code = TASK_TEMPLATE.render(
            MODULE_NAME=f"{self.pipeline_name}_module", STAGE=stage
        )
        filename = f"task_{stage['name']}.py"
        python_operator_file = self.output_dir / filename
        python_operator_file.write_text(python_operator_code)
        logger.info(f"Generated DAG file: {python_operator_file}")

Was this helpful?

Help us improve docs with your feedback!