def jacfwd(
    func: Callable,
    argnums: argnums_t = 0,
    has_aux: bool = False,
    *,
    randomness: str = "error",
    chunk_size=None
):
    """
    Wrapper around torch._functorch.jacfwd that accepts chunk size.
    See the original torch documentation for more details on the arguments.
    **Arguments:**
    - `func`: A Python function that takes one or more arguments, one of which
        must be a Tensor, and returns one or more Tensors
    - `argnums`: An integer or a tuple of integers specifying which positional
        argument(s) to differentiate with respect to.
    - `has_aux`: If `True`, `func` is assumed to return a pair where the first
        element is considered the output of the original function to be
        differentiated and the second element is auxiliary data.
    - `randomness`: A string specifying how to handle randomness in `func`.
        Valid values are “different”, “same”, “error”.
    - `chunk_size`: An integer specifying the chunk size for vmap.
    """
    @wraps(func)
    def wrapper_fn(*args):
        primals = args if argnums is None else _slice_argnums(args, argnums)
        flat_primals, primals_spec = tree_flatten(primals)
        flat_primals_numels = tuple(p.numel() for p in flat_primals)
        flat_basis = _construct_standard_basis_for(flat_primals, flat_primals_numels)
        basis = tree_unflatten(flat_basis, primals_spec)
        def push_jvp(basis):
            output = _jvp_with_argnums(
                func, args, basis, argnums=argnums, has_aux=has_aux
            )
            # output[0] is the output of `func(*args)`
            if has_aux:
                _, jvp_out, aux = output
                return jvp_out, aux
            _, jvp_out = output
            return jvp_out
        results = vmap(push_jvp, randomness=randomness, chunk_size=chunk_size)(basis)
        if has_aux:
            results, aux = results
            # aux is in the standard basis format, e.g. NxN matrix
            # We need to fetch the first element as original `func` output
            flat_aux, aux_spec = tree_flatten(aux)
            flat_aux = [value[0] for value in flat_aux]
            aux = tree_unflatten(flat_aux, aux_spec)
        jac_outs, spec = tree_flatten(results)
        # Most probably below output check can never raise an error
        # as jvp should test the output before
        # assert_non_empty_output(jac_outs, 'jacfwd(f, ...)(*args)')
        jac_outs_ins = tuple(
            tuple(
                safe_unflatten(jac_out_in, -1, primal.shape)
                for primal, jac_out_in in zip(
                    flat_primals,
                    jac_out.movedim(0, -1).split(flat_primals_numels, dim=-1),
                )
            )
            for jac_out in jac_outs
        )
        jac_outs_ins = tuple(
            tree_unflatten(jac_ins, primals_spec) for jac_ins in jac_outs_ins
        )
        if isinstance(argnums, int):
            jac_outs_ins = tuple(jac_ins[0] for jac_ins in jac_outs_ins)
        if has_aux:
            return tree_unflatten(jac_outs_ins, spec), aux
        return tree_unflatten(jac_outs_ins, spec)
    return wrapper_fn