- Published on
Building a PostgreSQL Wire Protocol Server using Vanilla, Modern Java 21
- Authors
- Name
- Gavin Ray
- @GavinRayDev
Table of Contents
- Outline
- Introduction to the PostgreSQL Wire Protocol
- Step 1: Doing The Simplest Thing That Could Possibly Work (TM)
- Step 1.1 - Initial Server Skeleton
- Connecting to the Server with psql
- Step 1.2 - Responding to the SSL Negotiation message and the Startup Message
- Step 1.3 - Differentiating between the SSL/Authentication request, and Command messages
- Step 1.4 - Handle a query and return data rows
- Step 2: Refactoring
- Refactoring into Sealed Interfaces with Records and Message Decoders (Pt 1. Client Side)
- Refactoring into Sealed Interfaces with Records and Message Decoders (Pt 2. Server Side)
This tutorial is meant to serve as a dual purpose guide:
- Demonstrate how to implement the Simple Query Protocol, where Java is an implementation detail
- Show practical examples of most of the new features since JDK 17, including:
- Records (JEP 395)
- Sealed Types (JEP 360/JEP 409)
- Pattern Matching for Switch (JEP 406)
- Virtual Threads aka Project Loom (JEP 425)
- Foreign-Function & Memory API (FMM) aka Project Panama (JEP 424)
- (Also give a practical example of
java.nio
'sAsynchronousChannelGroup
andAsynchronousServerSocketChannel
, for which there are few examples online)
Outline
This tutorial will be broken up into a series of 2 parts:
Because we are diligent coders who know to Do The Simplest Thing That Could Possibly Workâ˘, our first step will be an horrifically ugly, but functional, implementation of the PostgreSQL wire protocol server.
Then, we will refactor our code to be more readable and maintainable.
Introduction to the PostgreSQL Wire Protocol
The PostgreSQL wire protocol is a binary protocol that is used to communicate between a PostgreSQL client and server.
The protocol is documented in the PostgreSQL Protocol Documentation.
In my opinion, this documentation is not the most understandable. If you want to learn more about the protocol, I recommend the following presentation:
- https://beta.pgcon.org/2014/schedule/attachments/330_postgres-for-the-wire.pdf
- https://www.youtube.com/watch?v=qa22SouCr5E
What we're concerned with today, is primarily the following pieces of information:
- Postgres clients send two types of message to the server: Startup Messages and Commands
- Optionally, the Startup message can be preceded by an SSL Negotiation message, where the client asks the server if it supports SSL
Visualized as a Mermaid diagram:
There are many types of commands, but today we'll only be concerned with the Query command, which is used to execute a SQL query.
Knowing this, we can now start implementing our server.
Step 1: Doing The Simplest Thing That Could Possibly Work (TM)
We'll begin by implementing a basic java.nio.channels.AsynchronousServerSocketChannel
server, which will accept connections and print out the messages it receives:
Step 1.1 - Initial Server Skeleton
Below is the initial skeleton of our server.
- We create a
java.nio.channels.AsynchronousServerSocketChannel
and bind it tolocalhost
and the default Postgres port (5432
). - An
ExecutorService
is created, which will be used to create ajava.nio.channels.AsynchronousChannelGroup
for our server. - We use the
newVirtualThreadPerTaskExecutor
method, which will create a new LoomVirtual Thread
for the server thread pool. - Then, we accept connections and print out the messages we receive.
package postgres.wire.protocol;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.AsynchronousChannelGroup;
import java.nio.channels.AsynchronousServerSocketChannel;
import java.nio.channels.AsynchronousSocketChannel;
import java.nio.channels.CompletionHandler;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
class AsynchronousSocketServer {
private static final String HOST = "localhost";
private static final int PORT = 5432;
public static void main(String[] args) throws Exception {
ExecutorService executor = Executors.newVirtualThreadPerTaskExecutor();
AsynchronousChannelGroup group = AsynchronousChannelGroup.withThreadPool(executor);
try (AsynchronousServerSocketChannel server = AsynchronousServerSocketChannel.open(group)) {
server.bind(new InetSocketAddress(HOST, PORT));
System.out.println("[SERVER] Listening on " + HOST + ":" + PORT);
for (;;) {
Future<AsynchronousSocketChannel> future = server.accept();
AsynchronousSocketChannel client = future.get();
System.out.println("[SERVER] Accepted connection from " + client.getRemoteAddress());
ByteBuffer buffer = ByteBuffer.allocate(1024);
client.read(buffer, buffer, new CompletionHandler<>() {
@Override
public void completed(Integer result, ByteBuffer attachment) {
attachment.flip();
if (result != -1) {
onMessageReceived(client, attachment);
}
attachment.clear();
client.read(attachment, attachment, this);
}
@Override
public void failed(Throwable exc, ByteBuffer attachment) {
System.err.println("[SERVER] Failed to read from client: " + exc);
exc.printStackTrace();
}
});
}
}
}
private static void onMessageReceived(AsynchronousSocketChannel client, ByteBuffer buffer) {
System.out.println("[SERVER] Received message from client: " + client);
System.out.println("[SERVER] Buffer: " + buffer);
}
}
class MainSimplest {
public static void main(String[] args) throws Exception {
AsynchronousSocketServer.main(args);
}
}
If we start this up, we should see:
Running Gradle on WSL...
> Task :app:compileJava
> Task :app:processResources NO-SOURCE
Note: Some input files use preview features of Java SE 21.
Note: Recompile with -Xlint:preview for details.
> Task :app:classes
> Task :app:MainSimplestWalkThrough.main()
[SERVER] Listening on localhost:5432
psql
Connecting to the Server with Now, we can connect to our server using psql
:
$ psql -h localhost -p 5432 -U postgres
We should see psql
hang at the prompt, and the server should print out the following:
[SERVER] Accepted connection from /127.0.0.1:41826
[SERVER] Received message from client: sun.nio.ch.UnixAsynchronousSocketChannelImpl[connected local=/127.0.0.1:5432 remote=/127.0.0.1:41826]
[SERVER] Buffer: java.nio.HeapByteBuffer[pos=0 lim=8 cap=1024]
Great! We're now able to receive messages from the client.
Step 1.2 - Responding to the SSL Negotiation message and the Startup Message
What we want to do is just make sure that we're able to receive messages from the client, and respond to the:
- Initial SSL Negotiation message with a
'N'
byte (for No) - Startup Message with an
AuthenticationOk
message
Finally, we'll write a:
BackendKeyData
message, which is used to identify the connection to the clientReadyForQuery
message, which indicates that the server is ready to accept commands.
Below is the updated code:
private static void onMessageReceived(AsynchronousSocketChannel client, ByteBuffer buffer) {
System.out.println("[SERVER] Received message from client: " + client);
System.out.println("[SERVER] Buffer: " + buffer);
// First, write 'N' for SSL negotiation
ByteBuffer response = ByteBuffer.allocate(1);
response.put((byte) 'N');
response.flip();
Future<Integer> writeResult = client.write(response);
// Then, write AuthenticationOk
ByteBuffer authOk = ByteBuffer.allocate(9);
authOk.put((byte) 'R'); // 'R' for AuthenticationOk
authOk.putInt(8); // Length
authOk.putInt(0); // AuthenticationOk
authOk.flip();
writeResult = client.write(authOk);
// Then, write BackendKeyData
ByteBuffer backendKeyData = ByteBuffer.allocate(17);
backendKeyData.put((byte) 'K'); // Message type
backendKeyData.putInt(12); // Message length
backendKeyData.putInt(1234); // Process ID
backendKeyData.putInt(5678); // Secret key
backendKeyData.flip();
writeResult = client.write(backendKeyData);
// Then, write ReadyForQuery
ByteBuffer readyForQuery = ByteBuffer.allocate(6);
readyForQuery.put((byte) 'Z'); // 'Z' for ReadyForQuery
readyForQuery.putInt(5); // Length
readyForQuery.put((byte) 'I'); // Transaction status indicator, 'I' for idle
readyForQuery.flip();
writeResult = client.write(readyForQuery);
try {
writeResult.get();
} catch (Exception e) {
System.err.println("[SERVER] Failed to write to client: " + e);
}
}
From this point on, it'll be useful to be able to visualize the messages that we're sending and receiving.
There are a few tools that you can use to do this:
- Wireshark/tshark
- pgs-debug (part of
pgshark
) - pgmockproxy
I recommend using Wireshark's GUI, it's the easiest to use. For this tutorial I'll be using pgs-debug
, for two reasons:
- Wireshark doesn't work on WSL
- I want to be able to paste the ASCII output into the tutorial
NOTE: If you want a video tutorial on how to capture Postgres traffic with Wireshark, I have a short demo on my
pgprotokt
repo here:
To capture the output with pgs-debug
, I'll be using the command:
# Capture on loopback interface
$ sudo pgs-debug --interface lo
If we start up the server and connect with psql
, we should see the following:
psql
client:
[user@MSI ~]$ psql -h localhost -p 5432 -U postgres
psql (15.0, server 0.0.0)
WARNING: psql major version 15, server major version 0.0.
Some psql features might not work.
Type "help" for help.
postgres=>
pgs-debug
output:
[user@MSI ~]$ sudo pgs-debug --interface lo
Packet: t=1673886702.924458, session=213070643347544
PGSQL: type=SSLRequest, F -> B
SSL REQUEST
Packet: t=1673886702.928187, session=213070643347544
PGSQL: type=SSLAnswer, B -> F
SSL BACKEND ANSWER: N
Packet: t=1673886702.928222, session=213070643347544
PGSQL: type=StartupMessage, F -> B
STARTUP MESSAGE version: 3
application_name=psql
database=postgres
client_encoding=UTF8
user=postgres
Packet: t=1673886702.928318, session=213070643347544
PGSQL: type=AuthenticationOk, B -> F
AUTHENTIFICATION REQUEST code=0 (SUCCESS)
Packet: t=1673886702.970239, session=213070643347544
PGSQL: type=BackendKeyData, B -> F
BACKEND KEY DATA pid=1234, key=5678
Packet: t=1673886702.970239, session=213070643347544
PGSQL: type=ReadyForQuery, B -> F
READY FOR QUERY type=<IDLE>
- Server output:
[SERVER] Listening on localhost:5432
[SERVER] Accepted connection from /127.0.0.1:47544
[SERVER] Received message from client: sun.nio.ch.UnixAsynchronousSocketChannelImpl[connected local=/127.0.0.1:5432 remote=/127.0.0.1:47544]
[SERVER] Buffer: java.nio.HeapByteBuffer[pos=0 lim=8 cap=1024]
[SERVER] Received message from client: sun.nio.ch.UnixAsynchronousSocketChannelImpl[connected local=/127.0.0.1:5432 remote=/127.0.0.1:47544]
[SERVER] Buffer: java.nio.HeapByteBuffer[pos=0 lim=84 cap=1024]
Step 1.3 - Differentiating between the SSL/Authentication request, and Command messages
We need to be able to differentiate between the SSL Negotiation message, the Authentication request, and the standard command messages.
This is so that we can properly route the messages to the appropriate handlers. Otherwise we wouldn't be able to serve more than one client at a time.
To do this, we can create some predicate helpers for testing the message type to determine whether it's the SSL Request or Startup Message.
static Predicate<ByteBuffer> isSSLRequest = (ByteBuffer b) -> {
return b.get(4) == 0x04
&& b.get(5) == (byte) 0xd2
&& b.get(6) == 0x16
&& b.get(7) == 0x2f;
};
static Predicate<ByteBuffer> isStartupMessage = (ByteBuffer b) -> {
return b.remaining() > 8
&& b.get(4) == 0x00
&& b.get(5) == 0x03 // Protocol version 3
&& b.get(6) == 0x00
&& b.get(7) == 0x00;
};
private static void onMessageReceived(AsynchronousSocketChannel client, ByteBuffer buffer) {
System.out.println("[SERVER] Received message from client: " + client);
System.out.println("[SERVER] Buffer: " + buffer);
Future<Integer> writeResult = null;
if (isSSLRequest.test(buffer)) {
System.out.println("[SERVER] SSL Request");
ByteBuffer sslResponse = ByteBuffer.allocate(1);
sslResponse.put((byte) 'N');
sslResponse.flip();
writeResult = client.write(sslResponse);
} else if (isStartupMessage.test(buffer)) {
System.out.println("[SERVER] Startup Message");
// Then, write AuthenticationOk
ByteBuffer authOk = ByteBuffer.allocate(9);
authOk.put((byte) 'R'); // 'R' for AuthenticationOk
authOk.putInt(8); // Length
authOk.putInt(0); // AuthenticationOk
authOk.flip();
writeResult = client.write(authOk);
// Then, write BackendKeyData
ByteBuffer backendKeyData = ByteBuffer.allocate(17);
backendKeyData.put((byte) 'K'); // Message type
backendKeyData.putInt(12); // Message length
backendKeyData.putInt(1234); // Process ID
backendKeyData.putInt(5678); // Secret key
backendKeyData.flip();
writeResult = client.write(backendKeyData);
// Then, write ReadyForQuery
ByteBuffer readyForQuery = ByteBuffer.allocate(6);
readyForQuery.put((byte) 'Z'); // 'Z' for ReadyForQuery
readyForQuery.putInt(5); // Length
readyForQuery.put((byte) 'I'); // Transaction status indicator, 'I' for idle
readyForQuery.flip();
writeResult = client.write(readyForQuery);
} else {
System.out.println("[SERVER] Unknown message");
}
try {
System.out.println("[SERVER] Write result: " + writeResult.get());
} catch (Exception e) {
System.err.println("[SERVER] Failed to write to client: " + e);
e.printStackTrace();
}
}
If we re-start and reconnect with psql
, we should now see:
[SERVER] Listening on localhost:5432
[SERVER] Accepted connection from /127.0.0.1:35090
[SERVER] Received message from client: sun.nio.ch.UnixAsynchronousSocketChannelImpl[connected local=/127.0.0.1:5432 remote=/127.0.0.1:35090]
[SERVER] Buffer: java.nio.HeapByteBuffer[pos=0 lim=8 cap=1024]
[SERVER] SSL Request
[SERVER] Write result: 1
[SERVER] Received message from client: sun.nio.ch.UnixAsynchronousSocketChannelImpl[connected local=/127.0.0.1:5432 remote=/127.0.0.1:35090]
[SERVER] Buffer: java.nio.HeapByteBuffer[pos=0 lim=84 cap=1024]
[SERVER] Startup Message
[SERVER] Write result: 6
Step 1.4 - Handle a query and return data rows
Now, the moment you've likely all been waiting for. Let's handle a query and return some data rows.
"Handle" in this case means that we'll just return a hard-coded set of rows, and not actually query a database (sorry to disappoint!). I did say "no libraries" =(
To do this, we'll need to handle the Query
message, and then send a RowDescription
message, followed by a DataRow
message for each row, and finally a CommandComplete
message.
- For our
RowDescription
message, we'll send two columns, with names "id" and "name" - For our
DataRow
message, we'll send two rows, with values (1, "one") and (2, "two")
To complete the cycle, we lastly need to follow up with a ReadyForQuery
message.
This is the hairiest part of the protocol, so the annotated code below should hopefully help you understand what's going on:
} else {
System.out.println("[SERVER] Unknown message");
// Let's assume it's a query message, and just send a simple response
// First we send a RowDescription. We'll send two columns, with names "id" and "name"
ByteBuffer rowDescription = ByteBuffer.allocate(51);
rowDescription.put((byte) 'T'); // 'T' for RowDescription
rowDescription.putInt(50); // Length
rowDescription.putShort((short) 2); // Number of fields/columns
// For each field/column:
rowDescription.put("id".getBytes()).put((byte) 0); // Column name of column 1 (null-terminated)
rowDescription.putInt(0); // Object ID of column 1
rowDescription.putShort((short) 0); // Attribute number of column 1
rowDescription.putInt(23); // Data type OID of column 1
rowDescription.putShort((short) 4); // Data type size of column 1
rowDescription.putInt(-1); // Type modifier of column 1
rowDescription.putShort((short) 0); // Format code of column 1
rowDescription.put("name".getBytes()).put((byte) 0); // Column name of column 2 (null-terminated)
rowDescription.putInt(0); // Object ID of column 2
rowDescription.putShort((short) 0); // Attribute number of column 2
rowDescription.putInt(25); // Data type OID of column 2
rowDescription.putShort((short) -1); // Data type size of column 2
rowDescription.putInt(-1); // Type modifier of column 2
rowDescription.putShort((short) 0); // Format code of column 2
rowDescription.flip();
writeResult = client.write(rowDescription);
// Then we send a DataRow for each row. We'll send two rows, with values (1, "one") and (2, "two")
ByteBuffer dataRow1 = ByteBuffer.allocate(19);
dataRow1.put((byte) 'D'); // 'D' for DataRow
dataRow1.putInt(18); // Length (4)
dataRow1.putShort((short) 2); // Number of columns (5-6)
dataRow1.putInt(1); // Length of column 1 (7-10)
dataRow1.put((byte) '1'); // Value of column 1 (11-11)
dataRow1.putInt(3); // Length of column 2 (12-15)
dataRow1.put("one".getBytes()); // Value of column 2 (16-18)
dataRow1.flip();
writeResult = client.write(dataRow1);
ByteBuffer dataRow2 = ByteBuffer.allocate(19);
dataRow2.put((byte) 'D'); // 'D' for DataRow
dataRow2.putInt(18); // Length
dataRow2.putShort((short) 2); // Number of columns
dataRow2.putInt(1); // Length of column 2
dataRow2.put((byte) '2'); // Value of column 2
dataRow2.putInt(3); // Length of column 2
dataRow2.put("two".getBytes()); // Value of column 2
dataRow2.flip();
writeResult = client.write(dataRow2);
// We send a CommandComplete
ByteBuffer commandComplete = ByteBuffer.allocate(14);
commandComplete.put((byte) 'C'); // 'C' for CommandComplete
commandComplete.putInt(13); // Length
commandComplete.put("SELECT 2".getBytes()); // Command tag
commandComplete.put((byte) 0); // Null terminator
commandComplete.flip();
writeResult = client.write(commandComplete);
// Finally, write ReadyForQuery
ByteBuffer readyForQuery = ByteBuffer.allocate(6);
readyForQuery.put((byte) 'Z'); // 'Z' for ReadyForQuery
readyForQuery.putInt(5); // Length
readyForQuery.put((byte) 'I'); // Transaction status indicator, 'I' for idle
readyForQuery.flip();
writeResult = client.write(readyForQuery);
}
If we run this, we should see the following output:
psql
Client:
$ psql -h localhost -p 5432 -U postgres
psql (15.0, server 0.0.0)
WARNING: psql major version 15, server major version 0.0.
Some psql features might not work.
Type "help" for help.
postgres=> select 1;
id | name
----+------
1 | one
2 | two
(2 rows)
postgres=>
pgs-debug
output:
Packet: t=1673887134.207655, session=213070643346422
PGSQL: type=SSLRequest, F -> B
SSL REQUEST
Packet: t=1673887134.210943, session=213070643346422
PGSQL: type=SSLAnswer, B -> F
SSL BACKEND ANSWER: N
Packet: t=1673887134.210986, session=213070643346422
PGSQL: type=StartupMessage, F -> B
STARTUP MESSAGE version: 3
client_encoding=UTF8
user=postgres
database=postgres
application_name=psql
Packet: t=1673887134.211448, session=213070643346422
PGSQL: type=AuthenticationOk, B -> F
AUTHENTIFICATION REQUEST code=0 (SUCCESS)
Packet: t=1673887134.260990, session=213070643346422
PGSQL: type=BackendKeyData, B -> F
BACKEND KEY DATA pid=1234, key=5678
Packet: t=1673887134.260990, session=213070643346422
PGSQL: type=ReadyForQuery, B -> F
READY FOR QUERY type=<IDLE>
Packet: t=1673887136.771401, session=213070643346422
PGSQL: type=Query, F -> B
QUERY query=select 1;
Packet: t=1673887136.772593, session=213070643346422
PGSQL: type=RowDescription, B -> F
ROW DESCRIPTION: num_fields=2
---[Field 01]---
name='id'
type=23
type_len=4
type_mod=4294967295
relid=0
attnum=0
format=0
---[Field 02]---
name='name'
type=25
type_len=65535
type_mod=4294967295
relid=0
attnum=0
format=0
Packet: t=1673887136.772650, session=213070643346422
PGSQL: type=DataRow, B -> F
DATA ROW num_values=2
---[Value 0001]---
length=1
value='1'
---[Value 0002]---
length=3
value='one'
Packet: t=1673887136.772717, session=213070643346422
PGSQL: type=DataRow, B -> F
DATA ROW num_values=2
---[Value 0001]---
length=1
value='2'
---[Value 0002]---
length=3
value='two'
Packet: t=1673887136.772759, session=213070643346422
PGSQL: type=CommandComplete, B -> F
COMMAND COMPLETE command='SELECT 2'
Packet: t=1673887136.772792, session=213070643346422
PGSQL: type=ReadyForQuery, B -> F
READY FOR QUERY type=<IDLE>
- Server output:
[SERVER] Listening on localhost:5432
[SERVER] Accepted connection from /127.0.0.1:46422
[SERVER] Received message from client: sun.nio.ch.UnixAsynchronousSocketChannelImpl[connected local=/127.0.0.1:5432 remote=/127.0.0.1:46422]
[SERVER] Buffer: java.nio.HeapByteBuffer[pos=0 lim=8 cap=1024]
[SERVER] SSL Request
[SERVER] Write result: 1
[SERVER] Received message from client: sun.nio.ch.UnixAsynchronousSocketChannelImpl[connected local=/127.0.0.1:5432 remote=/127.0.0.1:46422]
[SERVER] Buffer: java.nio.HeapByteBuffer[pos=0 lim=84 cap=1024]
[SERVER] Startup Message
[SERVER] Write result: 6
[SERVER] Received message from client: sun.nio.ch.UnixAsynchronousSocketChannelImpl[connected local=/127.0.0.1:5432 remote=/127.0.0.1:46422]
[SERVER] Buffer: java.nio.HeapByteBuffer[pos=0 lim=15 cap=1024]
[SERVER] Unknown message
[SERVER] Write result: 6
Step 2: Refactoring
Now that we have a working server, we can start refactoring the code to make it more readable and maintainable.
This is where we'll pull in some of the more advanced features of recent Java versions, like Sealed Interfaces, Records, and Pattern Matching.
Our plan for refactoring is going to be to:
- Pull both the Client -> Server, and Server -> Client messages into their own sealed interfaces, where each message type is represented by a record.
- We'll then use encoders/decoders to convert between the raw bytes and the message types
- And do this in a succinct way using pattern matching in a
return switch + yield
statement.
- And do this in a succinct way using pattern matching in a
- Along the way, we'll see how we can use some of the API methods from the
Foreign Function and Memory
JEP to help us work withC-Strings
from raw bytes.
In a Mermaid diagram, the hierarchy of sealed interface types for client and server messages looks like this:
Refactoring into Sealed Interfaces with Records and Message Decoders (Pt 1. Client Side)
We'll start with the client side, since it's a bit simpler.
Most of this code is pretty straightforward, though there's one interesting bit that I'll point out.
When we decode the StartupMessage
, what we have are a series of null-terminated key/value string pairs.
The Foreign Function and Memory
JEP provides a MemorySegment
class that we can use to work with raw bytes. One of the convenience methods on MemorySegment
is getUtf8String(long offset)
, which will read a null-terminated string from the given offset.
This method is incredibly useful, previously it required a lot of boilerplate when working with ByteBuffer
to read a null-terminated C-String. Unfortunately this method does not directly return a Java string, so we need to do some extra conversion to convert the MemorySegment
bytes further to a String
.
With these records and encoders in place, it allows our dispatch logic to become a succinct pattern match statement:
private static void onMessageReceived(AsynchronousSocketChannel client, ByteBuffer buffer) {
PGClientMessage message = PGClientMessage.decode(buffer);
switch (message) {
case PGClientMessage.SSLNegotation ssl -> handleSSLRequest(ssl, client);
case PGClientMessage.StartupMessage startup -> handleStartupMessage(startup, client);
case PGClientMessage.QueryMessage query -> handleQueryMessage(query, client);
}
}
The full code for the client message types and decoders is below:
sealed interface PGClientMessage permits
PGClientMessage.SSLNegotation,
PGClientMessage.StartupMessage,
PGClientMessage.QueryMessage {
record SSLNegotation() implements PGClientMessage {
}
record StartupMessage(Map<String, String> parameters) implements PGClientMessage {
}
record QueryMessage(String query) implements PGClientMessage {
}
Predicate<ByteBuffer> isSSLRequest = (ByteBuffer b) -> {
return b.get(4) == 0x04
&& b.get(5) == (byte) 0xd2
&& b.get(6) == 0x16
&& b.get(7) == 0x2f;
};
Predicate<ByteBuffer> isStartupMessage = (ByteBuffer b) -> {
return b.remaining() > 8
&& b.get(4) == 0x00
&& b.get(5) == 0x03 // Protocol version 3
&& b.get(6) == 0x00
&& b.get(7) == 0x00;
};
static PGClientMessage decode(ByteBuffer buffer) {
if (isSSLRequest.test(buffer)) {
return new SSLNegotation();
} else if (isStartupMessage.test(buffer)) {
var segment = MemorySegment.ofBuffer(buffer);
var length = buffer.getInt(0);
var parameters = new HashMap<String, String>();
var offset = 8;
while (offset < length - 1) {
var name = segment.getUtf8String(offset);
offset += name.length() + 1;
var value = segment.getUtf8String(offset);
offset += value.length() + 1;
parameters.put(name, value);
}
return new StartupMessage(parameters);
} else {
// Assume it's a query message
var query = MemorySegment.ofBuffer(buffer).getUtf8String(5);
return new QueryMessage(query);
}
}
}
// In "AsynchronousSocketServer"
private static void onMessageReceived(AsynchronousSocketChannel client, ByteBuffer buffer) {
System.out.println("[SERVER] Received message from client: " + client);
System.out.println("[SERVER] Buffer: " + buffer);
PGClientMessage message = PGClientMessage.decode(buffer);
switch (message) {
case PGClientMessage.SSLNegotation ssl -> handleSSLRequest(ssl, client);
case PGClientMessage.StartupMessage startup -> handleStartupMessage(startup, client);
case PGClientMessage.QueryMessage query -> handleQueryMessage(query, client);
}
}
// Where each of those methods contains the previous logic it held in the "if/else" statement, for example:
private static void handleSSLRequest(PGClientMessage.SSLNegotation sslRequest, AsynchronousSocketChannel client) {
System.out.println("[SERVER] SSL Request: " + sslRequest);
ByteBuffer sslResponse = ByteBuffer.allocate(1);
sslResponse.put((byte) 'N');
sslResponse.flip();
try {
client.write(sslResponse).get();
} catch (Exception e) {
e.printStackTrace();
}
}
Refactoring into Sealed Interfaces with Records and Message Decoders (Pt 2. Server Side)
The final section of our journey, is just more of the same, so I will leave this section a bit more brief.
It does what it says on the tin, we refactor the server side to use sealed interfaces with records and message decoders.
sealed interface PGServerMessage permits
PGServerMessage.AuthenticationRequest, PGServerMessage.BackendKeyData,
PGServerMessage.ReadyForQuery, PGServerMessage.RowDescription,
PGServerMessage.DataRow, PGServerMessage.CommandComplete {
record AuthenticationRequest() implements PGServerMessage {
}
record BackendKeyData(int processId, int secretKey) implements PGServerMessage {
}
record ReadyForQuery() implements PGServerMessage {
}
record RowDescription(List<RowDescription.Field> fields) implements PGServerMessage {
record Field(
String name,
int tableObjectId,
int columnAttributeNumber,
int dataTypeObjectId,
int dataTypeSize,
int typeModifier,
int formatCode) {
int length() {
// 4 (int tableObjectId) + 2 (short columnAttributeNumber) + 4 (int dataTypeObjectId) + 2 (short dataTypeSize) + 4 (int typeModifier) + 2 (short formatCode)
// 4 + 2 + 4 + 2 + 4 + 2 = 18
// Add name length, plus 1 for null terminator '\0'
return 18 + name.length() + 1;
}
}
}
record DataRow(List<ByteBuffer> values) implements PGServerMessage {
}
record CommandComplete(String commandTag) implements PGServerMessage {
}
static ByteBuffer encode(PGServerMessage message) {
return switch (message) {
case AuthenticationRequest auth -> {
var buffer = ByteBuffer.allocate(9);
buffer.put((byte) 'R'); // 'R' for AuthenticationRequest
buffer.putInt(8); // Length
buffer.putInt(0); // Authentication type, 0 for OK
buffer.flip();
yield buffer;
}
case BackendKeyData keyData -> {
var buffer = ByteBuffer.allocate(13);
buffer.put((byte) 'K'); // 'K' for BackendKeyData
buffer.putInt(12); // Length
buffer.putInt(keyData.processId()); // Process ID
buffer.putInt(keyData.secretKey()); // Secret key
buffer.flip();
yield buffer;
}
case ReadyForQuery ready -> {
var buffer = ByteBuffer.allocate(6);
buffer.put((byte) 'Z'); // 'Z' for ReadyForQuery
buffer.putInt(5); // Length
buffer.put((byte) 'I'); // Transaction status indicator, 'I' for idle
buffer.flip();
yield buffer;
}
case RowDescription rowDesc -> {
var fields = rowDesc.fields();
var length = 4 + 2 + fields.stream().mapToInt(RowDescription.Field::length).sum();
var buffer = ByteBuffer.allocate(length + 1)
.put((byte) 'T')
.putInt(length)
.putShort((short) fields.size());
fields.forEach(field -> buffer
.put(field.name().getBytes(StandardCharsets.UTF_8))
.put((byte) 0) // null-terminated
.putInt(field.tableObjectId())
.putShort((short) field.columnAttributeNumber())
.putInt(field.dataTypeObjectId())
.putShort((short) field.dataTypeSize())
.putInt(field.typeModifier())
.putShort((short) field.formatCode()));
buffer.flip();
yield buffer;
}
case DataRow dataRow -> {
var values = dataRow.values();
// For each value, we need to add 4 bytes for the length, plus the length of the value
var length = 4 + 2 + values.stream().map(it -> it.remaining() + 4).reduce(0, Integer::sum);
var buffer = ByteBuffer.allocate(length + 1) // +1 for msg type
.put((byte) 'D')
.putInt(length) // +4 for length
.putShort((short) values.size()); // +2 for number of columns
for (var value : values) {
buffer.putInt(value.remaining());
buffer.put(value);
}
buffer.flip();
yield buffer;
}
case CommandComplete cmdComplete -> {
var commandTag = cmdComplete.commandTag();
var length = 4 + commandTag.length() + 1;
yield ByteBuffer.allocate(length + 1) // +1 for msg type
.put((byte) 'C')
.putInt(length) // +4 for length
.put(commandTag.getBytes(StandardCharsets.UTF_8))
.put((byte) 0) // null terminator
.flip();
}
};
}
}
// In "AsynchronousSocketServer"
private static void handleStartupMessage(PGClientMessage.StartupMessage startup, AsynchronousSocketChannel client) {
System.out.println("[SERVER] Startup Message: " + startup);
Future<Integer> writeResult;
// Then, write AuthenticationOk
PGServerMessage.AuthenticationRequest authRequest = new PGServerMessage.AuthenticationRequest();
writeResult = client.write(PGServerMessage.encode(authRequest));
// Then, write BackendKeyData
PGServerMessage.BackendKeyData backendKeyData = new PGServerMessage.BackendKeyData(1234, 5678);
writeResult = client.write(PGServerMessage.encode(backendKeyData));
// Then, write ReadyForQuery
PGServerMessage.ReadyForQuery readyForQuery = new PGServerMessage.ReadyForQuery();
writeResult = client.write(PGServerMessage.encode(readyForQuery));
try {
writeResult.get();
} catch (Exception e) {
e.printStackTrace();
}
}
private static void handleQueryMessage(PGClientMessage.QueryMessage query, AsynchronousSocketChannel client) {
System.out.println("[SERVER] Query Message: " + query);
Future<Integer> writeResult;
// Let's assume it's a query message, and just send a simple response
// First we send a RowDescription. We'll send two columns, with names "id" and "name"
PGServerMessage.RowDescription rowDescription = new PGServerMessage.RowDescription(List.of(
new PGServerMessage.RowDescription.Field("id", 0, 0, 23, 4, -1, 0),
new PGServerMessage.RowDescription.Field("name", 0, 0, 25, -1, -1, 0)
));
writeResult = client.write(PGServerMessage.encode(rowDescription));
// Then we send a DataRow for each row. We'll send two rows, with values (1, "one") and (2, "two")
PGServerMessage.DataRow dataRow1 = new PGServerMessage.DataRow(List.of(
ByteBuffer.wrap("1".getBytes(StandardCharsets.UTF_8)),
ByteBuffer.wrap("one".getBytes(StandardCharsets.UTF_8))
));
writeResult = client.write(PGServerMessage.encode(dataRow1));
PGServerMessage.DataRow dataRow2 = new PGServerMessage.DataRow(List.of(
ByteBuffer.wrap("2".getBytes(StandardCharsets.UTF_8)),
ByteBuffer.wrap("two".getBytes(StandardCharsets.UTF_8))
));
writeResult = client.write(PGServerMessage.encode(dataRow2));
// We send a CommandComplete
PGServerMessage.CommandComplete commandComplete = new PGServerMessage.CommandComplete("SELECT 2");
writeResult = client.write(PGServerMessage.encode(commandComplete));
// Finally, write ReadyForQuery
PGServerMessage.ReadyForQuery readyForQuery = new PGServerMessage.ReadyForQuery();
writeResult = client.write(PGServerMessage.encode(readyForQuery));
try {
writeResult.get();
} catch (Exception e) {
e.printStackTrace();
}
}