@@ -169,6 +169,7 @@ type FakeIDP struct {
169169 // clientID to be used by coderd
170170 clientID string
171171 clientSecret string
172+ pkce bool // TODO(Emyrk): Implement for refresh token flow as well
172173 // externalProviderID is optional to match the provider in coderd for
173174 // redirectURLs.
174175 externalProviderID string
@@ -181,6 +182,8 @@ type FakeIDP struct {
181182 // These maps are used to control the state of the IDP.
182183 // That is the various access tokens, refresh tokens, states, etc.
183184 codeToStateMap * syncmap.Map [string , string ]
185+ // Code -> PKCE Challenge
186+ codeToChallengeMap * syncmap.Map [string , string ]
184187 // Token -> Email
185188 accessTokens * syncmap.Map [string , token ]
186189 // Refresh Token -> Email
@@ -239,6 +242,12 @@ func (s statusHookError) Error() string {
239242
240243type FakeIDPOpt func (idp * FakeIDP )
241244
245+ func WithPKCE () func (* FakeIDP ) {
246+ return func (f * FakeIDP ) {
247+ f .pkce = true
248+ }
249+ }
250+
242251func WithAuthorizedRedirectURL (hook func (redirectURL string ) error ) func (* FakeIDP ) {
243252 return func (f * FakeIDP ) {
244253 f .hookValidRedirectURL = hook
@@ -450,6 +459,7 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
450459 clientSecret : uuid .NewString (),
451460 logger : slog .Make (),
452461 codeToStateMap : syncmap .New [string , string ](),
462+ codeToChallengeMap : syncmap .New [string , string ](),
453463 accessTokens : syncmap .New [string , token ](),
454464 refreshTokens : syncmap .New [string , string ](),
455465 refreshTokensUsed : syncmap .New [string , bool ](),
@@ -557,8 +567,16 @@ func (f *FakeIDP) realServer(t testing.TB) *httptest.Server {
557567func (f * FakeIDP ) GenerateAuthenticatedToken (claims jwt.MapClaims ) (* oauth2.Token , error ) {
558568 state := uuid .NewString ()
559569 f .stateToIDTokenClaims .Store (state , claims )
560- code := f .newCode (state )
561- return f .locked .Config ().Exchange (oidc .ClientContext (context .Background (), f .HTTPClient (nil )), code )
570+
571+ exchangeOpts := []oauth2.AuthCodeOption {}
572+ verifier := ""
573+ if f .pkce {
574+ verifier = oauth2 .GenerateVerifier ()
575+ exchangeOpts = append (exchangeOpts , oauth2 .VerifierOption (verifier ))
576+ }
577+ code := f .newCode (state , oauth2 .S256ChallengeFromVerifier (verifier ))
578+
579+ return f .locked .Config ().Exchange (oidc .ClientContext (context .Background (), f .HTTPClient (nil )), code , exchangeOpts ... )
562580}
563581
564582// Login does the full OIDC flow starting at the "LoginButton".
@@ -756,10 +774,16 @@ func (f *FakeIDP) OIDCCallback(t testing.TB, state string, idTokenClaims jwt.Map
756774 panic ("cannot use OIDCCallback with WithServing. This is only for the in memory usage" )
757775 }
758776
777+ opts := []oauth2.AuthCodeOption {}
778+ if f .pkce {
779+ verifier := oauth2 .GenerateVerifier ()
780+ opts = append (opts , oauth2 .S256ChallengeOption (oauth2 .S256ChallengeFromVerifier (verifier )))
781+ }
782+
759783 f .stateToIDTokenClaims .Store (state , idTokenClaims )
760784
761785 cli := f .HTTPClient (nil )
762- u := f .locked .Config ().AuthCodeURL (state )
786+ u := f .locked .Config ().AuthCodeURL (state , opts ... )
763787 req , err := http .NewRequest ("GET" , u , nil )
764788 require .NoError (t , err )
765789
@@ -790,9 +814,10 @@ type ProviderJSON struct {
790814
791815// newCode enforces the code exchanged is actually a valid code
792816// created by the IDP.
793- func (f * FakeIDP ) newCode (state string ) string {
817+ func (f * FakeIDP ) newCode (state string , challenge string ) string {
794818 code := uuid .NewString ()
795819 f .codeToStateMap .Store (code , state )
820+ f .codeToChallengeMap .Store (code , challenge )
796821 return code
797822}
798823
@@ -918,6 +943,22 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
918943 mux .Handle (authorizePath , http .HandlerFunc (func (rw http.ResponseWriter , r * http.Request ) {
919944 f .logger .Info (r .Context (), "http call authorize" , slogRequestFields (r )... )
920945
946+ challenge := ""
947+ if f .pkce {
948+ method := r .URL .Query ().Get ("code_challenge_method" )
949+ challenge = r .URL .Query ().Get ("code_challenge" )
950+
951+ if method == "" {
952+ httpError (rw , http .StatusBadRequest , xerrors .New ("missing code_challenge_method" ))
953+ return
954+ }
955+
956+ if challenge == "" {
957+ httpError (rw , http .StatusBadRequest , xerrors .New ("missing code_challenge" ))
958+ return
959+ }
960+ }
961+
921962 clientID := r .URL .Query ().Get ("client_id" )
922963 if ! assert .Equal (t , f .clientID , clientID , "unexpected client_id" ) {
923964 httpError (rw , http .StatusBadRequest , xerrors .New ("invalid client_id" ))
@@ -959,7 +1000,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
9591000
9601001 q := ru .Query ()
9611002 q .Set ("state" , state )
962- q .Set ("code" , f .newCode (state ))
1003+ q .Set ("code" , f .newCode (state , challenge ))
9631004 ru .RawQuery = q .Encode ()
9641005
9651006 http .Redirect (rw , r , ru .String (), http .StatusTemporaryRedirect )
@@ -1009,8 +1050,28 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
10091050 http .Error (rw , "invalid code" , http .StatusBadRequest )
10101051 return
10111052 }
1053+
1054+ if f .pkce {
1055+ challenge , ok := f .codeToChallengeMap .Load (code )
1056+ if ! ok {
1057+ httpError (rw , http .StatusBadRequest , xerrors .New ("pkce: challenge not found for code" ))
1058+ return
1059+ }
1060+ codeVerifier := values .Get ("code_verifier" )
1061+ if codeVerifier == "" {
1062+ httpError (rw , http .StatusBadRequest , xerrors .New ("pkce: missing code_verifier" ))
1063+ return
1064+ }
1065+ expecter := oauth2 .S256ChallengeFromVerifier (codeVerifier )
1066+ if challenge != expecter {
1067+ httpError (rw , http .StatusBadRequest , xerrors .New ("pkce: invalid code verifier" ))
1068+ return
1069+ }
1070+ }
1071+
10121072 // Always invalidate the code after it is used.
10131073 f .codeToStateMap .Delete (code )
1074+ f .codeToChallengeMap .Delete (code )
10141075
10151076 idTokenClaims , ok := f .getClaims (f .stateToIDTokenClaims , stateStr )
10161077 if ! ok {
0 commit comments