Coverage for src / beamme / cosserat_curve / cosserat_curve.py: 98%
199 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-06 06:24 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-06 06:24 +0000
1# The MIT License (MIT)
2#
3# Copyright (c) 2018-2025 BeamMe Authors
4#
5# Permission is hereby granted, free of charge, to any person obtaining a copy
6# of this software and associated documentation files (the "Software"), to deal
7# in the Software without restriction, including without limitation the rights
8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9# copies of the Software, and to permit persons to whom the Software is
10# furnished to do so, subject to the following conditions:
11#
12# The above copyright notice and this permission notice shall be included in
13# all copies or substantial portions of the Software.
14#
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21# THE SOFTWARE.
22"""Define a Cosserat curve object that can be used to describe warping of
23curve-like objects."""
25from pathlib import Path as _Path
26from typing import Optional as _Optional
27from typing import Tuple as _Tuple
28from xml.etree import ElementTree as _ET # nosec B405
30import numpy as _np
31import pyvista as _pv
32import quaternion as _quaternion
33from numpy.typing import NDArray as _NDArray
34from scipy import integrate as _integrate
35from scipy import interpolate as _interpolate
36from scipy import optimize as _optimize
38from beamme.core.conf import bme as _bme
39from beamme.core.rotation import Rotation as _Rotation
40from beamme.core.rotation import rotate_coordinates as _rotate_coordinates
41from beamme.core.rotation import smallest_rotation as _smallest_rotation
44def get_piecewise_linear_arc_length_along_points(
45 coordinates: _np.ndarray,
46) -> _np.ndarray:
47 """Return the accumulated distance between the points.
49 Args
50 ----
51 coordinates:
52 Array containing the point coordinates
53 """
55 n_points = len(coordinates)
56 point_distance = _np.linalg.norm(coordinates[1:] - coordinates[:-1], axis=1)
57 point_arc_length = _np.zeros(n_points)
58 for i in range(1, n_points):
59 point_arc_length[i] = point_arc_length[i - 1] + point_distance[i - 1]
60 return point_arc_length
63def get_spline_interpolation(
64 coordinates: _np.ndarray, point_arc_length: _np.ndarray
65) -> _interpolate.BSpline:
66 """Get a spline interpolation of the given points.
68 Args
69 ----
70 coordinates:
71 Array containing the point coordinates
72 point_arc_length:
73 Arc length for each coordinate
75 Return
76 ----
77 centerline_interpolation:
78 The spline interpolation object
79 """
81 # Interpolate coordinates along arc length
82 # Note: The numeric evaluation of the spline interpolation can depend on the
83 # operating system, thus introducing slight numerical differences (~1e-12).
84 centerline_interpolation = _interpolate.make_interp_spline(
85 point_arc_length, coordinates
86 )
87 return centerline_interpolation
90def get_quaternions_along_curve(
91 centerline: _interpolate.BSpline, point_arc_length: _np.ndarray
92) -> _NDArray[_quaternion.quaternion]:
93 """Get the quaternions along the curve based on smallest rotation mappings.
95 The initial rotation will be calculated based on the largest projection of the initial tangent
96 onto the cartesian basis vectors.
98 Args
99 ----
100 centerline:
101 A function that returns the centerline position for a parameter coordinate t
102 point_arc_length:
103 Array of parameter coordinates for which the quaternions should be calculated
104 """
106 centerline_interpolation_derivative = centerline.derivative()
108 def basis(i):
109 """Return the i-th Cartesian basis vector."""
110 basis = _np.zeros([3])
111 basis[i] = 1.0
112 return basis
114 # Get the reference rotation
115 t0 = centerline_interpolation_derivative(point_arc_length[0])
116 min_projection = _np.argmin(_np.abs([_np.dot(basis(i), t0) for i in range(3)]))
117 last_rotation = _Rotation.from_basis(t0, basis(min_projection))
119 # Get the rotation vectors along the curve. They are calculated with smallest rotation mappings.
120 n_points = len(point_arc_length)
121 quaternions = _np.zeros(n_points, dtype=_quaternion.quaternion)
122 quaternions[0] = last_rotation.q
123 for i in range(1, n_points):
124 rotation = _smallest_rotation(
125 last_rotation,
126 centerline_interpolation_derivative(point_arc_length[i]),
127 )
128 quaternions[i] = rotation.q
129 last_rotation = rotation
130 return quaternions
133def get_relative_distance_and_rotations(
134 coordinates: _np.ndarray, quaternions: _NDArray[_quaternion.quaternion]
135) -> _Tuple[
136 _np.ndarray, _NDArray[_quaternion.quaternion], _NDArray[_quaternion.quaternion]
137]:
138 """Get relative distances and rotations that can be used to evaluate
139 "intermediate" states of the Cosserat curve."""
141 n_points = len(coordinates)
142 relative_distances = _np.zeros(n_points - 1)
143 relative_distances_rotation = _np.zeros(n_points - 1, dtype=_quaternion.quaternion)
144 relative_rotations = _np.zeros(n_points - 1, dtype=_quaternion.quaternion)
146 for i_segment in range(n_points - 1):
147 relative_distance = coordinates[i_segment + 1] - coordinates[i_segment]
148 relative_distance_local = _quaternion.rotate_vectors(
149 quaternions[i_segment].conjugate(), relative_distance
150 )
151 relative_distances[i_segment] = _np.linalg.norm(relative_distance_local)
153 smallest_relative_rotation_onto_distance = _smallest_rotation(
154 _Rotation(),
155 relative_distance_local,
156 )
157 relative_distances_rotation[i_segment] = (
158 smallest_relative_rotation_onto_distance.get_numpy_quaternion()
159 )
161 relative_rotations[i_segment] = (
162 quaternions[i_segment].conjugate() * quaternions[i_segment + 1]
163 )
165 return relative_distances, relative_distances_rotation, relative_rotations
168class CosseratCurve(object):
169 """Represent a Cosserat curve in space."""
171 def __init__(
172 self,
173 point_coordinates: _np.ndarray,
174 *,
175 starting_triad_guess: _Optional[_Rotation] = None,
176 ):
177 """Initialize the Cosserat curve based on points in 3D space.
179 Args:
180 point_coordinates: Array containing the point coordinates
181 starting_triad_guess: Optional initial guess for the starting triad.
182 If provided, this introduces a constant twist angle along the curve.
183 The twist angle is computed between:
184 - The given starting guess triad, and
185 - The automatically calculated triad, rotated onto the first basis vector
186 of the starting guess triad using the smallest rotation.
187 """
189 self.coordinates = point_coordinates.copy()
190 self.n_points = len(self.coordinates)
192 # Interpolate coordinates along piece wise linear arc length
193 point_arc_length_piecewise_linear = (
194 get_piecewise_linear_arc_length_along_points(self.coordinates)
195 )
196 centerline_interpolation_piecewise_linear = get_spline_interpolation(
197 self.coordinates, point_arc_length_piecewise_linear
198 )
199 centerline_interpolation_piecewise_linear_p = (
200 centerline_interpolation_piecewise_linear.derivative(1)
201 )
203 def ds(t):
204 """Arc length along interpolated spline."""
205 return _np.linalg.norm(centerline_interpolation_piecewise_linear_p(t))
207 # Integrate the arc length along the interpolated centerline, this will result
208 # in a more accurate centerline arc length
209 self.point_arc_length = _np.zeros(self.n_points)
210 for i in range(len(point_arc_length_piecewise_linear) - 1):
211 self.point_arc_length[i + 1] = (
212 self.point_arc_length[i]
213 + _integrate.quad(
214 ds,
215 point_arc_length_piecewise_linear[i],
216 point_arc_length_piecewise_linear[i + 1],
217 )[0]
218 )
220 # Set the interpolation of the (positional) centerline
221 self.set_centerline_interpolation()
223 # Get the quaternions along the centerline based on smallest rotation mappings
224 self.quaternions = get_quaternions_along_curve(
225 self.centerline_interpolation, self.point_arc_length
226 )
228 # Get the relative quantities used to warp the curve
229 (
230 self.relative_distances,
231 self.relative_distances_rotation,
232 self.relative_rotations,
233 ) = get_relative_distance_and_rotations(self.coordinates, self.quaternions)
235 # Check if we have to apply a twist for the rotations
236 if starting_triad_guess is not None:
237 first_rotation = _Rotation.from_quaternion(self.quaternions[0])
238 starting_triad_e1 = starting_triad_guess * [1, 0, 0]
239 if _np.dot(first_rotation * [1, 0, 0], starting_triad_e1) < 0.5:
240 raise ValueError(
241 "The angle between the first basis vectors of the guess triad you"
242 " provided and the automatically calculated one is too large,"
243 " please check your input data."
244 )
245 smallest_rotation_to_guess_tangent = _smallest_rotation(
246 first_rotation, starting_triad_e1
247 )
248 relative_rotation = (
249 smallest_rotation_to_guess_tangent.inv() * starting_triad_guess
250 )
251 psi = relative_rotation.get_rotation_vector()
252 if _np.linalg.norm(psi[1:]) > _bme.eps_quaternion:
253 raise ValueError(
254 "The twist angle can not be extracted as the relative rotation is not plane!"
255 )
256 twist_angle = psi[0]
257 self.twist(twist_angle)
259 def set_centerline_interpolation(self):
260 """Set the interpolation of the centerline based on the coordinates and
261 arc length stored in this object."""
262 self.centerline_interpolation = get_spline_interpolation(
263 self.coordinates, self.point_arc_length
264 )
266 def translate(self, vector):
267 """Translate the curve by the given vector."""
269 self.coordinates += vector
270 self.set_centerline_interpolation()
272 def rotate(self, rotation: _Rotation, *, origin=None):
273 """Rotate the curve and the quaternions."""
275 self.quaternions = rotation.get_numpy_quaternion() * self.quaternions
276 self.coordinates = _rotate_coordinates(
277 self.coordinates, rotation, origin=origin
278 )
279 self.set_centerline_interpolation()
281 def twist(self, twist_angle: float) -> None:
282 """Apply a constant twist rotation along the Cosserat curve.
284 Args:
285 twist_angle: The rotation angle (in radiants).
286 """
287 material_twist_rotation = _Rotation(
288 [1, 0, 0], twist_angle
289 ).get_numpy_quaternion()
291 self.quaternions = self.quaternions * material_twist_rotation
292 self.relative_distances_rotation = (
293 material_twist_rotation.conjugate()
294 * self.relative_distances_rotation
295 * material_twist_rotation
296 )
297 self.relative_rotations = (
298 material_twist_rotation.conjugate()
299 * self.relative_rotations
300 * material_twist_rotation
301 )
303 def get_centerline_position_and_rotation(
304 self, arc_length: float, **kwargs
305 ) -> _Tuple[_np.ndarray, _NDArray[_quaternion.quaternion]]:
306 """Return the position and rotation at a given centerline arc
307 length."""
308 pos, rot = self.get_centerline_positions_and_rotations([arc_length], **kwargs)
309 return pos[0], rot[0]
311 def get_centerline_positions_and_rotations(
312 self, points_on_arc_length, *, factor=1.0
313 ) -> _Tuple[_np.ndarray, _NDArray[_quaternion.quaternion]]:
314 """Return the position and rotation at given centerline arc lengths.
316 If the points are outside of the valid interval, a linear extrapolation will be
317 performed for the displacements and the rotations will be held constant.
319 This function also allows to scale the curvature along the curve, allowing for a
320 "natural" unwrapping of general curves in 3D. We achieve this by scaling the
321 "final" curvature along the beam and then evaluating the curve that follows this
322 curvature (this would actually require to solve an ODE, but we avoid this by
323 using a piecewise constant approximation).
325 Args
326 ----
327 points_on_arc_length: list(float)
328 A sorted list with the arc lengths along the curve centerline
329 factor: float
330 Factor to scale the curvature along the curve.
331 factor == 1
332 Use the default positions and the triads obtained via a smallest rotation mapping
333 0 <factor < 1
334 Integrate (piecewise constant as evaluated with get_relative_distance_and_rotations)
335 the scaled curvature of the curve to obtain a intuitive wrapping. (factor=0 gives
336 a straight line)
337 """
339 # Get the points that are within the arc length of the given curve.
340 points_on_arc_length = _np.asarray(points_on_arc_length)
341 points_in_bounds = _np.logical_and(
342 points_on_arc_length > self.point_arc_length[0],
343 points_on_arc_length < self.point_arc_length[-1],
344 )
345 index_in_bound = _np.where(points_in_bounds == True)[0]
346 index_out_of_bound = _np.where(points_in_bounds == False)[0]
347 points_on_arc_length_in_bound = [
348 self.point_arc_length[0],
349 *points_on_arc_length[index_in_bound],
350 self.point_arc_length[-1],
351 ]
353 if factor < (1.0 - _bme.eps_quaternion):
354 coordinates = _np.zeros_like(self.coordinates)
355 quaternions = _np.zeros_like(self.quaternions)
356 coordinates[0] = self.coordinates[0]
357 quaternions[0] = self.quaternions[0]
358 for i_segment in range(self.n_points - 1):
359 relative_distance_rotation = _quaternion.slerp_evaluate(
360 _quaternion.quaternion(1),
361 self.relative_distances_rotation[i_segment],
362 factor,
363 )
364 # In the initial configuration (factor=0) we get a straight curve, so we need
365 # to use the arc length here. In the final configuration (factor=1) we want to
366 # exactly recover the input points, so we need the piecewise linear distance.
367 # Between them, we interpolate.
368 relative_distance = (factor * self.relative_distances[i_segment]) + (
369 1.0 - factor
370 ) * (
371 self.point_arc_length[i_segment + 1]
372 - self.point_arc_length[i_segment]
373 )
374 coordinates[i_segment + 1] = (
375 _quaternion.rotate_vectors(
376 quaternions[i_segment] * relative_distance_rotation,
377 [relative_distance, 0, 0],
378 )
379 + coordinates[i_segment]
380 )
381 quaternions[i_segment + 1] = quaternions[
382 i_segment
383 ] * _quaternion.slerp_evaluate(
384 _quaternion.quaternion(1),
385 self.relative_rotations[i_segment],
386 factor,
387 )
388 else:
389 coordinates = self.coordinates
390 quaternions = self.quaternions
392 sol_r = _np.zeros([len(points_on_arc_length_in_bound), 3])
393 sol_q = _np.zeros(
394 len(points_on_arc_length_in_bound), dtype=_quaternion.quaternion
395 )
396 arc_length_spline_interpolation = get_spline_interpolation(
397 coordinates, self.point_arc_length
398 )
399 for i_point, centerline_arc_length in enumerate(points_on_arc_length_in_bound):
400 if (
401 centerline_arc_length >= self.point_arc_length[0]
402 and centerline_arc_length <= self.point_arc_length[-1]
403 ):
404 for i in range(1, self.n_points):
405 centerline_index = i - 1
406 if self.point_arc_length[i] > centerline_arc_length:
407 break
409 # Get the two rotation vectors and arc length values
410 arc_lengths = self.point_arc_length[
411 centerline_index : centerline_index + 2
412 ]
413 q1 = quaternions[centerline_index]
414 q2 = quaternions[centerline_index + 1]
416 # Linear interpolate the arc length
417 xi = (centerline_arc_length - arc_lengths[0]) / (
418 arc_lengths[1] - arc_lengths[0]
419 )
421 # Perform a spline interpolation for the positions and a slerp
422 # interpolation for the rotations
423 sol_r[i_point] = arc_length_spline_interpolation(centerline_arc_length)
424 sol_q[i_point] = _quaternion.slerp_evaluate(q1, q2, xi)
425 else:
426 raise ValueError("Centerline value out of bounds")
428 # Set the already computed results in the final data structures
429 sol_r_final = _np.zeros([len(points_on_arc_length), 3])
430 sol_q_final = _np.zeros(len(points_on_arc_length), dtype=_quaternion.quaternion)
431 if len(index_in_bound) > 0:
432 sol_r_final[index_in_bound] = sol_r[index_in_bound - index_in_bound[0] + 1]
433 sol_q_final[index_in_bound] = sol_q[index_in_bound - index_in_bound[0] + 1]
435 # Perform the extrapolation at both ends of the curve
436 for i in index_out_of_bound:
437 arc_length = points_on_arc_length[i]
438 if arc_length <= self.point_arc_length[0]:
439 index = 0
440 elif arc_length >= self.point_arc_length[-1]:
441 index = -1
442 else:
443 raise ValueError("Should not happen")
445 length = arc_length - self.point_arc_length[index]
446 r = sol_r[index]
447 q = sol_q[index]
448 sol_r_final[i] = r + _Rotation.from_quaternion(q) * [length, 0, 0]
449 sol_q_final[i] = q
451 return sol_r_final, sol_q_final
453 def project_point(self, p, t0=None) -> float:
454 """Project a point to the curve, return the parameter coordinate for
455 the projection point."""
457 centerline_interpolation_p = self.centerline_interpolation.derivative(1)
458 centerline_interpolation_pp = self.centerline_interpolation.derivative(2)
460 def f(t):
461 """Function to find the root of."""
462 r = self.centerline_interpolation(t)
463 rp = centerline_interpolation_p(t)
464 return _np.dot(r - p, rp)
466 def fp(t):
467 """Derivative of the Function to find the root of."""
468 r = self.centerline_interpolation(t)
469 rp = centerline_interpolation_p(t)
470 rpp = centerline_interpolation_pp(t)
471 return _np.dot(rp, rp) + _np.dot(r - p, rpp)
473 if t0 is None:
474 t0 = 0.0
476 return _optimize.newton(f, t0, fprime=fp)
478 def get_pyvista_polyline(self, *, factor: float = 1.0) -> _pv.PolyData:
479 """Create a pyvista (vtk) representation of the curve with the
480 evaluated triad basis vectors.
482 Args:
483 factor: Factor to scale the curvature along the curve (see
484 `get_centerline_positions_and_rotations` for details).
486 Returns:
487 A pyvista PolyData object representing the curve.
488 """
490 positions, rotations = self.get_centerline_positions_and_rotations(
491 self.point_arc_length, factor=factor
492 )
494 poly_line = _pv.PolyData()
495 poly_line.points = positions
496 cell = _np.arange(0, self.n_points, dtype=int)
497 cell = _np.insert(cell, 0, self.n_points)
498 poly_line.lines = cell
500 rotation_matrices = _quaternion.as_rotation_matrix(rotations)
501 for i_dir in range(3):
502 poly_line.point_data.set_array(
503 rotation_matrices[:, :, i_dir], f"base_vector_{i_dir + 1}"
504 )
506 return poly_line
508 def write_vtk(self, path) -> None:
509 """Save a vtk representation of the curve."""
510 self.get_pyvista_polyline().save(path)
512 def write_pvd_series(
513 self,
514 pvd_path: _Path | str,
515 *,
516 factors: list[float] | None = None,
517 n_steps: int | None = None,
518 binary: bool = True,
519 ) -> None:
520 """Save a pvd series representing the curve at different states.
522 Args:
523 pvd_path: Path where to save the pvd file.
524 factors: List of factors to scale the curvature along the curve. Mutually exclusive with 'n_steps'.
525 n_steps: Number of steps to create a uniform series of factors. Mutually exclusive with 'factors'.
526 binary: If True, save the vtk files in binary format.
527 """
529 pvd_path = _Path(pvd_path)
530 if pvd_path.suffix != ".pvd":
531 raise ValueError(
532 f"The output path must have a .pvd suffix, got {pvd_path.suffix}"
533 )
535 if factors is not None and n_steps is not None:
536 raise ValueError(
537 "The keyword arguments 'factors' and 'n_steps' are mutually exclusive."
538 )
539 if factors is None and n_steps is None:
540 raise ValueError(
541 "One of the keyword arguments 'factors' or 'n_steps' must be provided."
542 )
543 if factors is None:
544 factors = _np.linspace(0.0, 1.0, num=n_steps)
546 pvd_file = _ET.Element("VTKFile", type="Collection", version="0.1")
547 collection = _ET.SubElement(pvd_file, "Collection")
548 width = max(1, len(str(len(factors) - 1)))
549 for i_step, factor in enumerate(factors):
550 # TODO: Check if we can use vtp here instead of vtu. Currently this does
551 # not work with how we compare files in testing. Since vtu and vtp are
552 # basically the same in this case, this solution is fine at the moment.
553 factor_file = pvd_path.parent / f"{pvd_path.stem}.{i_step:0{width}d}.vtu"
554 _pv.UnstructuredGrid(self.get_pyvista_polyline(factor=factor)).save(
555 factor_file, binary=binary
556 )
557 _ET.SubElement(
558 collection,
559 "DataSet",
560 timestep=str(factor),
561 group="",
562 part="0",
563 file=str(factor_file.relative_to(pvd_path.parent)),
564 )
566 tree = _ET.ElementTree(pvd_file)
567 _ET.indent(tree, space=" ", level=0)
568 tree.write(pvd_path, encoding="utf-8", xml_declaration=True)