Coverage for src/beamme/cosserat_curve/cosserat_curve.py: 98%
199 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-08 11:03 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-08 11:03 +0000
1# The MIT License (MIT)
2#
3# Copyright (c) 2018-2026 BeamMe Authors
4#
5# Permission is hereby granted, free of charge, to any person obtaining a copy
6# of this software and associated documentation files (the "Software"), to deal
7# in the Software without restriction, including without limitation the rights
8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9# copies of the Software, and to permit persons to whom the Software is
10# furnished to do so, subject to the following conditions:
11#
12# The above copyright notice and this permission notice shall be included in
13# all copies or substantial portions of the Software.
14#
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21# THE SOFTWARE.
22"""Define a Cosserat curve object that can be used to describe warping of
23curve-like objects."""
25from pathlib import Path as _Path
26from typing import Tuple as _Tuple
27from xml.etree import ElementTree as _ET # nosec B405
29import numpy as _np
30import pyvista as _pv
31import quaternion as _quaternion
32from numpy.typing import NDArray as _NDArray
33from scipy import integrate as _integrate
34from scipy import interpolate as _interpolate
35from scipy import optimize as _optimize
37from beamme.core.conf import bme as _bme
38from beamme.core.rotation import Rotation as _Rotation
39from beamme.core.rotation import rotate_coordinates as _rotate_coordinates
40from beamme.core.rotation import smallest_rotation as _smallest_rotation
43def get_piecewise_linear_arc_length_along_points(
44 coordinates: _np.ndarray,
45) -> _np.ndarray:
46 """Return the accumulated distance between the points.
48 Args
49 ----
50 coordinates:
51 Array containing the point coordinates
52 """
54 n_points = len(coordinates)
55 point_distance = _np.linalg.norm(coordinates[1:] - coordinates[:-1], axis=1)
56 point_arc_length = _np.zeros(n_points)
57 for i in range(1, n_points):
58 point_arc_length[i] = point_arc_length[i - 1] + point_distance[i - 1]
59 return point_arc_length
62def get_spline_interpolation(
63 coordinates: _np.ndarray, point_arc_length: _np.ndarray
64) -> _interpolate.BSpline:
65 """Get a spline interpolation of the given points.
67 Args
68 ----
69 coordinates:
70 Array containing the point coordinates
71 point_arc_length:
72 Arc length for each coordinate
74 Return
75 ----
76 centerline_interpolation:
77 The spline interpolation object
78 """
80 # Interpolate coordinates along arc length
81 # Note: The numeric evaluation of the spline interpolation can depend on the
82 # operating system, thus introducing slight numerical differences (~1e-12).
83 centerline_interpolation = _interpolate.make_interp_spline(
84 point_arc_length, coordinates
85 )
86 return centerline_interpolation
89def get_quaternions_along_curve(
90 centerline: _interpolate.BSpline, point_arc_length: _np.ndarray
91) -> _NDArray[_quaternion.quaternion]:
92 """Get the quaternions along the curve based on smallest rotation mappings.
94 The initial rotation will be calculated based on the largest projection of the initial tangent
95 onto the cartesian basis vectors.
97 Args
98 ----
99 centerline:
100 A function that returns the centerline position for a parameter coordinate t
101 point_arc_length:
102 Array of parameter coordinates for which the quaternions should be calculated
103 """
105 centerline_interpolation_derivative = centerline.derivative()
107 def basis(i):
108 """Return the i-th Cartesian basis vector."""
109 basis = _np.zeros([3])
110 basis[i] = 1.0
111 return basis
113 # Get the reference rotation
114 t0 = centerline_interpolation_derivative(point_arc_length[0])
115 min_projection = _np.argmin(_np.abs([_np.dot(basis(i), t0) for i in range(3)]))
116 last_rotation = _Rotation.from_basis(t0, basis(min_projection))
118 # Get the rotation vectors along the curve. They are calculated with smallest rotation mappings.
119 n_points = len(point_arc_length)
120 quaternions = _np.zeros(n_points, dtype=_quaternion.quaternion)
121 quaternions[0] = last_rotation.q
122 for i in range(1, n_points):
123 rotation = _smallest_rotation(
124 last_rotation,
125 centerline_interpolation_derivative(point_arc_length[i]),
126 )
127 quaternions[i] = rotation.q
128 last_rotation = rotation
129 return quaternions
132def get_relative_distance_and_rotations(
133 coordinates: _np.ndarray, quaternions: _NDArray[_quaternion.quaternion]
134) -> _Tuple[
135 _np.ndarray, _NDArray[_quaternion.quaternion], _NDArray[_quaternion.quaternion]
136]:
137 """Get relative distances and rotations that can be used to evaluate
138 "intermediate" states of the Cosserat curve."""
140 n_points = len(coordinates)
141 relative_distances = _np.zeros(n_points - 1)
142 relative_distances_rotation = _np.zeros(n_points - 1, dtype=_quaternion.quaternion)
143 relative_rotations = _np.zeros(n_points - 1, dtype=_quaternion.quaternion)
145 for i_segment in range(n_points - 1):
146 relative_distance = coordinates[i_segment + 1] - coordinates[i_segment]
147 relative_distance_local = _quaternion.rotate_vectors(
148 quaternions[i_segment].conjugate(), relative_distance
149 )
150 relative_distances[i_segment] = _np.linalg.norm(relative_distance_local)
152 smallest_relative_rotation_onto_distance = _smallest_rotation(
153 _Rotation(),
154 relative_distance_local,
155 )
156 relative_distances_rotation[i_segment] = (
157 smallest_relative_rotation_onto_distance.get_numpy_quaternion()
158 )
160 relative_rotations[i_segment] = (
161 quaternions[i_segment].conjugate() * quaternions[i_segment + 1]
162 )
164 return relative_distances, relative_distances_rotation, relative_rotations
167class CosseratCurve(object):
168 """Represent a Cosserat curve in space."""
170 def __init__(
171 self,
172 point_coordinates: _np.ndarray,
173 *,
174 starting_triad_guess: _Rotation | None = None,
175 ):
176 """Initialize the Cosserat curve based on points in 3D space.
178 Args:
179 point_coordinates: Array containing the point coordinates
180 starting_triad_guess: Optional initial guess for the starting triad.
181 If provided, this introduces a constant twist angle along the curve.
182 The twist angle is computed between:
183 - The given starting guess triad, and
184 - The automatically calculated triad, rotated onto the first basis vector
185 of the starting guess triad using the smallest rotation.
186 """
188 self.coordinates = point_coordinates.copy()
189 self.n_points = len(self.coordinates)
191 # Interpolate coordinates along piece wise linear arc length
192 point_arc_length_piecewise_linear = (
193 get_piecewise_linear_arc_length_along_points(self.coordinates)
194 )
195 centerline_interpolation_piecewise_linear = get_spline_interpolation(
196 self.coordinates, point_arc_length_piecewise_linear
197 )
198 centerline_interpolation_piecewise_linear_p = (
199 centerline_interpolation_piecewise_linear.derivative(1)
200 )
202 def ds(t):
203 """Arc length along interpolated spline."""
204 return _np.linalg.norm(centerline_interpolation_piecewise_linear_p(t))
206 # Integrate the arc length along the interpolated centerline, this will result
207 # in a more accurate centerline arc length
208 self.point_arc_length = _np.zeros(self.n_points)
209 for i in range(len(point_arc_length_piecewise_linear) - 1):
210 self.point_arc_length[i + 1] = (
211 self.point_arc_length[i]
212 + _integrate.quad(
213 ds,
214 point_arc_length_piecewise_linear[i],
215 point_arc_length_piecewise_linear[i + 1],
216 )[0]
217 )
219 # Set the interpolation of the (positional) centerline
220 self.set_centerline_interpolation()
222 # Get the quaternions along the centerline based on smallest rotation mappings
223 self.quaternions = get_quaternions_along_curve(
224 self.centerline_interpolation, self.point_arc_length
225 )
227 # Get the relative quantities used to warp the curve
228 (
229 self.relative_distances,
230 self.relative_distances_rotation,
231 self.relative_rotations,
232 ) = get_relative_distance_and_rotations(self.coordinates, self.quaternions)
234 # Check if we have to apply a twist for the rotations
235 if starting_triad_guess is not None:
236 first_rotation = _Rotation.from_quaternion(self.quaternions[0])
237 starting_triad_e1 = starting_triad_guess * [1, 0, 0]
238 if _np.dot(first_rotation * [1, 0, 0], starting_triad_e1) < 0.5:
239 raise ValueError(
240 "The angle between the first basis vectors of the guess triad you"
241 " provided and the automatically calculated one is too large,"
242 " please check your input data."
243 )
244 smallest_rotation_to_guess_tangent = _smallest_rotation(
245 first_rotation, starting_triad_e1
246 )
247 relative_rotation = (
248 smallest_rotation_to_guess_tangent.inv() * starting_triad_guess
249 )
250 psi = relative_rotation.get_rotation_vector()
251 if _np.linalg.norm(psi[1:]) > _bme.eps_quaternion:
252 raise ValueError(
253 "The twist angle can not be extracted as the relative rotation is not plane!"
254 )
255 twist_angle = psi[0]
256 self.twist(twist_angle)
258 def set_centerline_interpolation(self):
259 """Set the interpolation of the centerline based on the coordinates and
260 arc length stored in this object."""
261 self.centerline_interpolation = get_spline_interpolation(
262 self.coordinates, self.point_arc_length
263 )
265 def translate(self, vector):
266 """Translate the curve by the given vector."""
268 self.coordinates += vector
269 self.set_centerline_interpolation()
271 def rotate(self, rotation: _Rotation, *, origin=None):
272 """Rotate the curve and the quaternions."""
274 self.quaternions = rotation.get_numpy_quaternion() * self.quaternions
275 self.coordinates = _rotate_coordinates(
276 self.coordinates, rotation, origin=origin
277 )
278 self.set_centerline_interpolation()
280 def twist(self, twist_angle: float) -> None:
281 """Apply a constant twist rotation along the Cosserat curve.
283 Args:
284 twist_angle: The rotation angle (in radiants).
285 """
286 material_twist_rotation = _Rotation(
287 [1, 0, 0], twist_angle
288 ).get_numpy_quaternion()
290 self.quaternions = self.quaternions * material_twist_rotation
291 self.relative_distances_rotation = (
292 material_twist_rotation.conjugate()
293 * self.relative_distances_rotation
294 * material_twist_rotation
295 )
296 self.relative_rotations = (
297 material_twist_rotation.conjugate()
298 * self.relative_rotations
299 * material_twist_rotation
300 )
302 def get_centerline_position_and_rotation(
303 self, arc_length: float, **kwargs
304 ) -> _Tuple[_np.ndarray, _NDArray[_quaternion.quaternion]]:
305 """Return the position and rotation at a given centerline arc
306 length."""
307 pos, rot = self.get_centerline_positions_and_rotations([arc_length], **kwargs)
308 return pos[0], rot[0]
310 def get_centerline_positions_and_rotations(
311 self, points_on_arc_length, *, factor=1.0
312 ) -> _Tuple[_np.ndarray, _NDArray[_quaternion.quaternion]]:
313 """Return the position and rotation at given centerline arc lengths.
315 If the points are outside of the valid interval, a linear extrapolation will be
316 performed for the displacements and the rotations will be held constant.
318 This function also allows to scale the curvature along the curve, allowing for a
319 "natural" unwrapping of general curves in 3D. We achieve this by scaling the
320 "final" curvature along the beam and then evaluating the curve that follows this
321 curvature (this would actually require to solve an ODE, but we avoid this by
322 using a piecewise constant approximation).
324 Args
325 ----
326 points_on_arc_length: list(float)
327 A sorted list with the arc lengths along the curve centerline
328 factor: float
329 Factor to scale the curvature along the curve.
330 factor == 1
331 Use the default positions and the triads obtained via a smallest rotation mapping
332 0 <factor < 1
333 Integrate (piecewise constant as evaluated with get_relative_distance_and_rotations)
334 the scaled curvature of the curve to obtain a intuitive wrapping. (factor=0 gives
335 a straight line)
336 """
338 # Get the points that are within the arc length of the given curve.
339 points_on_arc_length = _np.asarray(points_on_arc_length)
340 points_in_bounds = _np.logical_and(
341 points_on_arc_length > self.point_arc_length[0],
342 points_on_arc_length < self.point_arc_length[-1],
343 )
344 index_in_bound = _np.where(points_in_bounds == True)[0]
345 index_out_of_bound = _np.where(points_in_bounds == False)[0]
346 points_on_arc_length_in_bound = [
347 self.point_arc_length[0],
348 *points_on_arc_length[index_in_bound],
349 self.point_arc_length[-1],
350 ]
352 if factor < (1.0 - _bme.eps_quaternion):
353 coordinates = _np.zeros_like(self.coordinates)
354 quaternions = _np.zeros_like(self.quaternions)
355 coordinates[0] = self.coordinates[0]
356 quaternions[0] = self.quaternions[0]
357 for i_segment in range(self.n_points - 1):
358 relative_distance_rotation = _quaternion.slerp_evaluate(
359 _quaternion.quaternion(1),
360 self.relative_distances_rotation[i_segment],
361 factor,
362 )
363 # In the initial configuration (factor=0) we get a straight curve, so we need
364 # to use the arc length here. In the final configuration (factor=1) we want to
365 # exactly recover the input points, so we need the piecewise linear distance.
366 # Between them, we interpolate.
367 relative_distance = (factor * self.relative_distances[i_segment]) + (
368 1.0 - factor
369 ) * (
370 self.point_arc_length[i_segment + 1]
371 - self.point_arc_length[i_segment]
372 )
373 coordinates[i_segment + 1] = (
374 _quaternion.rotate_vectors(
375 quaternions[i_segment] * relative_distance_rotation,
376 [relative_distance, 0, 0],
377 )
378 + coordinates[i_segment]
379 )
380 quaternions[i_segment + 1] = quaternions[
381 i_segment
382 ] * _quaternion.slerp_evaluate(
383 _quaternion.quaternion(1),
384 self.relative_rotations[i_segment],
385 factor,
386 )
387 arc_length_spline_interpolation = get_spline_interpolation(
388 coordinates, self.point_arc_length
389 )
390 else:
391 coordinates = self.coordinates
392 quaternions = self.quaternions
393 arc_length_spline_interpolation = self.centerline_interpolation
395 sol_r = _np.zeros([len(points_on_arc_length_in_bound), 3])
396 sol_q = _np.zeros(
397 len(points_on_arc_length_in_bound), dtype=_quaternion.quaternion
398 )
399 for i_point, centerline_arc_length in enumerate(points_on_arc_length_in_bound):
400 if (
401 centerline_arc_length >= self.point_arc_length[0]
402 and centerline_arc_length <= self.point_arc_length[-1]
403 ):
404 for i in range(1, self.n_points):
405 centerline_index = i - 1
406 if self.point_arc_length[i] > centerline_arc_length:
407 break
409 # Get the two rotation vectors and arc length values
410 arc_lengths = self.point_arc_length[
411 centerline_index : centerline_index + 2
412 ]
413 q1 = quaternions[centerline_index]
414 q2 = quaternions[centerline_index + 1]
416 # Linear interpolate the arc length
417 xi = (centerline_arc_length - arc_lengths[0]) / (
418 arc_lengths[1] - arc_lengths[0]
419 )
421 # Perform a spline interpolation for the positions and a slerp
422 # interpolation for the rotations
423 sol_r[i_point] = arc_length_spline_interpolation(centerline_arc_length)
424 sol_q[i_point] = _quaternion.slerp_evaluate(q1, q2, xi)
425 else:
426 raise ValueError("Centerline value out of bounds")
428 # Set the already computed results in the final data structures
429 sol_r_final = _np.zeros([len(points_on_arc_length), 3])
430 sol_q_final = _np.zeros(len(points_on_arc_length), dtype=_quaternion.quaternion)
431 if len(index_in_bound) > 0:
432 sol_r_final[index_in_bound] = sol_r[index_in_bound - index_in_bound[0] + 1]
433 sol_q_final[index_in_bound] = sol_q[index_in_bound - index_in_bound[0] + 1]
435 # Perform the extrapolation at both ends of the curve
436 for i in index_out_of_bound:
437 arc_length = points_on_arc_length[i]
438 if arc_length <= self.point_arc_length[0]:
439 index = 0
440 elif arc_length >= self.point_arc_length[-1]:
441 index = -1
442 else:
443 raise ValueError("Should not happen")
445 length = arc_length - self.point_arc_length[index]
446 r = sol_r[index]
447 q = sol_q[index]
448 sol_r_final[i] = r + _Rotation.from_quaternion(q) * [length, 0, 0]
449 sol_q_final[i] = q
451 return sol_r_final, sol_q_final
453 def project_point(self, p, t0=None) -> float:
454 """Project a point to the curve, return the parameter coordinate for
455 the projection point."""
457 centerline_interpolation_p = self.centerline_interpolation.derivative(1)
458 centerline_interpolation_pp = self.centerline_interpolation.derivative(2)
460 def f(t):
461 """Function to find the root of."""
462 r = self.centerline_interpolation(t)
463 rp = centerline_interpolation_p(t)
464 return _np.dot(r - p, rp)
466 def fp(t):
467 """Derivative of the Function to find the root of."""
468 r = self.centerline_interpolation(t)
469 rp = centerline_interpolation_p(t)
470 rpp = centerline_interpolation_pp(t)
471 return _np.dot(rp, rp) + _np.dot(r - p, rpp)
473 if t0 is None:
474 t0 = 0.0
476 return _optimize.newton(f, t0, fprime=fp)
478 def get_pyvista_polyline(self, *, factor: float = 1.0) -> _pv.PolyData:
479 """Create a pyvista (vtk) representation of the curve with the
480 evaluated triad basis vectors.
482 Args:
483 factor: Factor to scale the curvature along the curve (see
484 `get_centerline_positions_and_rotations` for details).
486 Returns:
487 A pyvista PolyData object representing the curve.
488 """
490 positions, rotations = self.get_centerline_positions_and_rotations(
491 self.point_arc_length, factor=factor
492 )
494 poly_line = _pv.PolyData()
495 poly_line.points = positions
496 cell = _np.arange(0, self.n_points, dtype=int)
497 cell = _np.insert(cell, 0, self.n_points)
498 poly_line.lines = cell
500 rotation_matrices = _quaternion.as_rotation_matrix(rotations)
501 for i_dir in range(3):
502 poly_line.point_data.set_array(
503 rotation_matrices[:, :, i_dir], f"base_vector_{i_dir + 1}"
504 )
506 return poly_line
508 def write_vtk(self, path) -> None:
509 """Save a vtk representation of the curve."""
510 self.get_pyvista_polyline().save(path)
512 def write_pvd_series(
513 self,
514 pvd_path: _Path | str,
515 *,
516 factors: list[float] | None = None,
517 n_steps: int | None = None,
518 binary: bool = True,
519 ) -> None:
520 """Save a pvd series representing the curve at different states.
522 Args:
523 pvd_path: Path where to save the pvd file.
524 factors: List of factors to scale the curvature along the curve. Mutually exclusive with 'n_steps'.
525 n_steps: Number of steps to create a uniform series of factors. Mutually exclusive with 'factors'.
526 binary: If True, save the vtk files in binary format.
527 """
529 pvd_path = _Path(pvd_path)
530 if pvd_path.suffix != ".pvd":
531 raise ValueError(
532 f"The output path must have a .pvd suffix, got {pvd_path.suffix}"
533 )
535 if factors is not None and n_steps is not None:
536 raise ValueError(
537 "The keyword arguments 'factors' and 'n_steps' are mutually exclusive."
538 )
539 if factors is None and n_steps is None:
540 raise ValueError(
541 "One of the keyword arguments 'factors' or 'n_steps' must be provided."
542 )
543 if factors is None:
544 factors = _np.linspace(0.0, 1.0, num=n_steps)
546 pvd_file = _ET.Element("VTKFile", type="Collection", version="0.1")
547 collection = _ET.SubElement(pvd_file, "Collection")
548 width = max(1, len(str(len(factors) - 1)))
549 for i_step, factor in enumerate(factors):
550 # TODO: Check if we can use vtp here instead of vtu. Currently this does
551 # not work with how we compare files in testing. Since vtu and vtp are
552 # basically the same in this case, this solution is fine at the moment.
553 factor_file = pvd_path.parent / f"{pvd_path.stem}.{i_step:0{width}d}.vtu"
554 _pv.UnstructuredGrid(self.get_pyvista_polyline(factor=factor)).save(
555 factor_file, binary=binary
556 )
557 _ET.SubElement(
558 collection,
559 "DataSet",
560 timestep=str(factor),
561 group="",
562 part="0",
563 file=str(factor_file.relative_to(pvd_path.parent)),
564 )
566 tree = _ET.ElementTree(pvd_file)
567 _ET.indent(tree, space=" ", level=0)
568 tree.write(pvd_path, encoding="utf-8", xml_declaration=True)