SDP Verification¶
The sdp_verify
directory contains a largely self-contained implementation of
the SDP-FO (first-order SDP verification) algorithm described in Dathathri et al
2020. We encourage projects building off this code to fork this directory,
though contributions are also welcome!
The core solver is contained in sdp_verify.py
. The main function is
dual_fun(verif_instance, dual_vars)
, which defines the dual upper bound from
Equation (5). For any feasible dual_vars
this provides a valid bound. It is
written amenable to autodiff, such that jax.grad
with respect to
dual_vars
yields a valid subgradient.
We also provide solve_sdp_dual_simple(verif_instance)
, which implements the
optimization loop (SDP-FO). This initializes the dual variables using our
proposed scheme, and performs projected subgradient steps.
Both methods accept a SdpDualVerifInstance
which specifies (1) the
Lagrangian, (2) interval bounds on the primal variables, and (3) dual variable
shapes.
As described in the paper, the solver can easily be applied to other
input/output specifications or network architectures for any QCQP. This involves
defining the corresponding QCQP Lagrangian and creating a
SdpDualVerifInstance
. In examples/run_sdp_verify.py
we include an
example for certifying adversarial L_inf robustness of a ReLU convolutional
network image classifier.
API Reference¶
-
jax_verify.sdp_verify.
dual_fun
(verif_instance, dual_vars, key=None, n_iter=30, scl=- 1, exact=False, dynamic_unroll=True, include_info=False)[source]¶ Returns the dual objective value.
- Parameters
verif_instance – a utils.SdpDualVerifInstance, the verification problem
dual_vars – A list of dual variables at each layer
key – PRNGKey passed to Lanczos
n_iter – Number of Lanczos iterations to use
scl – Inverse temperature in softmax over eigenvalues to smooth optimization problem (if negative treat as hardmax)
exact – Whether to use exact eigendecomposition instead of Lanczos
dynamic_unroll – bool. Whether to use jax.fori_loop for Lanczos for faster JIT compilation. Default is False.
include_info – if True, also return an info dict of various other values computed for the objective
- Returns
Either a single float, the dual upper bound, or if
include_info=True
, returns a pair, the dual bound and a dict containing debugging info
-
jax_verify.sdp_verify.
solve_sdp_dual
(verif_instance, key=None, opt=None, num_steps=10000, verbose=False, eval_every=1000, use_exact_eig_eval=True, use_exact_eig_train=False, n_iter_lanczos=30, scl=- 1.0, lr_init=0.001, steps_per_anneal=100, anneal_factor=1.0, num_anneals=3, opt_name='adam', gd_momentum=0.9, add_diagnostic_stats=False, opt_multiplier_fn=None, init_dual_vars=None, init_opt_state=None, opt_dual_vars=None, kappa_reg_weight=None, kappa_zero_after=None, device_type=None, save_best_k=1)[source]¶ Compute verified lower bound via dual of SDP relaxation.
NOTE: This method exposes many hyperparameter options, and the method signature is subject to change. We instead suggest using
solve_sdp_dual_simple
instead if you need a stable interface.
-
jax_verify.sdp_verify.
solve_sdp_dual_simple
(verif_instance, key=None, opt=None, num_steps=10000, eval_every=1000, verbose=False, use_exact_eig_eval=True, use_exact_eig_train=False, n_iter_lanczos=100, kappa_reg_weight=None, kappa_zero_after=None, device_type=None)[source]¶ Compute verified lower bound via dual of SDP relaxation.
- Parameters
verif_instance – a utils.SdpDualVerifInstance
key – jax.random.PRNGKey, used for Lanczos
opt – an optax.GradientTransformation instance, the optimizer. If None, defaults to Adam with learning rate 1e-3.
num_steps – int, the number of outer loop optimization steps
eval_every – int, frequency of running evaluation step
verbose – bool, enables verbose logging
use_exact_eig_eval – bool, whether to use exact eigendecomposition instead of Lanczos when computing evaluation loss
use_exact_eig_train – bool, whether to use exact eigendecomposition instead of Lanczos during training
n_iter_lanczos – int, number of Lanczos iterations
kappa_reg_weight – float, adds a penalty of sum(abs(kappa_{1:N})) to loss, which regularizes kappa_{1:N} towards zero. Default None is disabled.
kappa_zero_after – int, clamps kappa_{1:N} to zero after
kappa_zero_after
steps. Default None is disabled.device_type – string, used to clamp to a particular hardware device. Default None uses JAX default device placement
- Returns
A pair. The first element is a float, the final dual loss, which forms a valid upper bound on the objective specified by
verif_instance
. The second element is a dict containing various debug info.
-
class
jax_verify.sdp_verify.
SdpDualVerifInstance
(bounds, make_inner_lagrangian, dual_shapes, dual_types)[source]¶ A namedtuple specifying a verification instance for the dual SDP solver.
- Fields:
bounds: A list of bounds on post-activations at each layer
make_inner_lagrangian: A function which takes
dual_vars
as input, and returns another function, the inner lagrangian, which evaluates Lagrangian(x, dual_vars) for any valuex
(the set of activations).dual_types: A pytree matching dual_vars specifying which dual_vars should be non-negative.
dual_shapes: A pytree matching dual_vars specifying shape of each var.