Not even cloese, jax.jit allow you to compute almost anything using lax.for_loops, lax.cond and other lax and jax contsturts pytorch jit does not allow that its just extra optimization for static pytorch functions.
JAX autograd will work on most any jitted fn - the control-flow limitations are no autograd for code with for/while loops since there's a statically unknowable trip count through the loop body. Much looping code can be handled differentiably using a "scan" though.