Bases: BasePipelineWriter
Class for pipeline file writer. Corresponds to "DVC" framework.
Source code in lineapy/plugins/dvc_pipeline_writer.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153 | class DVCPipelineWriter(BasePipelineWriter):
"""
Class for pipeline file writer. Corresponds to "DVC" framework.
"""
@property
def docker_template_name(self) -> str:
return "dvc_dockerfile.jinja"
def _write_dag(self) -> None:
dag_flavor = self.dag_config.get(
"dag_flavor", "SingleStageAllSessions"
)
# Check if the given DAG flavor is a supported/valid one
if dag_flavor not in DVCDagFlavor.__members__:
raise ValueError(f'"{dag_flavor}" is an invalid dvc dag flavor.')
# Construct DAG text for the given flavor
if DVCDagFlavor[dag_flavor] == DVCDagFlavor.SingleStageAllSessions:
dvc_yaml_code = self._write_operator_run_all_sessions()
if DVCDagFlavor[dag_flavor] == DVCDagFlavor.StagePerArtifact:
dvc_yaml_code = self._write_operator_run_per_artifact()
# Write out file
dvc_dag_file = self.output_dir / "dvc.yaml"
dvc_dag_file.write_text(dvc_yaml_code)
logger.info(f"Generated DAG file: {dvc_dag_file}")
def _write_operator_run_all_sessions(self) -> str:
"""
This hidden method implements DVC DAG code generation corresponding
to the `SingleStageAllSessions` flavor. This DAG only has one stage and
calls `run_all_sessions` generated by the module file.
"""
DAG_TEMPLATE = load_plugin_template(
"dvc_dag_SingleStageAllSessions.jinja"
)
full_code = DAG_TEMPLATE.render(
MODULE_COMMAND=f"python {self.pipeline_name}_module.py",
)
return full_code
def _write_operator_run_per_artifact(self) -> str:
"""
This hidden method implements DVC DAG code generation corresponding
to the `StagePerArtifact` flavor.
"""
DAG_TEMPLATE = load_plugin_template("dvc_dag_StagePerArtifact.jinja")
task_defs, _ = get_task_graph(
self.artifact_collection,
pipeline_name=self.pipeline_name,
task_breakdown=DagTaskBreakdown.TaskPerArtifact,
)
# Get DAG parameters for an ARGO pipeline
input_parameters_dict: Dict[str, Any] = {}
for parameter_name, input_spec in super().get_pipeline_args().items():
input_parameters_dict[parameter_name] = input_spec.value
stages = [
Stage(
name=key,
deps=value.loaded_input_variables,
outs=value.return_vars,
call_block=value.call_block,
user_input_variables={
key: input_parameters_dict[key]
for key in value.user_input_variables
},
).dict()
for key, value in task_defs.items()
]
full_code = DAG_TEMPLATE.render(
MODULE_NAME=f"{self.pipeline_name}_module", STAGES=stages
)
self._write_params(stages)
for stage in stages:
self._write_python_operator_per_run_artifact(stage)
return full_code
def _write_params(self, stages: List[dict]):
PARAMS_TEMPLATE = load_plugin_template("dvc_dag_params.jinja")
params_code = PARAMS_TEMPLATE.render(STAGES=stages)
filename = "params.yaml"
params_file = self.output_dir / filename
params_file.write_text(params_code)
logger.info(f"Generated DAG file: {params_file}")
def _write_python_operator_per_run_artifact(self, stage: dict):
"""
This hidden method generates the python cmd files for each DVC stage.
"""
TASK_TEMPLATE = load_plugin_template("dvc_dag_PythonOperator.jinja")
python_operator_code = TASK_TEMPLATE.render(
MODULE_NAME=f"{self.pipeline_name}_module", STAGE=stage
)
filename = f"task_{stage['name']}.py"
python_operator_file = self.output_dir / filename
python_operator_file.write_text(python_operator_code)
logger.info(f"Generated DAG file: {python_operator_file}")
|