type SnippetPayload = {
  marqtuneClientUrl: string;
  apiKey: string;
  inputDataId?: string;
  trainingDatasetId?: string;
  evaluationDatasetId?: string;
  modelId?: string;
  checkpoint?: string;
  showReleaseCheckpointSnippet?: boolean;
};

type SnippetFn = (payload: SnippetPayload) => string;

type Snippets = {
  setup: string;
  python: {
    create_dataset: SnippetFn;
    train_model: SnippetFn;
    evaluate_model: SnippetFn;
  };
};

const PythonMarqtuneSetup: string = `# Install the Python marqtune client, if not already installed.
pip install marqtune
`;

const PythonCreateDataset: SnippetFn = ({ marqtuneClientUrl, apiKey }: SnippetPayload) =>
  `from marqtune.client import Client
from marqtune.enums import DatasetType
from urllib.request import urlopen
import gzip


marqtune_client = Client(
    url="${marqtuneClientUrl}",
    api_key="${apiKey}"
)

# the code below downloads some sample data to be used for training and evaluation, you can replace it with your own data
base_path =  "https://marqo-gcl-public.s3.us-west-2.amazonaws.com/marqtune_test/datasets/v1"
training_data = "gs_100k_training.csv"
eval_data = "gs_25k_eval.csv"
open(training_data, "w").write(
    gzip.open(urlopen(f"{base_path}/{training_data}.gz"), "rb").read().decode("utf-8")
)
open(eval_data, "w").write(
    gzip.open(urlopen(f"{base_path}/{eval_data}.gz"), "rb").read().decode("utf-8")
)

training_data_path = training_data
eval_data_path = eval_data
data_schema = {
  "query": "text", "title": "text", "image": "image_pointer", "score": "score"
}

training_dataset = marqtune_client.create_dataset(
    dataset_name="my_first_training_dataset",
    file_path=training_data_path,
    dataset_type=DatasetType.TRAINING,
    data_schema=data_schema,
    query_columns=["query"],
    result_columns=["title", "image"],
    wait_for_completion=True
)

eval_dataset = marqtune_client.create_dataset(
    dataset_name="my_first_eval_dataset",
    file_path=eval_data_path,
    dataset_type=DatasetType.EVALUATION,
    data_schema=data_schema,
    query_columns=["query"],
    result_columns=["title", "image"],
    wait_for_completion=True,
)

print(f"Training dataset: {training_dataset.dataset_id}")
print(f"Evaluation dataset: {eval_dataset.dataset_id}")`;

const PythonTrainModel: SnippetFn = ({ marqtuneClientUrl, apiKey, trainingDatasetId }: SnippetPayload) =>
  `from marqtune.client import Client
from marqtune.enums import InstanceType


marqtune_client = Client(
  url="${marqtuneClientUrl}",
  api_key="${apiKey}"
)

training_dataset_id = "${trainingDatasetId ?? "<training dataset id>"}"
train_task_params = {
    "epochs": 5,
    "rightKeys": ["image", "title"],
    "rightWeights": [0.9, 0.1],
    "leftKeys": ["query"],
    "leftWeights": [1],
    "weightKey": "score"
}

# Change this to a model/checkpoint of your choice:
base_model = "Marqo/ViT-B-32.laion400m_e31"

model = marqtune_client.train_model(
    dataset_id=training_dataset_id,
    model_name="my_first_trained_model",
    base_model=base_model,
    instance_type=InstanceType.BASIC,
    hyperparameters=train_task_params,
    wait_for_completion=True
)

print(model.model_id)

${releaseCheckpoint(`model.model_id`)}`;

const PythonEvaluateModel: SnippetFn = ({
  marqtuneClientUrl,
  apiKey,
  modelId,
  evaluationDatasetId,
  checkpoint,
  showReleaseCheckpointSnippet,
}: SnippetPayload) => `from marqtune.client import Client


marqtune_client = Client(
  url="${marqtuneClientUrl}",
  api_key="${apiKey}"
)

model_id = "${modelId ?? "<model id>"}"
checkpoint = "${checkpoint ?? "<checkpoint>"}"
evaluation_dataset_id = "${evaluationDatasetId ?? "<evaluation dataset id>"}"
evaluate_task_params = {
    "leftKeys": ["query"],
    "leftWeights": [1],
    "rightKeys": ["image", "title"],
    "rightWeights": [0.9, 0.1],
    "weightKey": "score",
}

evaluation = marqtune_client.evaluate(
    model=f"{model_id}/{checkpoint}",
    dataset_id=evaluation_dataset_id,
    hyperparameters=evaluate_task_params,
    wait_for_completion=True
)

print(evaluation.evaluation_id)

${showReleaseCheckpointSnippet ? releaseCheckpoint(`model_id`) : ""}`;

const releaseCheckpoint = (modelId: string) =>
  `# Release the model to make it available for use in a Marqo Index
marqtune_client.model(${modelId}).release("epoch_1")`;

const MarqtuneSnippets: Snippets = {
  setup: PythonMarqtuneSetup,
  python: {
    create_dataset: PythonCreateDataset,
    train_model: PythonTrainModel,
    evaluate_model: PythonEvaluateModel,
  },
};

export default MarqtuneSnippets;
