Reference Guide  2.5.0
psyclone.psyad.transformations.assignment_trans.AssignmentTrans Class Reference
Inheritance diagram for psyclone.psyad.transformations.assignment_trans.AssignmentTrans:
Collaboration diagram for psyclone.psyad.transformations.assignment_trans.AssignmentTrans:

Public Member Functions

def apply (self, node, options=None)
 
def validate (self, node, options=None)
 
def __str__ (self)
 
def name (self)
 
- Public Member Functions inherited from psyclone.psyad.transformations.adjoint_trans.AdjointTransformation
def __init__ (self, active_variables)
 

Detailed Description

Implements a transformation to translate a Tangent-Linear
assignment to its Adjoint form.

Definition at line 56 of file assignment_trans.py.

Member Function Documentation

◆ apply()

def psyclone.psyad.transformations.assignment_trans.AssignmentTrans.apply (   self,
  node,
  options = None 
)
Apply the Assignment transformation to the specified node. The node
must be a valid tangent-linear assignment. The assignment is
replaced with its adjoint version.

:param node: an Assignment node.
:type node: :py:class:`psyclone.psyir.nodes.Assignment`
:param options: a dictionary with options for transformations.
:type options: Optional[Dict[str, Any]]

Reimplemented from psyclone.psyGen.Transformation.

Definition at line 61 of file assignment_trans.py.

61  def apply(self, node, options=None):
62  '''Apply the Assignment transformation to the specified node. The node
63  must be a valid tangent-linear assignment. The assignment is
64  replaced with its adjoint version.
65 
66  :param node: an Assignment node.
67  :type node: :py:class:`psyclone.psyir.nodes.Assignment`
68  :param options: a dictionary with options for transformations.
69  :type options: Optional[Dict[str, Any]]
70 
71  '''
72  self.validate(node, options)
73 
74  # Split the RHS of the assignment into [-]<term> +- <term> +- ...
75  rhs_terms = self._split_nodes(
76  node.rhs, [BinaryOperation.Operator.ADD,
77  BinaryOperation.Operator.SUB])
78 
79  deferred_inc = []
80  sym_maths = SymbolicMaths.get()
81  for rhs_term in rhs_terms:
82 
83  # Find the active var in rhs_term if one exists (we may
84  # find 0.0), storing it in 'active_var' and if so replace
85  # it with lhs_active_var storing the modified term in
86  # 'new_rhs_term'. Also determine whether this is an
87  # increment, storing the result in 'increment'.
88  increment = False
89  active_var = None
90  new_rhs_term = rhs_term.copy()
91  for ref in new_rhs_term.walk(Reference):
92  if ref.symbol not in self._active_variables:
93  continue
94  active_var = ref
95  # Identify whether this reference on the RHS matches the
96  # one on the LHS - if so we have an increment.
97  if node.is_array_assignment and isinstance(ref, ArrayMixin):
98  # TODO #1537 - we can't just do `sym_maths.equal` if we
99  # have an array range because the SymbolicMaths class does
100  # not currently support them.
101  # Since we have already checked (in validate) that any
102  # references to the same symbol on the RHS have the same
103  # range, this is an increment if the symbols match.
104  if node.lhs.symbol is ref.symbol:
105  increment = True
106  else:
107  if sym_maths.equal(ref, node.lhs):
108  increment = True
109  if ref.parent:
110  ref.replace_with(node.lhs.copy())
111  else:
112  new_rhs_term = node.lhs.copy()
113  break
114 
115  # Work out whether the binary operation for this term is a
116  # '+' or a '-' and return it in 'rhs_operator'.
117  rhs_operator = BinaryOperation.Operator.ADD
118  previous = rhs_term
119  candidate = rhs_term.parent
120  while not isinstance(candidate, Assignment):
121  if (isinstance(candidate, BinaryOperation) and
122  candidate.operator == BinaryOperation.Operator.SUB and
123  candidate.children[1] is previous):
124  # Rules: + + -> +; - - -> +; + - -> -; - + -> -
125  # If the higher level op is + then there is no
126  # change to the existing op. If it is - then
127  # we flip the op i.e. - => + and + => -.
128  if rhs_operator == BinaryOperation.Operator.SUB:
129  rhs_operator = BinaryOperation.Operator.ADD
130  else:
131  rhs_operator = BinaryOperation.Operator.SUB
132  previous = candidate
133  candidate = candidate.parent
134 
135  if not active_var:
136  # This is an expression without an active variable
137  # (which must be 0.0, otherwise validation will have
138  # rejected it). There is therefore nothing to output.
139  continue
140 
141  if increment:
142  # The output of an increment needs to be deferred as
143  # other terms must be completed before the LHS TL
144  # active variable is modified. Save the rhs term
145  # and its associated operator.
146  deferred_inc.append((new_rhs_term, rhs_operator))
147  else:
148  # Output the adjoint for this term
149  rhs = BinaryOperation.create(
150  rhs_operator, active_var.copy(), new_rhs_term)
151  assignment = Assignment.create(active_var.copy(), rhs)
152  node.parent.children.insert(node.position, assignment)
153 
154  if (len(deferred_inc) == 1 and
155  isinstance(deferred_inc[0][0], Reference)):
156  # No need to output anything as the adjoint is A = A.
157  pass
158  elif deferred_inc:
159  # Output the adjoint for all increment terms in a single line.
160  rhs, _ = deferred_inc.pop(0)
161  for term, operator in deferred_inc:
162  rhs = BinaryOperation.create(operator, rhs, term)
163  assignment = Assignment.create(node.lhs.copy(), rhs)
164  node.parent.children.insert(node.position, assignment)
165  else:
166  # The assignment is not an increment. The LHS active
167  # variable needs to be zero'ed.
168  assignment = Assignment.create(
169  node.lhs.copy(), Literal("0.0", REAL_TYPE))
170  node.parent.children.insert(node.position, assignment)
171 
172  # Remove the original node
173  node.detach()
174 

References psyclone.psyad.adjoint_visitor.AdjointVisitor._active_variables, psyclone.psyad.transformations.adjoint_trans.AdjointTransformation._active_variables, psyclone.psyad.transformations.assignment_trans.AssignmentTrans._split_nodes(), psyclone.domain.lfric.kernel.lfric_kernel_metadata.LFRicKernelMetadata.validate(), psyclone.transformations.MoveTrans.validate(), psyclone.transformations.Dynamo0p3AsyncHaloExchangeTrans.validate(), psyclone.domain.common.transformations.alg_invoke_2_psy_call_trans.AlgInvoke2PSyCallTrans.validate(), psyclone.domain.common.transformations.alg_trans.AlgTrans.validate(), psyclone.domain.common.transformations.kernel_module_inline_trans.KernelModuleInlineTrans.validate(), psyclone.domain.common.transformations.raise_psyir_2_alg_trans.RaisePSyIR2AlgTrans.validate(), psyclone.domain.gocean.transformations.gocean_const_loop_bounds_trans.GOConstLoopBoundsTrans.validate(), psyclone.domain.gocean.transformations.gocean_move_iteration_boundaries_inside_kernel_trans.GOMoveIterationBoundariesInsideKernelTrans.validate(), psyclone.domain.gocean.transformations.gocean_opencl_trans.GOOpenCLTrans.validate(), psyclone.domain.gocean.transformations.raise_psyir_2_gocean_kern_trans.RaisePSyIR2GOceanKernTrans.validate(), psyclone.domain.lfric.transformations.lfric_alg_invoke_2_psy_call_trans.LFRicAlgInvoke2PSyCallTrans.validate(), psyclone.domain.lfric.transformations.raise_psyir_2_lfric_kern_trans.RaisePSyIR2LFRicKernTrans.validate(), psyclone.domain.nemo.transformations.create_nemo_invoke_schedule_trans.CreateNemoInvokeScheduleTrans.validate(), psyclone.domain.nemo.transformations.create_nemo_psy_trans.CreateNemoPSyTrans.validate(), psyclone.domain.nemo.transformations.nemo_allarrayrange2loop_trans.NemoAllArrayRange2LoopTrans.validate(), psyclone.domain.nemo.transformations.nemo_arrayrange2loop_trans.NemoArrayRange2LoopTrans.validate(), psyclone.domain.nemo.transformations.nemo_outerarrayrange2loop_trans.NemoOuterArrayRange2LoopTrans.validate(), psyclone.psyad.transformations.assignment_trans.AssignmentTrans.validate(), psyclone.psyGen.Transformation.validate(), psyclone.psyir.transformations.acc_update_trans.ACCUpdateTrans.validate(), psyclone.psyir.transformations.allarrayaccess2loop_trans.AllArrayAccess2LoopTrans.validate(), psyclone.psyir.transformations.arrayaccess2loop_trans.ArrayAccess2LoopTrans.validate(), psyclone.psyir.transformations.arrayrange2loop_trans.ArrayRange2LoopTrans.validate(), psyclone.psyir.transformations.chunk_loop_trans.ChunkLoopTrans.validate(), psyclone.psyir.transformations.fold_conditional_return_expressions_trans.FoldConditionalReturnExpressionsTrans.validate(), psyclone.psyir.transformations.hoist_local_arrays_trans.HoistLocalArraysTrans.validate(), psyclone.psyir.transformations.hoist_loop_bound_expr_trans.HoistLoopBoundExprTrans.validate(), psyclone.psyir.transformations.hoist_trans.HoistTrans.validate(), psyclone.psyir.transformations.inline_trans.InlineTrans.validate(), psyclone.psyir.transformations.intrinsics.array_reduction_base_trans.ArrayReductionBaseTrans.validate(), psyclone.psyir.transformations.intrinsics.dotproduct2code_trans.DotProduct2CodeTrans.validate(), psyclone.psyir.transformations.intrinsics.intrinsic2code_trans.Intrinsic2CodeTrans.validate(), psyclone.psyir.transformations.intrinsics.matmul2code_trans.Matmul2CodeTrans.validate(), psyclone.psyir.transformations.loop_swap_trans.LoopSwapTrans.validate(), psyclone.psyir.transformations.loop_tiling_2d_trans.LoopTiling2DTrans.validate(), psyclone.psyir.transformations.loop_trans.LoopTrans.validate(), psyclone.psyir.transformations.omp_task_trans.OMPTaskTrans.validate(), psyclone.psyir.transformations.omp_taskwait_trans.OMPTaskwaitTrans.validate(), psyclone.psyir.transformations.parallel_loop_trans.ParallelLoopTrans.validate(), psyclone.psyir.transformations.reference2arrayrange_trans.Reference2ArrayRangeTrans.validate(), psyclone.psyir.transformations.replace_induction_variables_trans.ReplaceInductionVariablesTrans.validate(), psyclone.transformations.OMPDeclareTargetTrans.validate(), psyclone.transformations.DynamoOMPParallelLoopTrans.validate(), psyclone.transformations.Dynamo0p3OMPLoopTrans.validate(), psyclone.transformations.GOceanOMPLoopTrans.validate(), psyclone.transformations.Dynamo0p3RedundantComputationTrans.validate(), psyclone.transformations.Dynamo0p3KernelConstTrans.validate(), psyclone.transformations.ACCRoutineTrans.validate(), psyclone.transformations.KernelImportsToArguments.validate(), psyclone.domain.gocean.transformations.gocean_loop_fuse_trans.GOceanLoopFuseTrans.validate(), psyclone.domain.lfric.transformations.lfric_loop_fuse_trans.LFRicLoopFuseTrans.validate(), psyclone.psyir.transformations.loop_fuse_trans.LoopFuseTrans.validate(), psyclone.domain.gocean.transformations.gocean_extract_trans.GOceanExtractTrans.validate(), psyclone.domain.lfric.transformations.lfric_extract_trans.LFRicExtractTrans.validate(), psyclone.psyir.transformations.extract_trans.ExtractTrans.validate(), psyclone.psyir.transformations.nan_test_trans.NanTestTrans.validate(), psyclone.psyir.transformations.read_only_verify_trans.ReadOnlyVerifyTrans.validate(), psyclone.transformations.ParallelRegionTrans.validate(), psyclone.transformations.OMPParallelTrans.validate(), psyclone.transformations.ACCParallelTrans.validate(), psyclone.transformations.ACCKernelsTrans.validate(), psyclone.transformations.ACCDataTrans.validate(), psyclone.psyir.transformations.psy_data_trans.PSyDataTrans.validate(), psyclone.psyir.transformations.region_trans.RegionTrans.validate(), and psyclone.transformations.ACCEnterDataTrans.validate().

Here is the call graph for this function:

◆ name()

def psyclone.psyad.transformations.assignment_trans.AssignmentTrans.name (   self)
:returns: the name of the transformation as a string.
:rtype: str

Reimplemented from psyclone.psyGen.Transformation.

Definition at line 408 of file assignment_trans.py.

408  def name(self):
409  '''
410  :returns: the name of the transformation as a string.
411  :rtype: str
412 
413  '''
414  return type(self).__name__
415 
416 
417 # =============================================================================
418 # Documentation utils: The list of module members that we wish AutoAPI to
419 # generate documentation for (see https://psyclone-ref.readthedocs.io).
Here is the caller graph for this function:

◆ validate()

def psyclone.psyad.transformations.assignment_trans.AssignmentTrans.validate (   self,
  node,
  options = None 
)
Perform various checks to ensure that it is valid to apply the
AssignmentTrans transformation to the supplied PSyIR Node.

:param node: the node that is being checked.
:type node: :py:class:`psyclone.psyir.nodes.Assignment`
:param options: a dictionary with options for transformations.
:type options: Optional[Dict[str, Any]]

:raises TransformationError: if the node argument is not an \
    Assignment.
:raises TangentLinearError: if the assignment does not conform \
    to the required tangent-linear structure.

Reimplemented from psyclone.psyGen.Transformation.

Definition at line 228 of file assignment_trans.py.

228  def validate(self, node, options=None):
229  '''Perform various checks to ensure that it is valid to apply the
230  AssignmentTrans transformation to the supplied PSyIR Node.
231 
232  :param node: the node that is being checked.
233  :type node: :py:class:`psyclone.psyir.nodes.Assignment`
234  :param options: a dictionary with options for transformations.
235  :type options: Optional[Dict[str, Any]]
236 
237  :raises TransformationError: if the node argument is not an \
238  Assignment.
239  :raises TangentLinearError: if the assignment does not conform \
240  to the required tangent-linear structure.
241 
242  '''
243  # Check node argument is an assignment node
244  if not isinstance(node, Assignment):
245  raise TransformationError(
246  f"Node argument in assignment transformation should be a "
247  f"PSyIR Assignment, but found '{type(node).__name__}'.")
248  assign = node
249 
250  # If there are no active variables then return
251  assignment_active_var_names = [
252  var.name for var in assign.walk(Reference)
253  if var.symbol in self._active_variables]
254  if not assignment_active_var_names:
255  # No active variables in this assignment so the assignment
256  # remains unchanged.
257  return
258 
259  # The lhs of the assignment node should be an active variable
260  if assign.lhs.symbol not in self._active_variables:
261  # There are active vars on RHS but not on LHS
262  raise TangentLinearError(
263  f"Assignment node '{assign.debug_string()}' has the following "
264  f"active variables on its RHS '{assignment_active_var_names}' "
265  f"but its LHS '{assign.lhs.name}' is not an active variable.")
266 
267  # Split the RHS of the assignment into <expr> +- <expr> +- <expr>
268  rhs_terms = self._split_nodes(
269  assign.rhs, [BinaryOperation.Operator.ADD,
270  BinaryOperation.Operator.SUB])
271 
272  # Check for the special case where RHS=0.0. This is really a
273  # representation of multiplying an active variable by zero but
274  # this is obviously not visible in the code. Use 'float' to
275  # normalise different representations of 0.
276  if (len(rhs_terms) == 1 and isinstance(rhs_terms[0], Literal) and
277  float(rhs_terms[0].value) == 0.0):
278  return
279 
280  # Check each expression term. It must be in the form
281  # A */ <expr> where A is an active variable.
282  for rhs_term in rhs_terms:
283 
284  # When searching for references to an active variable we must
285  # take care to exclude those cases where they are present as
286  # arguments to the L/UBOUND intrinsics (as they will be when
287  # array notation is used).
288  active_vars = []
289  lu_bound_ops = [IntrinsicCall.Intrinsic.LBOUND,
290  IntrinsicCall.Intrinsic.UBOUND]
291  for ref in rhs_term.walk(Reference):
292  if (ref.symbol in self._active_variables and
293  not (isinstance(ref.parent, IntrinsicCall) and
294  ref.parent.intrinsic in lu_bound_ops)):
295  active_vars.append(ref)
296 
297  if not active_vars:
298  # This term must contain an active variable
299  raise TangentLinearError(
300  f"Each non-zero term on the RHS of the assignment "
301  f"'{assign.debug_string()}' must have an active variable "
302  f"but '{rhs_term.debug_string()}' does not.")
303 
304  if len(active_vars) > 1:
305  # This term can only contain one active variable
306  raise TangentLinearError(
307  f"Each term on the RHS of the assignment "
308  f"'{assign.debug_string()}' must not have more than one "
309  f"active variable but '{rhs_term.debug_string()}' has "
310  f"{len(active_vars)}.")
311 
312  if (isinstance(rhs_term, Reference) and rhs_term.symbol
313  in self._active_variables):
314  self._array_ranges_match(assign, rhs_term)
315  # This term consists of a single active variable (with
316  # a multiplier of unity) and is therefore valid.
317  continue
318 
319  # Ignore unary minus if it is the parent. unary minus does
320  # not cause a problem when applying the transformation but
321  # does cause a problem here in the validation when
322  # splitting the term into expressions.
323  if (isinstance(rhs_term, UnaryOperation) and
324  rhs_term.operator ==
325  UnaryOperation.Operator.MINUS):
326  rhs_term = rhs_term.children[0]
327 
328  # Split the term into <expr> */ <expr> */ <expr>
329  expr_terms = self._split_nodes(
330  rhs_term, [BinaryOperation.Operator.MUL,
331  BinaryOperation.Operator.DIV])
332 
333  # One of the expression terms must be an active variable
334  # or an active variable with a preceding + or -.
335  for expr_term in expr_terms:
336  check_term = expr_term
337  if (isinstance(expr_term, UnaryOperation) and
338  expr_term.operator in [UnaryOperation.Operator.PLUS,
339  UnaryOperation.Operator.MINUS]):
340  check_term = expr_term.children[0]
341  if (isinstance(check_term, Reference) and
342  check_term.symbol in self._active_variables):
343  active_variable = check_term
344  break
345  else:
346  raise TangentLinearError(
347  f"Each term on the RHS of the assignment "
348  f"'{assign.debug_string()}' must be linear with respect "
349  f"to the active variable, but found "
350  f"'{rhs_term.debug_string()}'.")
351 
352  # The term must be a product of an active variable with an
353  # inactive expression. Check that the active variable does
354  # not appear in a denominator.
355  candidate = active_variable
356  parent = candidate.parent
357  while not isinstance(parent, Assignment):
358  # Starting with the active variable reference, look up
359  # the tree for an ancestor divide operation until
360  # reaching the assignment node.
361  if (isinstance(parent, BinaryOperation) and
362  parent.operator == BinaryOperation.Operator.DIV and
363  parent.children[1] is candidate):
364  # Found a divide and the active variable is on its RHS
365  raise TangentLinearError(
366  f"In tangent-linear code an active variable cannot "
367  f"appear as a denominator but "
368  f"'{rhs_term.debug_string()}' was found in "
369  f"'{assign.debug_string()}'.")
370  # Continue up the PSyIR tree
371  candidate = parent
372  parent = candidate.parent
373 
374  # If the LHS of the assignment is an array range then we only
375  # support accesses of the same variable on the RHS if they have
376  # the same range.
377  self._array_ranges_match(assign, active_variable)
378 

References psyclone.psyad.adjoint_visitor.AdjointVisitor._active_variables, psyclone.psyad.transformations.adjoint_trans.AdjointTransformation._active_variables, psyclone.psyad.transformations.assignment_trans.AssignmentTrans._array_ranges_match(), and psyclone.psyad.transformations.assignment_trans.AssignmentTrans._split_nodes().

Here is the call graph for this function:
Here is the caller graph for this function:

The documentation for this class was generated from the following file: