|
import torch |
|
import data_utils as du |
|
|
|
def run_inference_two_channels(coil_complex_image, autoencoder, device="cuda"): |
|
coil_complex_image = du.normalize_complex_coil_image(coil_complex_image) |
|
two_channel_image = du.complex_to_two_channel_image(coil_complex_image) |
|
two_channel_tensor = torch.from_numpy(two_channel_image)[None,...].type(torch.FloatTensor).to(device) |
|
autoencoder = autoencoder.to(device) |
|
with torch.no_grad(): |
|
autoencoder_output = autoencoder.encode(two_channel_tensor) |
|
latents = autoencoder_output.latent_dist.mean |
|
decoded_image = autoencoder.decode(latents).sample |
|
recon = du.two_channel_to_complex_image(decoded_image.detach().cpu().numpy()) |
|
input = coil_complex_image |
|
return input, recon |
|
|
|
def run_inference_three_channels(coil_complex_image, autoencoder, device="cuda"): |
|
coil_complex_image = du.normalize_complex_coil_image(coil_complex_image) |
|
three_channel_image = du.create_three_channel_image(coil_complex_image) |
|
three_channel_tensor = torch.from_numpy(three_channel_image)[None,...].type(torch.FloatTensor).to(device) |
|
autoencoder = autoencoder.to(device) |
|
with torch.no_grad(): |
|
autoencoder_output = autoencoder.encode(three_channel_tensor) |
|
latents = autoencoder_output.latent_dist.mean |
|
decoded_image = autoencoder.decode(latents).sample |
|
recon = decoded_image[0].detach().cpu().numpy() |
|
input = three_channel_image |
|
return input, recon |
|
|