GPT4MultiModal¶
The GPT4MultiModal
class is a subclass of the torch.nn.Module
class. This class serves as a model for handling both image and text input in the form of sequences. It integrates the ViTransformerWrapper for image encoding and the Transformer for text decoding.
The primary aim of this class is to enable encoding an image and use it as context for generating a text sequence, hence the name GPT4MultiModal
. Typical usage would be to pass an image to the encoder and a sequence of tokens (corresponding to a language prompt) to the decoder. The class will output a sequence of tokens- the length of the sequence will depend on the transformer architecture used.
Class Constructor¶
This class accepts the following parameters:
Parameters | Keyboard Argument | Type | Default Value | Description |
---|---|---|---|---|
image_size | image_size | int | 256 | Input image size |
patch_size | patch_size | int | 32 | Size of each image patch |
encoder_dim | encoder_dim | int | 512 | Dimension of encoder |
encoder_depth | encoder_depth | int | 6 | The depth of the encoder |
encoder_heads | encoder_heads | int | 8 | The number of attention heads in the encoder |
num_tokens | num_tokens | int | 20000 | The number of unique tokens |
max_seq_len | max_seq_len | int | 1024 | Maximum sequence length for text |
decoder_dim | decoder_dim | int | 512 | Dimension of decoder |
decoder_depth | decoder_depth | int | 6 | The depth of the decoder |
decoder_heads | decoder_heads | int | 8 | The number of attention heads in the decoder |
alibi_num_heads | alibi_num_heads | int | 4 | The number of attention heads per transformer |
use_abs_pos_emb | use_abs_pos_emb | bool | False | If True, embeds input using absolute positional embedding |
cross_attend | cross_attend | bool | True | If True, enables cross attention in decoder |
alibi_pos_bias | alibi_pos_bias | bool | True | If True, positional bias is added to alibi |
rotary_xpos | rotary_xpos | bool | True | Enables rotary positional embeddings |
attn_flash | attn_flash | bool | True | If True, enables the use of Flash-like attention |
qk_norm | qk_norm | bool | True | If True, enables query-key normalization |
Methods¶
The following methods are available in this class.
forward(self, img, text) -> Union[Tensor, str]
¶
The forward
method is used to perform the forward propagation operation of the GPT4MultiModal model. It accepts an image and a sequence of tokens and returns a sequence of tokens.
Parameters:
Parameters | Keyboard Argument | Type | Default Value | Description |
---|---|---|---|---|
img | img | Tensor | - | The input image tensor |
text | text | Tensor | - | The sequence of tokens to be used as input |
Returns:
Type | Description |
---|---|
Union[Tensor, str] | Output sequence of tokens or an error message if an exception is encountered |
Example of Use¶
Consider having an image tensor img
of size (1, 256, 256, 3) and a text tensor text
of size (1, 50). Here is an example of how to use GPT4MultiModal
import torch
from zeta.models import GPT4MultiModal
# Initialize the model
model = GPT4MultiModal(
image_size=256,
patch_size=32,
encoder_dim=512,
encoder_depth=6,
encoder_heads=8,
num_tokens=20000,
max_seq_len=1024,
decoder_dim=512,
decoder_depth=6,
decoder_heads=8,
alibi_num_heads=4,
use_abs_pos_emb=False,
cross_attend=True,
alibi_pos_bias=True,
rotary_xpos=True,
attn_flash=True,
qk_norm=True,
)
# Assume we have an image tensor 'img' of size (1, 256, 256, 3) and
# a text tensor 'text' of size (1, 50)
# Run the model
output = model(img, text)
This will encode img
using the ViTransformerWrapper
and then use the encoded embeddings as the context for the Transformer
to generate a sequence of tokens from text
. The sequence of tokens, output
, is the result.