36 '''This module implements a class that encapsulates rules that map
37 LFRic kernel metadata to kernel arguments.
40 from collections
import OrderedDict
45 OperatorArgMetadata, ColumnwiseOperatorArgMetadata, FieldArgMetadata,
46 FieldVectorArgMetadata, InterGridArgMetadata, InterGridVectorArgMetadata,
53 '''This class encapsulates rules to map LFRic kernel metadata to
54 kernel arguments. It does this by calling class methods each, of
55 which represent a particular kernel argument or set of
56 arguments. It calls these in the order that the arguments should
57 be found in the kernel metadata. The particular methods called and
58 their ordering is determined by the supplied kernel metadata.
60 Kernel argument information from kernel metadata can be used for
61 more than one thing, e.g. to create or check arguments within a
62 kernel and their declarations (using PSyIR), create the arguments
63 to a kernel call from the generated PSy-layer code or to create
64 appropriate algorithm PSyIR for an Invoke of the kernel or the
65 resulting call to the PSy-layer routine. Subclasses of this class
66 can be implemented for these different requirements.
76 bc_kern_regex = re.compile(
r"enforce_bc_(\d+_)?code", flags=re.I)
80 '''Takes kernel metadata as input and returns whatever is added to the
81 _info variable. This class adds nothing to the _info variable,
82 it is up to the subclass to do this. The variable is treated
83 as a container. The optional info argument allows the subclass
84 to add to an existing object if required.
86 :param metadata: the kernel metadata.
88 py:class:`psyclone.domain.lfric.kernel.LFRicKernelMetadata`
89 :param info: optional object to initialise the _info \
90 variable. Defaults to None.
91 :type info: :py:class:`Object`
100 def _initialise(cls, info):
101 '''Initialise the _info variable. This is implemented as its own
102 method to allow for simple subclassing (i.e. the mapping
103 method should not need to be subclassed).
105 :param info: object to initialise the _info variable.
106 :type info: :py:class:`Object`
109 cls.
_info_info = info
112 def _cell_position(cls):
113 '''A cell position argument.'''
116 def _mesh_height(cls):
117 '''A mesh height argument.'''
120 def _mesh_ncell2d_no_halos(cls):
121 '''Argument providing the number of columns in the mesh ignoring
127 def _mesh_ncell2d(cls):
128 '''Argument providing the number of columns in the mesh including
135 '''Arguments providing a mapping from coarse to fine mesh for the
141 def _scalar(cls, meta_arg):
142 '''Argument providing an LFRic scalar value.
144 :param meta_arg: the metadata associated with this scalar argument.
146 :py:class:`psyclone.domain.lfric.kernel.ScalarArgMetadata`
151 def _field(cls, meta_arg):
152 '''Argument providing an LFRic field.
154 :param meta_arg: the metadata associated with this field argument.
156 :py:class:`psyclone.domain.lfric.kernel.FieldArgMetadata`
161 def _field_vector(cls, meta_arg):
162 '''Arguments providing the components of an LFRic field vector.
164 :param meta_arg: the metadata associated with this field \
167 :py:class:`psyclone.domain.lfric.kernel.FieldVectorArgMetadata`
172 def _operator(cls, meta_arg):
173 '''Arguments providing an LMA operator.
175 :param meta_arg: the metadata associated with the operator \
178 :py:class:`psyclone.domain.lfric.kernel.OperatorArgMetadata`
183 def _cma_operator(cls, meta_arg):
184 '''Arguments providing a columnwise operator.
186 :param meta_arg: the metadata associated with the CMA operator \
188 :type meta_arg: :py:class:`psyclone.domain.lfric.kernel.\
189 ColumnwiseOperatorArgMetadata`
194 def _ref_element_properties(cls, meta_ref_element):
195 '''Arguments required if there are reference-element properties
196 specified in the metadata.
198 :param meta_ref_element: the metadata capturing the \
199 reference-element properties required by the kernel.
200 :type meta_mesh: List[\
201 :py:class:`psyclone.domain.lfric.kernel.MetaRefElementArgMetadata`]
206 def _mesh_properties(cls, meta_mesh):
207 '''All arguments required for mesh properties specified in the kernel
210 :param meta_mesh: the metadata capturing the mesh properties \
211 required by the kernel.
212 :type meta_mesh: List[\
213 :py:class:`psyclone.domain.lfric.kernel.MetaMeshArgMetadata`]
218 def _fs_common(cls, function_space):
219 '''Arguments associated with a function space that are common to
220 fields and operators.
222 :param str function_space: the current function space.
227 def _fs_compulsory_field(cls, function_space):
228 '''Compulsory arguments for this function space.
230 :param str function_space: the current function space.
235 def _fs_intergrid(cls, meta_arg):
236 '''Function-space related arguments for an intergrid kernel.
238 :param meta_arg: the metadata capturing the InterGrid argument \
239 required by the kernel.
241 :py:class:`psyclone.domain.lfric.kernel.InterGridArgMetadata`
246 def _basis(cls, function_space):
247 '''Arguments associated with basis functions on the supplied function
250 :param str function_space: the current function space.
255 def _diff_basis(cls, function_space):
256 '''Arguments associated with differential basis functions on the
257 supplied function space.
259 :param str function_space: the current function space.
264 def _quad_rule(cls, shapes):
265 '''Quadrature information is required (gh_shape =
266 gh_quadrature_*). Shape information is provided for each shape
267 in the order specified in the gh_shape metadata.
269 :param shapes: the metadata capturing the quadrature shapes \
270 required by the kernel.
271 :type shapes: List[str]
276 def _field_bcs_kernel(cls):
277 '''Fix for the field boundary condition kernel.'''
280 def _operator_bcs_kernel(cls):
281 '''Fix for the operator boundary condition kernel.'''
284 def _stencil_cross2d_extent(cls, meta_arg):
285 '''The field has a stencil access of type 'cross2d' of unknown extent
286 and therefore requires the extent to be passed from the
289 :param meta_arg: the metadata associated with a field argument \
290 with a cross2d stencil access.
292 :py:class:`psyclone.domain.lfric.kernel.FieldArgMetadata`
297 def _stencil_cross2d_max_extent(cls, meta_arg):
298 '''The field has a stencil access of type 'cross2d' and requires the
299 maximum size of a stencil extent to be passed from the
302 :param meta_arg: the metadata associated with a field argument \
303 with a cross2d stencil access.
305 :py:class:`psyclone.domain.lfric.kernel.FieldArgMetadata`
310 def _stencil_extent(cls, meta_arg):
311 '''The field has a stencil access (that is not of type 'cross2d') of
312 unknown extent and therefore requires the extent to be passed
313 from the algorithm layer.
315 :param meta_arg: the metadata associated with a field argument \
316 with a stencil access.
318 :py:class:`psyclone.domain.lfric.kernel.FieldArgMetadata`
323 def _stencil_xory1d_direction(cls, meta_arg):
324 '''The field has a stencil access of type 'xory1d' and therefore
325 requires the stencil direction to be passed from the algorithm
328 :param meta_arg: the metadata associated with a field argument \
329 with a xory1d stencil access.
331 :py:class:`psyclone.domain.lfric.kernel.FieldArgMetadata`
336 def _stencil_cross2d(cls, meta_arg):
337 '''Stencil information that is always passed from the algorithm layer
338 if a field has a stencil access of type 'cross2d'.
340 :param meta_arg: the metadata associated with a field argument \
341 with a stencil access.
343 :py:class:`psyclone.domain.lfric.kernel.FieldArgMetadata`
348 def _stencil(cls, meta_arg):
349 '''Stencil information that is always passed from the algorithm layer
350 if a field has a stencil access that is not of type 'cross2d'.
352 :param meta_arg: the metadata associated with a field argument \
353 with a stencil access.
355 :py:class:`psyclone.domain.lfric.kernel.FieldArgMetadata`
360 def _banded_dofmap(cls, function_space, cma_operator):
361 '''Adds a banded dofmap for the provided function space and cma
362 operator when there is an cma assembly kernel.
364 :param str function_space: the function space for this banded \
366 :param cma_operator: the cma operator metadata associated with \
368 :type cma_operator: :py:class:`psyclone.domain.lfric.kernel.\
369 ColumnwiseOperatorArgMetadata`
374 def _indirection_dofmap(cls, function_space, cma_operator):
375 '''Adds an indirection dofmap for the provided function space and cma
376 operator when there is an apply cma kernel.
378 :param str function_space: the function space for this \
380 :param cma_operator: the cma operator metadata associated with \
381 this indirection dofmap.
382 :type cma_operator: :py:class:`psyclone.domain.lfric.kernel.\
383 ColumnwiseOperatorArgMetadata`
391 '''Specifies which arguments appear in an argument list and their
392 ordering. Calls methods for each type of argument. These
393 methods can be specialised by a subclass for its particular
396 :raises InternalError: if an unexpected mesh property is found.
401 if cls._metadata.meta_args_get(
402 [OperatorArgMetadata, ColumnwiseOperatorArgMetadata]):
407 if cls._metadata.kernel_type
not in [
"cma-apply",
"cma-matrix-matrix"]:
421 if cls._metadata.operates_on ==
"domain":
422 cls._mesh_ncell2d_no_halos()
423 if cls._metadata.meta_args_get(ColumnwiseOperatorArgMetadata):
426 if cls._metadata.kernel_type ==
"inter-grid":
438 const = LFRicConstants()
439 for meta_arg
in cls._metadata.meta_args:
441 if type(meta_arg)
in [
442 FieldArgMetadata, FieldVectorArgMetadata,
443 InterGridArgMetadata, InterGridVectorArgMetadata]:
444 if type(meta_arg)
in [FieldArgMetadata, InterGridArgMetadata]:
446 if type(meta_arg)
in [
447 FieldVectorArgMetadata, InterGridVectorArgMetadata]:
448 cls._field_vector(meta_arg)
450 if meta_arg.stencil ==
"cross2d":
454 cls._stencil_cross2d_extent(meta_arg)
458 cls._stencil_cross2d_max_extent(meta_arg)
463 cls._stencil_extent(meta_arg)
464 if meta_arg.stencil ==
"xory1d":
467 cls._stencil_xory1d_direction(meta_arg)
470 if meta_arg.stencil ==
"cross2d":
471 cls._stencil_cross2d(meta_arg)
473 cls._stencil(meta_arg)
474 elif type(meta_arg)
is OperatorArgMetadata:
475 cls._operator(meta_arg)
476 elif type(meta_arg)
is ColumnwiseOperatorArgMetadata:
477 cls._cma_operator(meta_arg)
478 elif type(meta_arg)
is ScalarArgMetadata:
479 cls._scalar(meta_arg)
482 f
"Unexpected meta_arg type '{type(meta_arg).__name__}' "
487 function_space_args = cls._metadata.meta_args_get(
488 [FieldArgMetadata, FieldVectorArgMetadata,
489 InterGridArgMetadata, InterGridVectorArgMetadata,
490 OperatorArgMetadata, ColumnwiseOperatorArgMetadata])
491 unique_function_spaces = OrderedDict()
492 for arg
in function_space_args:
494 OperatorArgMetadata, ColumnwiseOperatorArgMetadata]:
495 unique_function_spaces[arg.function_space_to] =
None
496 unique_function_spaces[arg.function_space_from] =
None
498 unique_function_spaces[arg.function_space] =
None
500 for function_space
in unique_function_spaces.keys():
504 if cls._metadata.kernel_type
not in [
505 "cma-matrix-matrix",
"inter-grid"]:
506 cls._fs_common(function_space)
510 if (cls._metadata.field_meta_args_on_fs(
511 [FieldArgMetadata, FieldVectorArgMetadata],
513 cls._fs_compulsory_field(function_space)
517 intergrid_field = cls._metadata.field_meta_args_on_fs(
518 [InterGridArgMetadata, InterGridVectorArgMetadata],
521 cls._fs_intergrid(intergrid_field[0])
523 cma_ops = cls._metadata.operator_meta_args_on_fs(
524 ColumnwiseOperatorArgMetadata, function_space)
526 if cls._metadata.kernel_type ==
"cma-assembly":
528 cls._banded_dofmap(function_space, cma_ops[0])
529 elif cls._metadata.kernel_type ==
"cma-apply":
531 cls._indirection_dofmap(
532 function_space, cma_ops[0])
537 meta_funcs = cls._metadata.meta_funcs \
538 if cls._metadata.meta_funcs
else []
539 if any(func
for func
in meta_funcs
if func.basis_function
and
540 func.function_space == function_space):
541 cls._basis(function_space)
542 if any(func
for func
in meta_funcs
if func.diff_basis_function
543 and func.function_space == function_space):
544 cls._diff_basis(function_space)
548 if (cls._metadata.procedure_name
and
549 cls.bc_kern_regex.match(cls._metadata.procedure_name)):
550 cls._field_bcs_kernel()
554 if (cls._metadata.procedure_name
and
555 cls._metadata.procedure_name.lower() ==
556 "enforce_operator_bc_code"):
557 cls._operator_bcs_kernel()
560 if cls._metadata.meta_ref_element:
561 cls._ref_element_properties(cls._metadata.meta_ref_element)
564 if cls._metadata.meta_mesh:
565 cls._mesh_properties(cls._metadata.meta_mesh)
570 if cls._metadata.meta_funcs
and cls._metadata.shapes
and \
571 any(shape
for shape
in cls._metadata.shapes
if shape
in
572 const.VALID_QUADRATURE_SHAPES):
573 cls._quad_rule(cls._metadata.shapes)
def mapping(cls, metadata, info=None)
def _initialise(cls, info)