Pytorch compatibility (experimental)¶
The Wavefront and WavefrontK generally use numpy arrays. This demonstrates compatibility with PyTorch arrays on the CPU and GPU.
In [ ]:
Copied!
import torch
import numpy as np
from pmd_beamphysics import Wavefront
from pmd_beamphysics.wavefront.propagators import drift_wavefront
import torch
import numpy as np
from pmd_beamphysics import Wavefront
from pmd_beamphysics.wavefront.propagators import drift_wavefront
In [ ]:
Copied!
W_numpy = Wavefront(Ex=np.ones((255, 255, 255), dtype=np.complex64))
W_torch_cpu = Wavefront(
Ex=torch.ones((255, 255, 255), device="cpu", dtype=torch.complex64),
)
W_torch_gpu = Wavefront(
Ex=torch.ones((255, 255, 255), device="mps", dtype=torch.complex64),
)
W_numpy = Wavefront(Ex=np.ones((255, 255, 255), dtype=np.complex64))
W_torch_cpu = Wavefront(
Ex=torch.ones((255, 255, 255), device="cpu", dtype=torch.complex64),
)
W_torch_gpu = Wavefront(
Ex=torch.ones((255, 255, 255), device="mps", dtype=torch.complex64),
)
In [ ]:
Copied!
%%time
W_numpy2 = drift_wavefront(W_numpy, 1)
%%time
W_numpy2 = drift_wavefront(W_numpy, 1)
In [ ]:
Copied!
%%time
W_torch_cpu2 = drift_wavefront(W_torch_cpu, 1, backend=torch)
%%time
W_torch_cpu2 = drift_wavefront(W_torch_cpu, 1, backend=torch)
In [ ]:
Copied!
%%time
W_torch_gpu2 = drift_wavefront(W_torch_gpu, 1, device="mps", backend=torch)
%%time
W_torch_gpu2 = drift_wavefront(W_torch_gpu, 1, device="mps", backend=torch)
In [ ]:
Copied!
torch.allclose(W_torch_gpu2.Ex.cpu(), W_torch_cpu2.Ex)
torch.allclose(W_torch_gpu2.Ex.cpu(), W_torch_cpu2.Ex)
In [ ]:
Copied!
np.allclose(W_torch_gpu2.Ex.cpu(), W_numpy2.Ex)
np.allclose(W_torch_gpu2.Ex.cpu(), W_numpy2.Ex)