39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720 | @dataclass
class SessionArtifacts:
"""
Refactor a given session graph for use in a downstream task (e.g., pipeline building).
"""
_session_id: LineaID
graph: Graph
session_graph: Graph
db: RelationalLineaDB
usercode_nodecollections: List[UserCodeNodeCollection]
import_nodecollection: ImportNodeCollection
input_parameters_node: Dict[str, LineaID]
node_context: Dict[LineaID, NodeInfo]
target_artifacts: List[LineaArtifact]
reuse_pre_computed_artifacts: Dict[str, LineaArtifact]
all_session_artifacts: Dict[LineaID, LineaArtifact]
input_parameters: List[str]
nodecollection_dependencies: TaskGraph
def __init__(
self,
db: RelationalLineaDB,
target_artifacts: List[LineaArtifact],
input_parameters: List[str] = [],
reuse_pre_computed_artifacts: List[LineaArtifact] = [],
) -> None:
self.db = db
self.target_artifacts = target_artifacts
self.target_artifacts_name = [
art.name for art in self.target_artifacts
]
self._session_id = self.target_artifacts[0]._session_id
self.session_graph = Graph.create_session_graph(
self.db, self._session_id
)
def _get_subgraph_from_node_list(
session_graph: Graph, node_list: List[LineaID]
) -> Graph:
"""
Return the subgraph as LineaPy Graph from list of node id
"""
nodes: List[Node] = []
for node_id in node_list:
node = session_graph.get_node(node_id)
if node is not None:
nodes.append(node)
return self.session_graph.get_subgraph(nodes)
# Only interested union of sliced graph of each artifacts
self.graph = _get_subgraph_from_node_list(
self.session_graph,
list(
set.union(
*[
set(art._get_subgraph().nx_graph.nodes)
for art in target_artifacts
]
)
),
)
self.usercode_nodecollections = []
self.node_context = OrderedDict()
self.input_parameters = input_parameters
self.reuse_pre_computed_artifacts = {
art.name: art for art in reuse_pre_computed_artifacts
}
# Retrive all artifacts within the subgraph of target artifacts
self._retrive_all_session_artifacts()
# Add extra attributes(from predecessors) at session Linea nodes
self._update_node_context()
# Divide session graph into a set of non-overlapping NodeCollection
self._slice_session_artifacts()
# Determine dependencies of NodeCollections and check whether to
# replace it with pre-calculated value from artifact store
self._update_nodecollection_dependencies()
@property
def session_id(self) -> LineaID:
return self._session_id
def _update_dependent_variables(
self, nodeinfo: NodeInfo, variable_dict: Dict[LineaID, Set[str]]
) -> None:
# Dependent variables of each node is union of all tracked
# variables from predecessors
for prev_node_id in nodeinfo.predecessors:
prev_nodeinfo = self.node_context[prev_node_id]
dep = variable_dict.get(
prev_node_id,
prev_nodeinfo.tracked_variables,
)
nodeinfo.dependent_variables = nodeinfo.dependent_variables.union(
dep
)
def _update_tracked_variables(
self, nodeinfo: NodeInfo, node_id: LineaID
) -> None:
# Determine the tracked variables of each node
# Fix me when you see return variables in refactor behaves strange
node = self.graph.get_node(node_id=node_id)
if len(nodeinfo.assigned_variables) > 0:
nodeinfo.tracked_variables = nodeinfo.assigned_variables
elif isinstance(node, MutateNode) or isinstance(node, GlobalNode):
predecessor_call_id = node.call_id
if predecessor_call_id in nodeinfo.predecessors:
predecessor_nodeinfo = self.node_context[predecessor_call_id]
nodeinfo.tracked_variables = (
predecessor_nodeinfo.tracked_variables
)
elif isinstance(node, CallNode):
if (
len(node.global_reads) > 0
and list(node.global_reads.values())[0]
in nodeinfo.predecessors
):
nodeinfo.tracked_variables = self.node_context[
list(node.global_reads.values())[0]
].tracked_variables
elif (
node.function_id in nodeinfo.predecessors
and len(
self.node_context[node.function_id].dependent_variables
)
> 0
):
nodeinfo.tracked_variables = self.node_context[
node.function_id
].tracked_variables
elif (
len(node.positional_args) > 0
and node.positional_args[0].id in nodeinfo.predecessors
):
nodeinfo.tracked_variables = self.node_context[
node.positional_args[0].id
].tracked_variables
else:
nodeinfo.tracked_variables = set()
else:
nodeinfo.tracked_variables = set()
def _retrive_all_session_artifacts(self):
"""
Retrive all artifacts(targeted, reused) within the session.
Note that, the version of reused artifacts within the session does not
need to be the same as in `reuse_pre_computed_artifacts`; thus, we need
to retrive the correct version in the session with the correct node id
to enable graph manipulation.
"""
# Map node id to artifact assignment within the session
self.all_session_artifacts = {}
for artifact in self.db.get_artifacts_for_session(self._session_id):
if (
artifact.name in self.target_artifacts_name
or artifact.name in self.reuse_pre_computed_artifacts.keys()
):
assert isinstance(artifact.name, str)
self.all_session_artifacts[
artifact.node_id
] = LineaArtifact.get_artifact_from_name_and_version(
self.db, artifact.name, artifact.version
)
# Check only one artifact in the session with the name within reuse_pre_computed_artifacts
session_artifacts_name_count = Counter(
[art.name for nodeid, art in self.all_session_artifacts.items()]
)
for art_name in self.reuse_pre_computed_artifacts.keys():
if session_artifacts_name_count[art_name] > 1:
raise ValueError(
f"More than one artifacts with the same name {art_name} in the session."
+ "Please remove it from reuse_pre_computed_artifacts."
)
def _update_node_context(self):
"""
Traverse every node within the session in topologically sorted order
and update node_context with following information.
- assigned_variables : variables assigned at this node
- assigned_artifact : this node is pointing to some artifact
- predecessors : predecessors of the node
- dependent_variables : union of if any variable is assigned at
predecessor node, use the assigned variables; otherwise, use the
dependent_variables
- tracked_variables : variables that this node is point to
- module_import : module name/alias that this node is point to
Note that, it is possible to add all these new attributes during the
Linea graph creating phase. However, this might sacrifice the runtime
performance since some of the information need to query the attributes
from predecessors.
"""
# Map each variable node ID to the corresponding variable name(when variable assigned)
# Need to treat import different from regular variable assignment
variable_dict: Dict[LineaID, Set[str]] = OrderedDict()
import_dict: Dict[LineaID, Set[str]] = OrderedDict()
self.input_parameters_node = dict()
input_parameters_assignment_nodes: Dict[str, List[LineaID]] = dict()
for node_id, variable_name in self.db.get_variables_for_session(
self._session_id
):
if is_import_node(self.graph, node_id):
import_dict[node_id] = (
set([variable_name])
if node_id not in import_dict.keys()
else import_dict[node_id].union(set([variable_name]))
)
else:
variable_dict[node_id] = (
set([variable_name])
if node_id not in variable_dict.keys()
else variable_dict[node_id].union(set([variable_name]))
)
for var in set([variable_name]).intersection(
set(self.input_parameters)
):
input_parameters_assignment_nodes[
var
] = input_parameters_assignment_nodes.get(var, []) + [
node_id
]
# Identify variable dependencies of each node in topological order
for node_id in nx.topological_sort(self.graph.nx_graph):
nodeinfo = NodeInfo(
assigned_variables=variable_dict.get(node_id, set()),
assigned_artifact=self.all_session_artifacts[node_id].name
if node_id in self.all_session_artifacts.keys()
else None,
dependent_variables=set(),
predecessors=set(self.graph.nx_graph.predecessors(node_id)),
tracked_variables=set(),
module_import=import_dict.get(node_id, set()),
)
self._update_dependent_variables(nodeinfo, variable_dict)
self._update_tracked_variables(nodeinfo, node_id)
self.node_context[node_id] = nodeinfo
# If a variable is declared as an input parameters, we only support
# the literal assignment only happen once in the entire session at this
# moment. If there is a way to specify which literal assignment to use
# as an input parameter. We can relax this restriction.
# We allow multiple assignments to non-literals to handle common cases like the
# following:
# x = 1
# x = x + 1
# input_parameters = [x]
# In this case, the original definition of x = 1 will be parametrized.
for var, node_ids in input_parameters_assignment_nodes.items():
for node_id in node_ids:
if node_id in self.node_context.keys():
is_literal_assignment = (
len(self.node_context[node_id].dependent_variables)
== 0
)
previous_assignment = self.input_parameters_node.get(
var, None
)
if previous_assignment is None:
self.input_parameters_node[var] = node_id
else:
previous_assignment_is_literal = (
len(
self.node_context[
previous_assignment
].dependent_variables
)
== 0
)
if (
is_literal_assignment
and previous_assignment_is_literal
):
raise ValueError(
f"Variable {var}, is defined more than once"
)
elif not previous_assignment_is_literal:
# previous assignment is not literal, so we can reassign it
self.input_parameters_node[var] = node_id
# else previous assignment is literal and we should not override it
for var, node_id in self.input_parameters_node.items():
if len(self.node_context[node_id].dependent_variables) > 0:
dep_vars = ", ".join(
sorted(
list(self.node_context[node_id].dependent_variables)
)
)
raise ValueError(
f"LineaPy only supports input parameters without dependent variables for now. "
f"{var} has dependent variables: {dep_vars}."
)
def _get_sliced_nodes(
self, node_id: LineaID
) -> Tuple[Set[LineaID], Set[LineaID]]:
"""
Get sliced nodes from session graph and separate nodes for import
and main calculation.
"""
nodes = set(get_slice_graph(self.graph, [node_id]).nx_graph.nodes)
# Identify import nodes
importnodes = set(
[
_node
for _node in nodes
if len(self.node_context[_node].module_import) > 0
]
)
# Ancestors of import node should be also for import
importnodes = importnodes.union(
*[
self.graph.get_ancestors(import_node_id)
for import_node_id in importnodes
]
)
# Ancestors might not in nodes
importnodes = importnodes.intersection(nodes)
return nodes, importnodes
def _get_predecessor_info(self, nodes, importnodes):
"""
Figure out where each predecessor is coming from which artifact in
both artifact name and nodeid
"""
predecessor_nodes = set().union(
*[self.node_context[n_id].predecessors for n_id in nodes]
)
predecessor_nodes = predecessor_nodes - nodes - importnodes
predecessor_artifact = set(
self.node_context[n_id].artifact_name for n_id in predecessor_nodes
)
return predecessor_nodes, predecessor_artifact
def _get_common_variables(
self, curr_nc: UserCodeNodeCollection, pred_nc: UserCodeNodeCollection
) -> Tuple[List[str], Set[LineaID]]:
"""
Identify common variables for two NodeCollections.
"""
assert isinstance(pred_nc.name, str)
common_inner_variables = (
pred_nc.all_variables - set(pred_nc.return_variables)
).intersection(
curr_nc.get_input_variable_sources(self.node_context)[pred_nc.name]
)
common_nodes = set()
if len(common_inner_variables) > 0:
slice_variable_nodes = [
n
for n in curr_nc.predecessor_nodes
if n in pred_nc.node_list
and self.node_context[n].assigned_artifact
!= self.node_context[n].artifact_name
]
pred_graph_segment = self.graph.get_subgraph_from_id(
list(pred_nc.node_list)
)
assert pred_graph_segment is not None
source_art_slice_variable_graph = get_slice_graph(
pred_graph_segment,
slice_variable_nodes,
)
common_nodes = set(source_art_slice_variable_graph.nx_graph.nodes)
else:
common_nodes = set()
# Want the return variales in consistent ordering
return sorted(list(common_inner_variables)), common_nodes
def _slice_session_artifacts(self) -> None:
"""
Divide all session nodes into a set of non-overlapping nodes. Each set
responds to one artifact or common variables calculation. All the import
nodes will belong to one set and all input variable nodes will belong to
another set.
"""
self.used_nodes: Set[LineaID] = set() # Track nodes that get ever used
self.import_nodes: Set[LineaID] = set()
self.usercode_nodecollections = list()
for node_id, n in self.node_context.items():
if (
n.assigned_artifact is not None
and node_id in self.all_session_artifacts.keys()
):
art = self.all_session_artifacts[node_id]
# Identify nodes to calculate the artifact from sliced graphs
sliced_nodes, sliced_import_nodes = self._get_sliced_nodes(
node_id
)
# Attach import nodes to import NodeCollection
self.import_nodes.update(sliced_import_nodes - self.used_nodes)
# New nodes to calculate this specifict artifact
art_nodes = sliced_nodes - self.used_nodes - self.import_nodes
# Update used nodes
self.used_nodes = self.used_nodes.union(art_nodes).union(
self.import_nodes
)
# Identify precedent artifacts(id and name)
pred_nodes, _ = self._get_predecessor_info(
art_nodes, self.import_nodes
)
# Check whether this artifact should be replaced by
# pre-computed value, None if no and (name, version) if yes
if (
n.assigned_artifact
in self.reuse_pre_computed_artifacts.keys()
):
reuse_def = get_lineaartifactdef(
(
n.assigned_artifact,
self.reuse_pre_computed_artifacts[
n.assigned_artifact
].version,
)
)
nodecollectioninfo: ArtifactNodeCollection = (
ArtifactNodeCollection(
name=n.assigned_artifact,
node_list=art_nodes,
tracked_variables=n.tracked_variables,
return_variables=list(n.tracked_variables),
predecessor_nodes=pred_nodes,
is_pre_computed=True,
pre_computed_artifact=reuse_def,
)
)
else:
nodecollectioninfo = ArtifactNodeCollection(
name=n.assigned_artifact,
node_list=art_nodes,
tracked_variables=n.tracked_variables,
return_variables=list(n.tracked_variables),
predecessor_nodes=pred_nodes,
is_pre_computed=False,
)
# Update node context to label the node is assigned to this artifact
for n_id in nodecollectioninfo.node_list:
self.node_context[n_id].artifact_name = n.assigned_artifact
for n_id in self.import_nodes:
self.node_context[n_id].artifact_name = "module_import"
# Check whether we need to breakdown existing artifact node
# collection. If the precedent node in precedent collection
# is not the artifact itself, this means we should split the
# existing collection.
for (
source_artifact_name,
variables,
) in nodecollectioninfo.get_input_variable_sources(
self.node_context
).items():
source_id, source_info = [
(i, context)
for i, context in enumerate(
self.usercode_nodecollections
)
if context.name == source_artifact_name
][0]
# Common variables between two artifacts
(
common_inner_variables,
common_nodes,
) = self._get_common_variables(
nodecollectioninfo, source_info
)
# If common inner variables detected, split the precedent
# NodeCollection into two parts. One for calculation of
# common variables and the other for rest of artifact
# calculation.
if len(common_inner_variables) > 0 and len(common_nodes):
common_nodecollectioninfo = UserCodeNodeCollection(
name=f"{'_'.join(common_inner_variables)}_for_artifact_{source_info.name}_and_downstream",
node_list=common_nodes,
return_variables=common_inner_variables,
)
common_nodecollectioninfo.update_variable_info(
self.node_context, self.input_parameters_node
)
remaining_nodes = source_info.node_list - common_nodes
if isinstance(source_info, ArtifactNodeCollection):
remaining_nodecollectioninfo: UserCodeNodeCollection = ArtifactNodeCollection(
name=source_info.name,
node_list=remaining_nodes,
return_variables=source_info.return_variables,
is_pre_computed=source_info.is_pre_computed,
pre_computed_artifact=source_info.pre_computed_artifact,
)
else: # outputs a common variable, use base UserCodeNodeCollection instead.
remaining_nodecollectioninfo = UserCodeNodeCollection(
name=source_info.name,
node_list=remaining_nodes,
return_variables=source_info.return_variables,
)
remaining_nodecollectioninfo.update_variable_info(
self.node_context, self.input_parameters_node
)
self.usercode_nodecollections = (
self.usercode_nodecollections[:source_id]
+ [
common_nodecollectioninfo,
remaining_nodecollectioninfo,
]
+ self.usercode_nodecollections[(source_id + 1) :]
)
# Remove input parameter node
nodecollectioninfo.update_variable_info(
self.node_context, self.input_parameters_node
)
nodecollectioninfo.node_list = (
nodecollectioninfo.node_list
- set(self.input_parameters_node.values())
)
self.usercode_nodecollections.append(nodecollectioninfo)
# NodeCollection for import
self.import_nodecollection = ImportNodeCollection(
name="", node_list=self.import_nodes
)
# NodeCollection for input parameters
self.input_parameters_nodecollection = InputVarNodeCollection(
name="",
node_list=set(self.input_parameters_node.values()),
)
def _update_nodecollection_dependencies(self):
"""
Identify the dependencies graph of each NodeCollections that compute
an artifact or common variables for multiple artifacts. Remove useless
NodeCollections that only use to calculate artifacts in
reuse_pre_computed_artifacts list.
"""
last_appearance_nc: Dict[str, str] = dict()
dependencies: Dict[str, Set[str]] = dict()
# Artifact nodes that are going to be replaced by cached value
cache_nodes = [
nc.name
for nc in self.usercode_nodecollections
if isinstance(nc, ArtifactNodeCollection) and nc.is_pre_computed
]
# Determine input variables of a nodecollection are coming from output
# variables of nodecollections to build the NodeCollection dependencies
for nc in self.usercode_nodecollections:
dependencies[nc.name] = set()
for var in nc.input_variables:
if var in last_appearance_nc.keys():
dependencies[nc.name].add(last_appearance_nc[var])
for var in nc.return_variables:
last_appearance_nc[var] = nc.name
# Nodecollection dependencies
self.nodecollection_dependencies = TaskGraph(
nodes=[nc.name for nc in self.usercode_nodecollections],
edges=dependencies,
)
self.nodecollection_dependencies = (
self.nodecollection_dependencies.remap_nodes(
mapping={
nc.name: nc.safename
for nc in self.usercode_nodecollections
},
)
)
# Graph with each nodecollection as node
nc_graph = self.nodecollection_dependencies.graph
# Edges point to cached nodes
cache_nodes_edges = [
(from_node, to_node)
for from_node, to_node in nc_graph.edges
if from_node in cache_nodes
]
# Remove these edges and the nodecollection graph might split into
# multiple components, only keep components with user required artifact
nc_graph.remove_nodes_from(cache_nodes)
if nc_graph is not None and len(cache_nodes) > 0:
artifact_names = set([art.name for art in self.target_artifacts])
nc_graph = nx.union_all(
[
nc_graph.subgraph(c).copy()
for c in nx.connected_components(nc_graph.to_undirected())
if len(set(c).intersection(artifact_names)) > 0
]
)
nc_graph.add_nodes_from(cache_nodes)
nc_graph.add_edges_from(
[edge for edge in cache_nodes_edges if edge[1] in nc_graph]
)
self.usercode_nodecollections = [
art
for art in self.usercode_nodecollections
if art.name in nc_graph.nodes
]
self.nodecollection_dependencies.graph = nc_graph
def _get_first_artifact_name(self) -> Optional[str]:
"""
Return the name of first artifact(topologically sorted).
"""
for coll in self.usercode_nodecollections:
if isinstance(coll, ArtifactNodeCollection):
return coll.safename
return None
def get_libraries(self) -> List[ImportNodeORM]:
"""
Return a list of ImportNodeORM's containing the libraries associated with this SessionArtifact.
This function works by taking the imported library information from the whole session
and checking if the library is used for this SessionArtifact by trying to match with the
relevant nodes in the Session Artifacts import_nodes attribute.
Specifically we look for CallNodes with function `l_import` and a single argument.
The value of the argument Literal node will contain the base library name we want because
CallNodes with a single argument are the ones importing without the base_module optional argument,
which is only the case when we are importing the base library which we want the name of.
"""
# All libraries in a session
session_libs = self.db.get_libraries_for_session(self.session_id)
# Libraries this SessionArtifact uses will be stored here by name
session_artifact_lib_names: Set[str] = set()
# Get all nodes in SessionArtifact "import_nodes" attribute.
# Note that these nodes are not actually ImportNodes, but simply the
# Call/Literal/Lookup Nodes associated with the captured lines
# that are import statements.
import_nodes = {
id: self.db.get_node_by_id(id) for id in self.import_nodes
}
# Try to find library names through their associated CallNode.
# This Node must call l_import with one argument which will be the
# base library name we're interested in.
for node_id, node in import_nodes.items():
# check if node is CallNode doing module import
if is_import_node(self.graph, node_id):
node = cast(CallNode, node)
# Check function has a single argument
if len(node.positional_args) != 1:
continue
# This single argument should be a literal node holding the library name
argument_node = import_nodes[node.positional_args[0].id]
if not isinstance(argument_node, LiteralNode):
continue
session_artifact_lib_names.add(argument_node.value)
# Get only session libraries that are used in this SessionArtifact
session_artifact_libs = [
lib_info
for lib_info in session_libs
if lib_info.package_name in session_artifact_lib_names
]
return session_artifact_libs
|