import { afterEach, beforeEach, describe, expect, test, vi } from "bun:test";
import * as fs from "node:fs/promises";
import * as os from "node:os";
import * as path from "node:path";
import {
type AuthCredentialStore,
AuthStorage,
type CredentialDisabledEvent,
SqliteAuthCredentialStore,
} from "../src/auth-storage";
import * as oauthUtils from "../src/utils/oauth";
import { withEnv } from "./helpers";
const SUPPRESS_ANTHROPIC_ENV = {
ANTHROPIC_API_KEY: undefined,
ANTHROPIC_OAUTH_TOKEN: undefined,
} as const;
describe("AuthStorage OAuth refresh race", () => {
let tempDir = "";
let store: AuthCredentialStore | null = null;
let authStorage: AuthStorage | null = null;
let events: CredentialDisabledEvent[] = [];
beforeEach(async () => {
tempDir = await fs.mkdtemp(path.join(os.tmpdir(), "pi-ai-auth-oauth-race-"));
store = await SqliteAuthCredentialStore.open(path.join(tempDir, "agent.db"));
events = [];
authStorage = new AuthStorage(store, {
onCredentialDisabled: event => {
events.push(event);
},
});
});
afterEach(async () => {
vi.restoreAllMocks();
oauthUtils.unregisterOAuthProviders("auth-storage-oauth-refresh-race-test");
store?.close();
store = null;
authStorage = null;
if (tempDir) {
await fs.rm(tempDir, { recursive: true, force: true });
tempDir = "";
}
});
test("does not disable a credential another process already rotated", async () => {
if (!authStorage || !store) throw new Error("test setup failed");
await authStorage.set("anthropic", [
{
type: "oauth",
access: "stale-access",
refresh: "stale-refresh",
expires: Date.now() - 60_000,
},
]);
const storedBefore = store.listAuthCredentials("anthropic");
expect(storedBefore).toHaveLength(1);
const credentialId = storedBefore[0]!.id;
store.updateAuthCredential(credentialId, {
type: "oauth",
access: "fresh-access-from-peer",
refresh: "fresh-refresh-from-peer",
expires: Date.now() + 60 * 60_000,
});
vi.spyOn(oauthUtils, "getOAuthApiKey").mockImplementation(async (provider, creds) => {
const credential = creds[provider];
if (credential?.refresh === "stale-refresh") {
throw new Error(
'HTTP 400 invalid_grant {"error":"invalid_grant","error_description":"Refresh token not found or invalid"}',
);
}
return { newCredentials: credential!, apiKey: credential!.access };
});
await withEnv(SUPPRESS_ANTHROPIC_ENV, async () => {
const apiKey = await authStorage!.getApiKey("anthropic", "session-race");
expect(apiKey).toBe("fresh-access-from-peer");
expect(events).toHaveLength(0);
expect(authStorage!.list()).toContain("anthropic");
const stored = store!.listAuthCredentials("anthropic");
expect(stored).toHaveLength(1);
expect(stored[0]?.id).toBe(credentialId);
expect(stored[0]?.credential.type).toBe("oauth");
if (stored[0]?.credential.type === "oauth") {
expect(stored[0].credential.refresh).toBe("fresh-refresh-from-peer");
}
});
});
test("does not disable when peer rotates between pre-check and CAS disable", async () => {
if (!authStorage || !store) throw new Error("test setup failed");
await authStorage.set("anthropic", [
{
type: "oauth",
access: "stale-access",
refresh: "stale-refresh",
expires: Date.now() - 60_000,
},
]);
const storedBefore = store.listAuthCredentials("anthropic");
expect(storedBefore).toHaveLength(1);
const credentialId = storedBefore[0]!.id;
vi.spyOn(oauthUtils, "getOAuthApiKey").mockImplementation(async (provider, creds) => {
const credential = creds[provider];
if (credential?.refresh === "stale-refresh") {
throw new Error(
'HTTP 400 invalid_grant {"error":"invalid_grant","error_description":"Refresh token not found or invalid"}',
);
}
return { newCredentials: credential!, apiKey: credential!.access };
});
const sharedStore = store;
const originalTryDisable = sharedStore.tryDisableAuthCredentialIfMatches.bind(sharedStore);
const tryDisableSpy = vi
.spyOn(sharedStore, "tryDisableAuthCredentialIfMatches")
.mockImplementation((id, expectedData, disabledCause) => {
sharedStore.updateAuthCredential(id, {
type: "oauth",
access: "fresh-access-from-peer",
refresh: "fresh-refresh-from-peer",
expires: Date.now() + 60 * 60_000,
});
return originalTryDisable(id, expectedData, disabledCause);
});
await withEnv(SUPPRESS_ANTHROPIC_ENV, async () => {
const apiKey = await authStorage!.getApiKey("anthropic", "session-cas-race");
expect(apiKey).toBe("fresh-access-from-peer");
expect(events).toHaveLength(0);
expect(tryDisableSpy).toHaveBeenCalled();
const stored = sharedStore.listAuthCredentials("anthropic");
expect(stored).toHaveLength(1);
expect(stored[0]?.id).toBe(credentialId);
expect(stored[0]?.credential.type).toBe("oauth");
if (stored[0]?.credential.type === "oauth") {
expect(stored[0].credential.refresh).toBe("fresh-refresh-from-peer");
expect(stored[0].credential.access).toBe("fresh-access-from-peer");
}
});
});
test("still disables when the failure is real (no concurrent rotation)", async () => {
if (!authStorage) throw new Error("test setup failed");
await authStorage.set("anthropic", [
{
type: "oauth",
access: "expired-access",
refresh: "stale-refresh",
expires: Date.now() - 60_000,
},
]);
vi.spyOn(oauthUtils, "getOAuthApiKey").mockImplementation(async () => {
throw new Error('invalid_grant {"error":"invalid_grant"}');
});
await withEnv(SUPPRESS_ANTHROPIC_ENV, async () => {
const apiKey = await authStorage!.getApiKey("anthropic", "session-real-failure");
expect(apiKey).toBeUndefined();
expect(events).toHaveLength(1);
expect(events[0]?.disabledCause).toContain("invalid_grant");
});
});
test("persists every credential refreshed during candidate preflight", async () => {
if (!authStorage || !store) throw new Error("test setup failed");
const expires = Date.now() - 60_000;
const refreshedExpires = Date.now() + 60 * 60_000;
oauthUtils.registerOAuthProvider({
id: "unit-oauth-preflight",
name: "Unit OAuth Preflight",
sourceId: "auth-storage-oauth-refresh-race-test",
async login() {
return { access: "unused", refresh: "unused", expires: refreshedExpires };
},
async refreshToken(credentials) {
return {
...credentials,
access: `${credentials.access}-rotated`,
refresh: `${credentials.refresh}-rotated`,
expires: refreshedExpires,
};
},
getApiKey(credentials) {
return credentials.access;
},
});
await authStorage.set("unit-oauth-preflight", [
{ type: "oauth", access: "access-a", refresh: "refresh-a", expires },
{ type: "oauth", access: "access-b", refresh: "refresh-b", expires },
]);
const apiKey = await authStorage.getApiKey("unit-oauth-preflight");
expect(apiKey).toBe("access-a-rotated");
const stored = store.listAuthCredentials("unit-oauth-preflight");
expect(stored).toHaveLength(2);
const oauth = stored.map(entry => entry.credential).filter(credential => credential.type === "oauth");
expect(oauth.map(credential => credential.refresh).sort()).toEqual(["refresh-a-rotated", "refresh-b-rotated"]);
});
test("coalesces concurrent refreshes for the same credential", async () => {
if (!authStorage) throw new Error("test setup failed");
const expires = Date.now() - 60_000;
const refreshedExpires = Date.now() + 60 * 60_000;
const refreshStarted = Promise.withResolvers<void>();
const allowRefresh = Promise.withResolvers<void>();
let refreshCalls = 0;
oauthUtils.registerOAuthProvider({
id: "unit-oauth-mutex",
name: "Unit OAuth Mutex",
sourceId: "auth-storage-oauth-refresh-race-test",
async login() {
return { access: "unused", refresh: "unused", expires: refreshedExpires };
},
async refreshToken(credentials) {
refreshCalls += 1;
refreshStarted.resolve();
await allowRefresh.promise;
return {
...credentials,
access: "access-rotated",
refresh: "refresh-rotated",
expires: refreshedExpires,
};
},
getApiKey(credentials) {
return credentials.access;
},
});
await authStorage.set("unit-oauth-mutex", [
{ type: "oauth", access: "access-old", refresh: "refresh-old", expires },
]);
const first = authStorage.getApiKey("unit-oauth-mutex", "same-session");
const second = authStorage.getApiKey("unit-oauth-mutex", "same-session");
await refreshStarted.promise;
allowRefresh.resolve();
await expect(first).resolves.toBe("access-rotated");
await expect(second).resolves.toBe("access-rotated");
expect(refreshCalls).toBe(1);
});
test("invalidating a session-sticky OAuth credential rotates the retry to another active credential", async () => {
if (!authStorage) throw new Error("test setup failed");
let sessionId = "";
for (let index = 0; index < 32; index++) {
const candidate = `session-auth-retry-${index}`;
if (Bun.hash.xxHash32(candidate) % 2 === 0) {
sessionId = candidate;
break;
}
}
if (!sessionId) throw new Error("could not find test session id");
await authStorage.set("unit-oauth-rotation", [
{
type: "oauth",
access: "access-a",
refresh: "refresh-a",
expires: Date.now() + 60 * 60_000,
},
{
type: "oauth",
access: "access-b",
refresh: "refresh-b",
expires: Date.now() + 60 * 60_000,
},
]);
vi.spyOn(oauthUtils, "getOAuthApiKey").mockImplementation(async (provider, credentials) => {
const credential = credentials[provider];
if (!credential) return null;
return { newCredentials: credential, apiKey: credential.access };
});
const firstKey = await authStorage.getApiKey("unit-oauth-rotation", sessionId);
expect(firstKey).toBe("access-a");
const invalidated = await authStorage.invalidateCredentialMatching("unit-oauth-rotation", "access-a", {
sessionId,
});
expect(invalidated).toBe(true);
const retryKey = await authStorage.getApiKey("unit-oauth-rotation", sessionId);
expect(retryKey).toBe("access-b");
});
});