diff options
Diffstat (limited to 'security/nss/gtests/ssl_gtest/tls_agent.h')
-rw-r--r-- | security/nss/gtests/ssl_gtest/tls_agent.h | 53 |
1 files changed, 35 insertions, 18 deletions
diff --git a/security/nss/gtests/ssl_gtest/tls_agent.h b/security/nss/gtests/ssl_gtest/tls_agent.h index 4bccb9a849..6cd6d5073b 100644 --- a/security/nss/gtests/ssl_gtest/tls_agent.h +++ b/security/nss/gtests/ssl_gtest/tls_agent.h @@ -14,7 +14,6 @@ #include <iostream> #include "test_io.h" -#include "tls_filter.h" #define GTEST_HAS_RTTI 0 #include "gtest/gtest.h" @@ -37,7 +36,10 @@ enum SessionResumptionMode { RESUME_BOTH = RESUME_SESSIONID | RESUME_TICKET }; +class PacketFilter; class TlsAgent; +class TlsCipherSpec; +struct TlsRecord; const extern std::vector<SSLNamedGroup> kAllDHEGroups; const extern std::vector<SSLNamedGroup> kECDHEGroups; @@ -66,7 +68,6 @@ class TlsAgent : public PollTarget { static const std::string kServerRsaSign; static const std::string kServerRsaPss; static const std::string kServerRsaDecrypt; - static const std::string kServerRsaChain; // A cert that requires a chain. static const std::string kServerEcdsa256; static const std::string kServerEcdsa384; static const std::string kServerEcdsa521; @@ -81,20 +82,15 @@ class TlsAgent : public PollTarget { adapter_->SetPeer(peer->adapter_); } - void SetTlsRecordFilter(std::shared_ptr<TlsRecordFilter> filter) { - filter->SetAgent(this); + void SetFilter(std::shared_ptr<PacketFilter> filter) { adapter_->SetPacketFilter(filter); } - - void SetPacketFilter(std::shared_ptr<PacketFilter> filter) { - adapter_->SetPacketFilter(filter); - } - - void DeletePacketFilter() { adapter_->SetPacketFilter(nullptr); } + void ClearFilter() { adapter_->SetPacketFilter(nullptr); } void StartConnect(PRFileDesc* model = nullptr); void CheckKEA(SSLKEAType kea_type, SSLNamedGroup group, size_t kea_size = 0) const; + void CheckOriginalKEA(SSLNamedGroup kea_group) const; void CheckAuthType(SSLAuthType auth_type, SSLSignatureScheme sig_scheme) const; @@ -121,12 +117,10 @@ class TlsAgent : public PollTarget { void SetupClientAuth(); void RequestClientAuth(bool requireAuth); + void SetOption(int32_t option, int value); void ConfigureSessionCache(SessionResumptionMode mode); - void SetSessionTicketsEnabled(bool en); - void SetSessionCacheEnabled(bool en); void Set0RttEnabled(bool en); void SetFallbackSCSVEnabled(bool en); - void SetShortHeadersEnabled(); void SetVersionRange(uint16_t minver, uint16_t maxver); void GetVersionRange(uint16_t* minver, uint16_t* maxver); void CheckPreliminaryInfo(); @@ -136,7 +130,6 @@ class TlsAgent : public PollTarget { void ExpectReadWriteError(); void EnableFalseStart(); void ExpectResumption(); - void ExpectShortHeaders(); void SkipVersionChecks(); void SetSignatureSchemes(const SSLSignatureScheme* schemes, size_t count); void EnableAlpn(const uint8_t* val, size_t len); @@ -149,27 +142,49 @@ class TlsAgent : public PollTarget { // Send data on the socket, encrypting it. void SendData(size_t bytes, size_t blocksize = 1024); void SendBuffer(const DataBuffer& buf); + bool SendEncryptedRecord(const std::shared_ptr<TlsCipherSpec>& spec, + uint16_t wireVersion, uint64_t seq, uint8_t ct, + const DataBuffer& buf); // Send data directly to the underlying socket, skipping the TLS layer. void SendDirect(const DataBuffer& buf); + void SendRecordDirect(const TlsRecord& record); void ReadBytes(size_t max = 16384U); void ResetSentBytes(); // Hack to test drops. void EnableExtendedMasterSecret(); void CheckExtendedMasterSecret(bool expected); void CheckEarlyDataAccepted(bool expected); - void DisableRollbackDetection(); - void EnableCompression(); void SetDowngradeCheckVersion(uint16_t version); void CheckSecretsDestroyed(); void ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups); void DisableECDHEServerKeyReuse(); bool GetPeerChainLength(size_t* count); void CheckCipherSuite(uint16_t cipher_suite); + void SetResumptionTokenCallback(); + bool MaybeSetResumptionToken(); + void SetResumptionToken(const std::vector<uint8_t>& resumption_token) { + resumption_token_ = resumption_token; + } + const std::vector<uint8_t>& GetResumptionToken() const { + return resumption_token_; + } + void GetTokenInfo(ScopedSSLResumptionTokenInfo& token) { + SECStatus rv = SSL_GetResumptionTokenInfo( + resumption_token_.data(), resumption_token_.size(), token.get(), + sizeof(SSLResumptionTokenInfo)); + ASSERT_EQ(SECSuccess, rv); + } + void SetResumptionCallbackCalled() { resumption_callback_called_ = true; } + bool resumption_callback_called() const { + return resumption_callback_called_; + } const std::string& name() const { return name_; } Role role() const { return role_; } std::string role_str() const { return role_ == SERVER ? "server" : "client"; } + SSLProtocolVariant variant() const { return variant_; } + State state() const { return state_; } const CERTCertificate* peer_cert() const { @@ -253,6 +268,7 @@ class TlsAgent : public PollTarget { const static char* states[]; void SetState(State state); + void ValidateCipherSpecs(); // Dummy auth certificate hook. static SECStatus AuthCertificateHook(void* arg, PRFileDesc* fd, @@ -378,6 +394,7 @@ class TlsAgent : public PollTarget { uint8_t expected_sent_alert_; uint8_t expected_sent_alert_level_; bool handshake_callback_called_; + bool resumption_callback_called_; SSLChannelInfo info_; SSLCipherSuiteInfo csinfo_; SSLVersionRange vrange_; @@ -388,8 +405,8 @@ class TlsAgent : public PollTarget { HandshakeCallbackFunction handshake_callback_; AuthCertificateCallbackFunction auth_certificate_callback_; SniCallbackFunction sni_callback_; - bool expect_short_headers_; bool skip_version_checks_; + std::vector<uint8_t> resumption_token_; }; inline std::ostream& operator<<(std::ostream& stream, @@ -440,7 +457,7 @@ class TlsAgentTestBase : public ::testing::Test { void ProcessMessage(const DataBuffer& buffer, TlsAgent::State expected_state, int32_t error_code = 0); - std::unique_ptr<TlsAgent> agent_; + std::shared_ptr<TlsAgent> agent_; TlsAgent::Role role_; SSLProtocolVariant variant_; uint16_t version_; |