Search quantum information protocols with LOCCNet¶
Table of Contents¶
[1]:
import time
import torch
import math
import numpy as np
import matplotlib.pyplot as plt
import quairkit as qkit
from quairkit import to_state
from quairkit.application import OneWayLOCCNet
from quairkit.database import *
from quairkit.loss import *
from quairkit.operator import *
from quairkit.qinfo import *
qkit.set_dtype('complex128')
plt.rcParams['mathtext.fontset'] = 'cm'
plt.rcParams['font.family'] = 'STIXGeneral'
Introduction¶
This tutorial explores the application of LOCCNet (Local Operations and Classical Communication Network) in quantum information protocols, particularly in the context of distributed quantum information processing. We will focus on two fundamental tasks[1]:
Quantum State Discrimination: Distinguishing between different quantum states under the restriction of LOCC, specifically discriminating between a perfect Bell state and its noisy counterpart after passing through amplitude damping channels.
Quantum State Distillation: Purifying entangled states to obtain higher-fidelity entanglement using LOCC operations.
Throughout this tutorial, we will demonstrate the practical steps for implementing these protocols. This includes setting up a OneWayLOCCNet for multiple parties, configuring the initial quantum states, and building the parameterized quantum circuits. We will also cover the implementation of parameterized LOCC operations, the visualization of the resulting circuits, and the methods for training and optimizing the network.
Quantum State Discrimination¶
Problem Description¶
As demonstrated in [1], distinguishing multipartite quantum states under the restriction of local operations and classical communication (LOCC) is both practically important and theoretically challenging. In particular, the authors focus on discriminating between a perfect Bell state and its noisy counterpart after each qubit has passed through an amplitude damping channel.
Here, to explore the power of LOCCNet in state discrimination, we consider the two ideal Bell states:
and an amplitude damping channel \(\mathcal{A}\) with noise parameter \(\gamma\), defined by:
where the Kraus operators are:
Sending each qubit of \(|\Phi^{-}\rangle\) through \(\mathcal{A}\) yields the mixed state:
while:
remains noiseless. The goal is to distinguish between \(\Phi_0\) and \(\Phi_1\) via LOCC.
The ansatz used for finding QSD protocols with LOCCNet works as follows: the input state is randomly chosen from the set \(\{ \Phi_0, \Phi_1 \}\) with equal probability. Alice performs a unitary gate on her qubit and measures. Then Bob performs on his qubit a unitary gate chosen based on Alice’s measurement result. Bob’s measurement outcome is supposed to tell which state the input state is.
Implementation with LOCCNet¶
Now we create the OneWayLOCCNet structure. We have two spatially separated parties:
Alice: Has 1 qubit, implements local operations, and performs the first measurement
Bob: Has 1 qubit and performs adaptive operations based on Alice’s measurement
The protocol works as follows:
Alice applies a parameterized unitary gate on her qubit
Alice measures her qubit
Bob receives the classical measurement result
Bob applies a conditional parameterized unitary based on Alice’s outcome
[2]:
def create_locc_network() -> OneWayLOCCNet:
r"""
Create a one-way LOCC qubit-network for quantum state discrimination.
Returns:
An initialized LOCC network
"""
party_info = {'Alice': 1, 'Bob': 1}
net = OneWayLOCCNet(party_info)
net['Alice'].u3([0])
net.param_locc(u3, 3, [('Alice', 0),('Bob', 0)], label='M', support_batch=False)
return net
An example circuit diagram is shown below:
[3]:
create_locc_network().physical_circuit.plot()
Important Note: In OneWayLOCC, once Alice has measured her first qubit, it cannot be measured again. This restriction ensures we maintain the one-way nature of the classical communication. The following circuit configuration would be forbidden:

Forbidden Circuit.
Next, we define the quantum states to be discriminated using batch processing:
[4]:
zero2 = torch.tensor([1, 0], dtype=torch.cdouble)
one2 = torch.tensor([0, 1], dtype=torch.cdouble)
basis00 = torch.kron(zero2, zero2)
basis11 = torch.kron(one2, one2)
psi_minus = (basis00 - basis11) / math.sqrt(2)
rho_minus = psi_minus.unsqueeze(1) @ psi_minus.conj().unsqueeze(0)
def generate_single_pair(gamma: float) -> tuple[torch.Tensor, torch.Tensor]:
r"""
Generate a pair of quantum states: Bell state and amplitude-damped state.
Args:
gamma: Amplitude damping parameter (0 <= gamma <= 1)
Returns:
tuple: (Bell state density matrix, amplitude-damped state density matrix)
"""
state_minus = to_state(rho_minus)
kraus_ops = amplitude_damping_kraus(gamma)
state_minus = state_minus.transform(kraus_ops, sys_idx=[0], repr_type="kraus")
state_minus = state_minus.transform(kraus_ops, sys_idx=[1], repr_type="kraus")
rho_bell = bell_state(2).density_matrix
rho_damp = state_minus.density_matrix
return rho_bell, rho_damp
def generate_initial_states_locc(gamma: float) -> qkit.State:
r"""Generate batch of initial states for LOCC network training.
Args:
gamma: Amplitude damping parameter
Returns:
Batch state containing Bell and damped states
"""
state_bell, state_damp = generate_single_pair(gamma)
batch = torch.stack([state_bell, state_damp])
return to_state(batch)
The loss function of this state discrimination protocol is defined as:
where \(P(j\mid \Phi_k)\) is the probability that Bob’s final measurement yields outcome \(j\) given input state \(\Phi_k\). This loss function represents the total error probability in the discrimination task:
\(P(1 \mid \Phi_0)\): Probability of incorrectly identifying the Bell state \(\Phi_0\) as the noisy state
\(P(0 \mid \Phi_1)\): Probability of incorrectly identifying the noisy state \(\Phi_1\) as the Bell state
The implementation works as follows:
Set the initial shared quantum states (\(\Phi_0\) and \(\Phi_1\)) using
net.set_init_state. States that are not set will default to the zero stateExecute the LOCC protocol where Alice measures first
Bob performs conditional operations based on Alice’s outcome
Measure Bob’s qubit to determine the state
Calculate the weighted measurement results across all branches (Alice’s measurement outcomes)
Minimize the loss function to maximize discrimination success probability
We first marginalize the measurement probabilities over all possible branches from Alice’s measurement:
[5]:
def compute_marginalized_probabilities(
prob_tensor: torch.Tensor, probs_batch: torch.Tensor
) -> torch.Tensor:
r"""Compute marginalized and normalized measurement probabilities.
Args:
prob_tensor: Probability tensor with shape (batch_size, ..., 2)
probs_batch: Conditional probabilities with same shape structure
Returns:
Normalized probabilities with shape (batch_size, 2)
"""
branch_weights = prob_tensor.sum(dim=2)
marginalized_probs = torch.einsum('bi,bij->bj', branch_weights, probs_batch)
norm_sum = marginalized_probs.sum(dim=1, keepdim=True)
return marginalized_probs / norm_sum
To discriminate between the two quantum states, we classify them based on Bob’s measurement outcome: an outcome of 0 corresponds to the \(\Phi_0\) state (Bell state), and an outcome of 1 corresponds to the \(\Phi_1\) state (damped state).
Our objective is to minimize the loss, which is defined as the average probability of misidentifying the state defined in (6).
[6]:
def compute_discrimination_loss(normalized_probs: torch.Tensor) -> torch.Tensor:
r"""Compute the average classification probability for state discrimination.
Based on Bob's measurement outcome: 0 → Bell state, 1 → damped state.
The function computes the average correct classification probability weighted by
the prior distribution (0.5 each).
Args:
normalized_probs: Shape (2, 2) tensor where normalized_probs[i, j] = P(outcome j | state i)
Returns:
Weighted classification score
"""
p0 = normalized_probs[0].real
p1 = normalized_probs[1].real
F0_bell = p0[0]
T0_bell = p0[1]
eps = 1e-10
p0_prob = F0_bell / (F0_bell + T0_bell + eps)
F1_damped = p1[1]
T1_damped = p1[0]
p1_prob = F1_damped / (F1_damped + T1_damped + eps)
return 0.5 * p0_prob + 0.5 * p1_prob
The main loss function combines the above components:
[7]:
meas = Measure('z')
def loss_func_locc(net: OneWayLOCCNet) -> tuple[torch.Tensor, qkit.State]:
r"""Compute loss function for LOCC network training.
Args:
net: The LOCC network
inputs: Input states (batch of Bell and damped states)
Returns:
tuple: (loss value, output state)
"""
output_state = net()
probs_batch, output_state = meas(output_state, qubits_idx=[0], keep_state=True)
prob_tensor = output_state.probability
normalized_probs = compute_marginalized_probabilities(prob_tensor, probs_batch)
return compute_discrimination_loss(normalized_probs)
Now we train the network and analyze how the discrimination success probability varies with the amplitude damping parameter \(\gamma\):
[8]:
def train_locc_network(num_itr: int, lr: float, gamma: float) -> tuple[float, OneWayLOCCNet]:
r"""
Train the LOCC network for state discrimination.
Args:
num_itr: Number of training iterations
lr: Learning rate
gamma: Amplitude damping parameter
Returns:
tuple: (best loss value, trained network)
"""
net = create_locc_network()
inputs = generate_initial_states_locc(gamma)
net.set_init_state([('Alice', 0), ('Bob', 0)], inputs)
time_list = []
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5)
best_loss = float('inf')
best_state = None
for itr in range(num_itr):
start_time = time.time()
optimizer.zero_grad()
loss = loss_func_locc(net)
loss.backward()
optimizer.step()
loss = loss.item()
scheduler.step(loss)
time_list.append(time.time() - start_time)
if loss < best_loss:
best_loss = loss
best_state = {k: v.cpu().clone() for k, v in net.state_dict().items()}
if itr % 100 == 0 or itr == num_itr - 1:
print(f"iter: {str(itr).zfill(len(str(num_itr)))}, " +
f"loss: {loss:.4f}, best = {best_loss:.4f}, " +
f"lr: {scheduler.get_last_lr()[0]:.2E}, avg_time: {np.mean(time_list):.4f}s")
if best_state:
net.load_state_dict(best_state)
return best_loss, net
[9]:
NUM_ITR = 200
LR = 0.15
START_POINT = 0
gamma_vals_full = np.linspace(0.0, 1.0, 11)
gamma_vals = gamma_vals_full[START_POINT:]
results = []
print(f"Training LOCC network for gamma values: {gamma_vals}\n")
for idx, gamma in enumerate(gamma_vals):
print(f"\nTraining gamma = {gamma:.2f} ({idx+1}/{len(gamma_vals)})")
best_loss, trained_net = train_locc_network(NUM_ITR, LR, gamma)
sp = 1 - best_loss
results.append(sp)
print(f"Result: success_prob = {sp:.6f}\n")
print("\nFinal Results:")
for gamma, sp in zip(gamma_vals, results):
print(f"gamma = {gamma:.2f}, success_prob = {sp:.6f}")
Training LOCC network for gamma values: [0. 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1. ]
Training gamma = 0.00 (1/11)
iter: 000, loss: 0.7192, best = 0.7192, lr: 1.50E-01, avg_time: 0.0256s
iter: 100, loss: 0.0000, best = 0.0000, lr: 3.75E-02, avg_time: 0.0085s
iter: 199, loss: 0.0000, best = 0.0000, lr: 9.37E-03, avg_time: 0.0100s
Result: success_prob = 1.000000
Training gamma = 0.10 (2/11)
iter: 000, loss: 0.5003, best = 0.5003, lr: 1.50E-01, avg_time: 0.0117s
iter: 100, loss: 0.0243, best = 0.0243, lr: 4.69E-03, avg_time: 0.0117s
iter: 199, loss: 0.0243, best = 0.0243, lr: 9.16E-06, avg_time: 0.0119s
Result: success_prob = 0.975654
Training gamma = 0.20 (3/11)
iter: 000, loss: 0.5029, best = 0.5029, lr: 1.50E-01, avg_time: 0.0126s
iter: 100, loss: 0.0472, best = 0.0472, lr: 9.37E-03, avg_time: 0.0105s
iter: 199, loss: 0.0472, best = 0.0472, lr: 3.66E-05, avg_time: 0.0106s
Result: success_prob = 0.952769
Training gamma = 0.30 (4/11)
iter: 000, loss: 0.6282, best = 0.6282, lr: 1.50E-01, avg_time: 0.0119s
iter: 100, loss: 0.0684, best = 0.0684, lr: 1.87E-02, avg_time: 0.0112s
iter: 199, loss: 0.0684, best = 0.0684, lr: 7.32E-05, avg_time: 0.0113s
Result: success_prob = 0.931567
Training gamma = 0.40 (5/11)
iter: 000, loss: 0.6896, best = 0.6896, lr: 1.50E-01, avg_time: 0.0133s
iter: 100, loss: 0.0877, best = 0.0877, lr: 7.50E-02, avg_time: 0.0120s
iter: 199, loss: 0.0877, best = 0.0877, lr: 1.46E-04, avg_time: 0.0116s
Result: success_prob = 0.912311
Training gamma = 0.50 (6/11)
iter: 000, loss: 0.3447, best = 0.3447, lr: 1.50E-01, avg_time: 0.0121s
iter: 100, loss: 0.1047, best = 0.1047, lr: 4.69E-03, avg_time: 0.0108s
iter: 199, loss: 0.1047, best = 0.1047, lr: 9.16E-06, avg_time: 0.0104s
Result: success_prob = 0.895285
Training gamma = 0.60 (7/11)
iter: 000, loss: 0.6612, best = 0.6612, lr: 1.50E-01, avg_time: 0.0147s
iter: 100, loss: 0.1192, best = 0.1192, lr: 1.87E-02, avg_time: 0.0100s
iter: 199, loss: 0.1192, best = 0.1192, lr: 3.66E-05, avg_time: 0.0105s
Result: success_prob = 0.880788
Training gamma = 0.70 (8/11)
iter: 000, loss: 0.6391, best = 0.6391, lr: 1.50E-01, avg_time: 0.0128s
iter: 100, loss: 0.1309, best = 0.1309, lr: 9.37E-03, avg_time: 0.0109s
iter: 199, loss: 0.1309, best = 0.1309, lr: 1.83E-05, avg_time: 0.0104s
Result: success_prob = 0.869120
Training gamma = 0.80 (9/11)
iter: 000, loss: 0.6432, best = 0.6432, lr: 1.50E-01, avg_time: 0.0098s
iter: 100, loss: 0.1395, best = 0.1395, lr: 1.50E-01, avg_time: 0.0100s
iter: 199, loss: 0.1394, best = 0.1394, lr: 2.93E-04, avg_time: 0.0098s
Result: success_prob = 0.860555
Training gamma = 0.90 (10/11)
iter: 000, loss: 0.5454, best = 0.5454, lr: 1.50E-01, avg_time: 0.0114s
iter: 100, loss: 0.1447, best = 0.1447, lr: 9.37E-03, avg_time: 0.0096s
iter: 199, loss: 0.1447, best = 0.1447, lr: 1.83E-05, avg_time: 0.0096s
Result: success_prob = 0.855316
Training gamma = 1.00 (11/11)
iter: 000, loss: 0.4941, best = 0.4941, lr: 1.50E-01, avg_time: 0.0120s
iter: 100, loss: 0.1464, best = 0.1464, lr: 1.87E-02, avg_time: 0.0097s
iter: 199, loss: 0.1464, best = 0.1464, lr: 3.66E-05, avg_time: 0.0104s
Result: success_prob = 0.853553
Final Results:
gamma = 0.00, success_prob = 1.000000
gamma = 0.10, success_prob = 0.975654
gamma = 0.20, success_prob = 0.952769
gamma = 0.30, success_prob = 0.931567
gamma = 0.40, success_prob = 0.912311
gamma = 0.50, success_prob = 0.895285
gamma = 0.60, success_prob = 0.880788
gamma = 0.70, success_prob = 0.869120
gamma = 0.80, success_prob = 0.860555
gamma = 0.90, success_prob = 0.855316
gamma = 1.00, success_prob = 0.853553
[10]:
plt.figure(figsize=(6, 4))
plt.plot(gamma_vals, results, marker='o')
plt.xlabel(r'$\gamma$')
plt.ylabel('Success Probability (1 - loss)')
plt.title(f'Success Probability vs AD channel parameter $\\gamma$ (LOCCNet)')
plt.grid(True)
plt.show()
As shown in the figure above, as the noise parameter \(\gamma\) increases, the success probability of distinguishing between the two states \(\Phi_0\) and \(\Phi_1\) decreases, indicating that they become harder to distinguish.
Quantum State Distillation¶
Problem Setup¶
Bell state distillation (also known as entanglement distillation) is a crucial task in quantum information processing. The goal is to extract high-fidelity entangled pairs from a larger number of noisy entangled pairs using only LOCC operations.
In this task, Alice and Bob share 4 noisy Bell states and they apply local operations and measurements to distill fewer but higher-quality entangled pairs, using a combination of single-qubit variational gates and CNOT gates.
Implementation¶
Our implementation uses OneWayLOCCNet with the following structure:
Alice and Bob: Each has 4 qubits
Circuit architecture: Two layers of single-qubit variational gates sandwich a layer of cycled CNOT gates
Optimization goal: Maximize the fidelity of the distilled entangled pairs
[11]:
net_purify = OneWayLOCCNet({'Alice': 4, 'Bob': 4})
net_purify['Alice'].u3([0, 1, 2, 3])
net_purify['Alice'].cnot('cycle')
net_purify['Alice'].u3([0, 1, 2, 3])
net_purify['Bob'].u3([0, 1, 2, 3])
net_purify['Bob'].cnot('cycle')
net_purify['Bob'].u3([0, 1, 2, 3])
Purification Protocol Details¶
The states we want to purify are S states, which are probabilistic mixtures of Bell states and \(|00\rangle\) states:
The protocol works as follows:
Use
net.set_init_stateto set the initial state for each pair of qubits shared between Alice and BobMeasure all qubits except the first qubit of Alice and Bob
Compare measurement results for the three qubit pairs (pairs 1, 2, and 3). We consider the distillation successful if, for each of these three pairs, Alice’s measurement outcome is the same as Bob’s (e.g., \(m_{A_1} = m_{B_1}\), \(m_{A_2} = m_{B_2}\), and \(m_{A_3} = m_{B_3}\)). If the measurement outcome is not in this list, the protocol fails for this round, and the pair is discarded.
If the measurement outcomes for any pair do not match, the protocol fails for this round, and the pair is discarded.
If successful, use partial trace to discard the measured states (pairs 1, 2, and 3).
The remaining state (on the first pair of qubits, pair 0) is the distilled state.
Our loss function is \(1 - F\), where \(F\) is the state fidelity between the distilled state (after partial trace) and the target EPR state \(|\Phi^{+}\rangle = \frac{|00\rangle + |11\rangle}{\sqrt{2}}\).
[12]:
meas = Measure('zzzzzz')
target_state = bell_state(2)
measurement_qubits = [1, 5, 2, 6, 3, 7]
successful_outcomes = [
'000000', '000011', '001100', '110000',
'001111', '111100', '110011', '111111'
]
def purify(net: OneWayLOCCNet, share_entangle: bool) -> torch.Tensor:
r""" Perform entanglement purification protocol.
Args:
net: The purification LOCC network
share_entangle: Whether share entangled pairs initially
"""
net._state_map.clear()
if share_entangle:
for i in range(4):
net.set_init_state([('Alice', i), ('Bob', i)])
status = net()
_, post_state = meas(status, qubits_idx=measurement_qubits,
desired_result=successful_outcomes, keep_state=True)
reduced_state = partial_trace(post_state, measurement_qubits)
return 1 - state_fidelity(reduced_state, target_state).mean()
Now we define the S state parameter \(p = 0.7\) and begin training:
[13]:
NUM_ITR = 100
LR = 0.2
p = 0.7
[14]:
time_list = []
opt = torch.optim.Adam(net_purify.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, 'min', factor=0.5)
for itr in range(NUM_ITR):
start_time = time.time()
opt.zero_grad()
loss_bell = purify(net_purify, share_entangle=True)
loss_00 = purify(net_purify, share_entangle=False)
loss = p * loss_bell + (1-p) * loss_00
loss.backward()
opt.step()
loss = loss.item()
scheduler.step(loss)
time_list.append(time.time() - start_time)
if itr % 20 == 0 or itr == NUM_ITR - 1:
print(f"iter: {str(itr).zfill(len(str(NUM_ITR)))}, " +
f"loss: {loss:.4f}, " +
f"lr: {scheduler.get_last_lr()[0]:.2E}, avg_time: {np.mean(time_list):.4f}s")
final_loss_bell = purify(net_purify, share_entangle=True)
final_loss_00 = purify(net_purify, share_entangle=False)
final_loss = p * final_loss_bell + (1-p) * final_loss_00
final_fidelity = 1 - final_loss.item()
print(f"\nTraining Complete! Final Fidelity: {final_fidelity:.5f}")
iter: 000, loss: 0.5949, lr: 2.00E-01, avg_time: 0.0775s
iter: 020, loss: 0.0977, lr: 2.00E-01, avg_time: 0.0547s
iter: 040, loss: 0.0893, lr: 2.00E-01, avg_time: 0.0544s
iter: 060, loss: 0.0880, lr: 2.00E-01, avg_time: 0.0536s
iter: 080, loss: 0.0879, lr: 2.00E-01, avg_time: 0.0538s
iter: 099, loss: 0.0879, lr: 2.00E-01, avg_time: 0.0533s
Training Complete! Final Fidelity: 0.91213
We first establish the baseline fidelity by calculating the fidelity between the initial S state \(\rho_s\) and the target Bell state \(\rho_{target} = |\Phi^+\rangle\langle\Phi^+|\).
The fidelity formula for a mixed state \(\rho_s\) and a pure target state \(\rho_{target}\) simplifies to:
This calculation yields the simplified formula:
For our initial state with \(p=0.7\), the baseline fidelity is:
Our resulting distilled fidelity of 0.9212 is higher than this 0.85 baseline, demonstrating that the LOCCNet protocol successfully purified the entanglement.
The experimental circuit diagram is shown below:
[15]:
net_purify.physical_circuit.plot()
References¶
[1] Zhao, Xuanqiang, et al. “Practical distributed quantum information processing with LOCCNet.” npj Quantum Information 7.1 (2021): 159.
Notation Reference¶
Table: A reference of notation conventions in this tutorial.
Symbol |
Variant |
Description |
|---|---|---|
\(\|\Phi^{+}\rangle\) |
Bell state (maximally entangled) |
|
\(\|\Phi^{-}\rangle\) |
Bell state with phase flip |
|
\(\|00\rangle\) |
Product state of two qubits in zero state |
|
\(\mathcal{A}\) |
Amplitude damping channel |
|
\(\gamma\) |
Noise parameter for amplitude damping |
|
\(E_0, E_1\) |
Kraus operators for the amplitude damping channel |
|
\(\rho\) |
Density matrix |
|
\(\Phi_0\) |
Noiseless Bell state (to be discriminated) |
|
\(\Phi_1\) |
Noisy state after amplitude damping (to be discriminated) |
|
\(L\) |
Loss function |
|
\(P(j\|\Phi_k)\) |
Conditional probability of outcome \(j\) given state \(\Phi_k\) |
|
\(\rho_s\) |
S state (mixed state for purification) |
|
\(p\) |
Mixing parameter in S state |
|
\(F\) |
State fidelity |
[16]:
qkit.print_info()
---------VERSION---------
quairkit: 0.4.4
torch: 2.8.0+cpu
numpy: 2.2.6
scipy: 1.15.3
matplotlib: 3.10.6
---------SYSTEM---------
Python version: 3.10.18
OS: Windows
OS version: 10.0.26100
---------DEVICE---------
CPU: ARMv8 (64-bit) Family 8 Model 1 Revision 201, Qualcomm Technologies Inc