JWT Debugging

· 791 words · 4 minute read

I need to remember to not stop at the following:

claims := jwt.Claims{}
if err := token.Claims(&claims); err != nil {
	panic(err)
}
log.Printf("%+v", claims)

but also try the following:

claims := map[string]interface{}{}
if err := token.Claims(&claims); err != nil {
	panic(err)
}
log.Printf("%+v", claims)

before I go dig myself a hole reimplementing JWT and JWE parsing and decrypting to figure out why it doesn’t work.

BTW, totally unrelatedly, here’s some JWT and JWE parsing code:

// jwt.go
package jwt

import (
	"crypto"
	"crypto/aes"
	"crypto/cipher"
	"crypto/rand"
	"crypto/rsa"
	"crypto/sha1"
	"crypto/sha256"
	"encoding/base64"
	"encoding/json"
	"fmt"
	"strings"

	"github.com/bernardigiri/go-pkcs7"
	"github.com/pkg/errors"
)

const (
	JWTCTY        = "JWT"
	RSAOAEP       = "RSA-OAEP"
	A128CBC_HS256 = "A128CBC-HS256"
	RS256         = "RS256"
)

type jweparts struct {
	header     string
	ekey       string
	iv         string
	ciphertext string
	authtag    string
}

func SplitJWE(token string) (jweparts, error) {
	parts := strings.Split(token, ".")

	res := jweparts{}

	if len(parts) != 5 {
		return res, fmt.Errorf("doesn't look like JWE")
	}

	res.header = parts[0]
	res.ekey = parts[1]
	res.iv = parts[2]
	res.ciphertext = parts[3]
	res.authtag = parts[4]

	return res, nil
}

type rawjwe struct {
	header     []byte
	ekey       []byte
	iv         []byte
	ciphertext []byte
	authtag    []byte
}

func ParseJWE(parts jweparts) (rawjwe, error) {
	res := rawjwe{}
	var err error

	res.header, err = base64.RawURLEncoding.DecodeString(parts.header)
	if err != nil {
		return res, errors.WithStack(err)
	}

	res.ekey, err = base64.RawURLEncoding.DecodeString(parts.ekey)
	if err != nil {
		return res, errors.WithStack(err)
	}

	res.iv, err = base64.RawURLEncoding.DecodeString(parts.iv)
	if err != nil {
		return res, errors.WithStack(err)
	}

	res.ciphertext, err = base64.RawURLEncoding.DecodeString(parts.ciphertext)
	if err != nil {
		return res, errors.WithStack(err)
	}

	res.authtag, err = base64.RawURLEncoding.DecodeString(parts.authtag)
	if err != nil {
		return res, errors.WithStack(err)
	}

	return res, nil
}

type jwe struct {
	header     jweheader
	key        []byte
	iv         []byte
	ciphertext []byte
	authtag    []byte
}

type jweheader struct {
	Alg string `json:"alg"`
	Cty string `json:"cty"`
	Enc string `json:"enc"`
	Kid string `json:"kid"`
}

func ParseJWEHeader(raw rawjwe, privateKey *rsa.PrivateKey) (jwe, error) {
	res := jwe{
		iv:         raw.iv,
		ciphertext: raw.ciphertext,
		authtag:    raw.authtag,
	}

	if err := json.Unmarshal(raw.header, &res.header); err != nil {
		return res, errors.WithStack(err)
	}

	if res.header.Alg != RSAOAEP {
		return res, fmt.Errorf("%s != %s", res.header.Alg, RSAOAEP)
	}

	if res.header.Cty != JWTCTY {
		return res, fmt.Errorf("%s != %s", res.header.Cty, JWTCTY)
	}

	if res.header.Enc != A128CBC_HS256 {
		return res, fmt.Errorf("%s != %s", res.header.Enc, A128CBC_HS256)
	}

	var err error

	res.key, err = rsa.DecryptOAEP(sha1.New(), rand.Reader, privateKey, raw.ekey, nil)
	if err != nil {
		return res, errors.WithStack(err)
	}

	return res, nil
}

// https://stackoverflow.com/questions/64809592/decrypting-jwt-encrypted-with-a128cbc-hs256-in-node-js
func DecryptJWE(t jwe) (string, error) {
	if len(t.key) != 32 {
		return "", fmt.Errorf("key length %d != 32", len(t.key))
	}
	//hmackey = t.key[0:16]
	aeskey := t.key[16:32]

	block, err := aes.NewCipher(aeskey)
	if err != nil {
		return "", errors.WithStack(err)
	}

	mode := cipher.NewCBCDecrypter(block, t.iv)
	padded := make([]byte, len(t.ciphertext), len(t.ciphertext))
	mode.CryptBlocks(padded, t.ciphertext)

	plaintext, err := pkcs7.Unpad(padded, 16)
	if err != nil {
		return "", errors.WithStack(err)
	}

	return string(plaintext), nil
}

type jwtparts struct {
	header    string
	payload   string
	signature string
}

func SplitJWT(t string) (jwtparts, error) {
	parts := strings.Split(t, ".")

	res := jwtparts{}

	if len(parts) != 3 {
		return res, fmt.Errorf("doesn't look like a JWT")
	}

	res.header = parts[0]
	res.payload = parts[1]
	res.signature = parts[2]

	return res, nil
}

type rawjwt struct {
	header    []byte
	payload   []byte
	signature []byte
}

func ParseJWT(p jwtparts) (rawjwt, error) {
	res := rawjwt{}
	var err error

	res.header, err = base64.RawURLEncoding.DecodeString(p.header)
	if err != nil {
		return res, errors.WithStack(err)
	}

	res.payload, err = base64.RawURLEncoding.DecodeString(p.payload)
	if err != nil {
		return res, errors.WithStack(err)
	}

	res.signature, err = base64.RawURLEncoding.DecodeString(p.signature)
	if err != nil {
		return res, errors.WithStack(err)
	}

	return res, nil
}

type jwtheader struct {
	Typ string `json:"typ"`
	Kid string `json:"kid"`
	Alg string `json:"alg"`
}

type parsedjwt struct {
	header         jwtheader
	payload        map[string]interface{}
	signatureValid bool
}

func ParseJWTHeader(r rawjwt, publicKey *rsa.PublicKey) (parsedjwt, error) {
	res := parsedjwt{
		payload: map[string]interface{}{},
	}

	if err := json.Unmarshal(r.header, &res.header); err != nil {
		return res, errors.WithStack(err)
	}

	if res.header.Typ != JWTCTY {
		return res, fmt.Errorf("wrong typ %s", res.header.Typ)
	}

	if res.header.Alg != RS256 {
		return res, fmt.Errorf("wrong alg %s", res.header.Alg)
	}

	if err := json.Unmarshal(r.payload, &res.payload); err != nil {
		return res, errors.WithStack(err)
	}

	data := []byte(fmt.Sprintf("%s.%s", base64.RawURLEncoding.EncodeToString(r.header), base64.RawStdEncoding.EncodeToString(r.payload)))

	h := sha256.Sum256(data)

	if err := rsa.VerifyPKCS1v15(publicKey, crypto.SHA256, h[:], r.signature); err != nil {
		return res, errors.WithStack(err)
	}

	res.signatureValid = true

	return res, nil
}
// jwt_test.go
package jwt

import (
	"testing"
)

const testjwe = `ey...` // you'll need a JWE token

func checkErr(t *testing.T, err error) {
	t.Helper()

	if err != nil {
		t.Fatalf("error %+v", err)
	}
}

func TestJWT(t *testing.T) {
	jwep, err := SplitJWE(testjwe)
	checkErr(t, err)

	jwer, err := ParseJWE(jwep)
	checkErr(t, err)

	jwe, err := ParseJWEHeader(jwer, privateKey) // you'll need privateKey
	checkErr(t, err)

	jwed, err := DecryptJWE(jwe)
	checkErr(t, err)

	jwtr, err := SplitJWT(jwed)
	checkErr(t, err)

	jwtp, err := ParseJWT(jwtr)
	checkErr(t, err)

	jwth, err := ParseJWTHeader(jwtp, publicKey) // you'll need publicKey
	checkErr(t, err)

	t.Logf("%+v", jwth)
}