summaryrefslogtreecommitdiff
path: root/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc
diff options
context:
space:
mode:
Diffstat (limited to 'security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc')
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc94
1 files changed, 54 insertions, 40 deletions
diff --git a/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc
index 335bfecfa5..e4a9e5aed7 100644
--- a/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc
@@ -22,8 +22,11 @@ namespace nss_test {
class TlsHandshakeSkipFilter : public TlsRecordFilter {
public:
// A TLS record filter that skips handshake messages of the identified type.
- TlsHandshakeSkipFilter(uint8_t handshake_type)
- : handshake_type_(handshake_type), skipped_(false) {}
+ TlsHandshakeSkipFilter(const std::shared_ptr<TlsAgent>& agent,
+ uint8_t handshake_type)
+ : TlsRecordFilter(agent),
+ handshake_type_(handshake_type),
+ skipped_(false) {}
protected:
// Takes a record; if it is a handshake record, it removes the first handshake
@@ -92,9 +95,14 @@ class TlsSkipTest : public TlsConnectTestBase,
TlsSkipTest()
: TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}
+ void SetUp() override {
+ TlsConnectTestBase::SetUp();
+ EnsureTlsSetup();
+ }
+
void ServerSkipTest(std::shared_ptr<PacketFilter> filter,
uint8_t alert = kTlsAlertUnexpectedMessage) {
- server_->SetPacketFilter(filter);
+ server_->SetFilter(filter);
ConnectExpectAlert(client_, alert);
}
};
@@ -105,9 +113,14 @@ class Tls13SkipTest : public TlsConnectTestBase,
Tls13SkipTest()
: TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {}
- void ServerSkipTest(std::shared_ptr<TlsRecordFilter> filter, int32_t error) {
+ void SetUp() override {
+ TlsConnectTestBase::SetUp();
EnsureTlsSetup();
- server_->SetTlsRecordFilter(filter);
+ }
+
+ void ServerSkipTest(std::shared_ptr<TlsRecordFilter> filter, int32_t error) {
+ filter->EnableDecryption();
+ server_->SetFilter(filter);
ExpectAlert(client_, kTlsAlertUnexpectedMessage);
ConnectExpectFail();
client_->CheckErrorCode(error);
@@ -115,8 +128,8 @@ class Tls13SkipTest : public TlsConnectTestBase,
}
void ClientSkipTest(std::shared_ptr<TlsRecordFilter> filter, int32_t error) {
- EnsureTlsSetup();
- client_->SetTlsRecordFilter(filter);
+ filter->EnableDecryption();
+ client_->SetFilter(filter);
server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
ConnectExpectFailOneSide(TlsAgent::SERVER);
@@ -129,48 +142,49 @@ class Tls13SkipTest : public TlsConnectTestBase,
TEST_P(TlsSkipTest, SkipCertificateRsa) {
EnableOnlyStaticRsaCiphers();
- ServerSkipTest(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate));
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificate));
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
}
TEST_P(TlsSkipTest, SkipCertificateDhe) {
- ServerSkipTest(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate));
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificate));
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH);
}
TEST_P(TlsSkipTest, SkipCertificateEcdhe) {
- ServerSkipTest(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate));
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificate));
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH);
}
TEST_P(TlsSkipTest, SkipCertificateEcdsa) {
Reset(TlsAgent::kServerEcdsa256);
- ServerSkipTest(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate));
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificate));
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH);
}
TEST_P(TlsSkipTest, SkipServerKeyExchange) {
- ServerSkipTest(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeServerKeyExchange));
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeServerKeyExchange));
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
}
TEST_P(TlsSkipTest, SkipServerKeyExchangeEcdsa) {
Reset(TlsAgent::kServerEcdsa256);
- ServerSkipTest(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeServerKeyExchange));
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeServerKeyExchange));
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
}
TEST_P(TlsSkipTest, SkipCertAndKeyExch) {
- auto chain = std::make_shared<ChainedPacketFilter>(ChainedPacketFilterInit{
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate),
- std::make_shared<TlsHandshakeSkipFilter>(
- kTlsHandshakeServerKeyExchange)});
+ auto chain = std::make_shared<ChainedPacketFilter>(
+ ChainedPacketFilterInit{std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificate),
+ std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeServerKeyExchange)});
ServerSkipTest(chain);
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
}
@@ -178,48 +192,48 @@ TEST_P(TlsSkipTest, SkipCertAndKeyExch) {
TEST_P(TlsSkipTest, SkipCertAndKeyExchEcdsa) {
Reset(TlsAgent::kServerEcdsa256);
auto chain = std::make_shared<ChainedPacketFilter>();
- chain->Add(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate));
- chain->Add(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeServerKeyExchange));
+ chain->Add(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificate));
+ chain->Add(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeServerKeyExchange));
ServerSkipTest(chain);
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
}
TEST_P(Tls13SkipTest, SkipEncryptedExtensions) {
ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
- kTlsHandshakeEncryptedExtensions),
+ server_, kTlsHandshakeEncryptedExtensions),
SSL_ERROR_RX_UNEXPECTED_CERTIFICATE);
}
TEST_P(Tls13SkipTest, SkipServerCertificate) {
- ServerSkipTest(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate),
- SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY);
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificate),
+ SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY);
}
TEST_P(Tls13SkipTest, SkipServerCertificateVerify) {
- ServerSkipTest(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificateVerify),
- SSL_ERROR_RX_UNEXPECTED_FINISHED);
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificateVerify),
+ SSL_ERROR_RX_UNEXPECTED_FINISHED);
}
TEST_P(Tls13SkipTest, SkipClientCertificate) {
client_->SetupClientAuth();
server_->RequestClientAuth(true);
client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage);
- ClientSkipTest(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate),
- SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY);
+ ClientSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ client_, kTlsHandshakeCertificate),
+ SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY);
}
TEST_P(Tls13SkipTest, SkipClientCertificateVerify) {
client_->SetupClientAuth();
server_->RequestClientAuth(true);
client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage);
- ClientSkipTest(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificateVerify),
- SSL_ERROR_RX_UNEXPECTED_FINISHED);
+ ClientSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ client_, kTlsHandshakeCertificateVerify),
+ SSL_ERROR_RX_UNEXPECTED_FINISHED);
}
INSTANTIATE_TEST_CASE_P(