Skip to content

DataSource

A DataSource object in Relta has the following roles and responsibilities:

  1. It manages the connection to the external datasource
  2. It maintains a SemanticLayer object
  3. It creates a local transient database from the external datasource that matches it's semantic layer

We support connecting to PostgreSQL, DuckDB, MySQL, CSV, and Parquet datasources.

In the set up of Relta, you will set up your DataSource's external connections (1) and their semantic layers (2). When deploying Relta, you will read in those persisted DataSources and create their local databases, usually limiting it to the current user you are serving (3).

The "local transient database" is intended to provide two benefits:

  1. Each user receives a sandboxed database with only the data they are allowed to access. The LLM cannot mix user's data when running a query, because it isn't there.
  2. No LLM-produced SQL is ran against your production database -- this prevents long-running queries from locking up your database.

More on the local database

Under the hood, we use DuckDB to create the local database. It runs fully in process, allowing for fast query times.

Usage

Setting up a DataSource can be done directly with the library or through the CLI. Using the CLI for this is the recommended approach and is covered in the Getting Started guide. We will show how to do this with the library, as the same functions are also used when deploying Relta.

You should create a DataSource through create_datasource method in the Relta Client. This persists the DataSource to your repository. There are also get_datasource and get_or_create_datasource methods for accessing existing DataSource objects.

import relta
rc = relta.Client()
source = rc.create_datasource("data/invoices.csv", name="week_invoice")

Don't duplicate DataSources in the same Python process

Getting multiple copies of the same DataSource (e.g. by creating one and getting it right after) and assigning them to different variables will create two separate objects in memory. As they have separate semantic layer objects, this can cause unexpected behavior.

### DON'T DO THIS ###
source_original = rc.create_datasource("data/invoices.csv", name="week_invoice")
source_copy = rc.get_datasource("week_invoice")

### THIS IS OK ###
source = rc.create_datasource("data/invoices.csv", name="week_invoice")
source = rc.get_datasource("week_invoice")

Accessing a DataSource's semantic layer can be done through the semantic_layer property. This is usually only done to be able to setup the semantic layer.

import relta
rc = relta.Client()
source = rc.get_datasource(name="week_invoice")
source.semantic_layer.propose("How many times did Alice make a payment?")

When you are deploying Relta, you should already have a set up semantic layer in your repository, so accessing the semantic layer is not necessary.

To pull in data from the external datasource once a semantic layer is set up, you can use the deploy method, which will create the local transient database.

import relta
rc = relta.Client()
source = rc.get_datasource(name="week_invoice")
print(source.semantic_layer.metrics) # this is automatically loaded in when you get the datasource
source.deploy()

We refer to this as deploying the semantic layer on the database. Once this is complete, you can create Chats on this DataSource and begin running queries.

import relta
rc = relta.Client()
source = rc.get_datasource(name="week_invoice")
source.deploy()

chat = rc.create_chat(source)
resp = chat.prompt("How many times did Omar pay")
print(resp.text)

API Reference

Bases: SQLModel

Source code in src/relta/datasource/datasource.py
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
class DataSource(SQLModel, table=True):
    conn: ClassVar[
        duckdb.DuckDBPyConnection
    ]  # populated when a Client is initialized. ClassVar is not serialized by SQLModel
    # this connection is for the external data source, not for accessing the Relta replica.

    id: Optional[int] = Field(default=source_id_seq.next_value(), primary_key=True)
    type: DataSourceType
    connection_uri: str = Field(sa_column=Column(String))
    name: str
    last_hydrated: Optional[datetime] = Field(default=None)

    # These private fields are populated on creation or load.
    # They aren't validated by SQLModel (even if you pass them in on __init__ as private attributes, no matter the ConfigDict(extra="allow"))
    _config: Configuration
    _semantic_layer: SemanticLayer

    chats: list["Chat"] = Relationship(back_populates="datasource")  # type: ignore # noqa: F821 # avoid circular import w quotes

    _encryption_key: ClassVar[bytes]  # Class-level encryption key

    @property
    def semantic_layer(self):
        """The semantic layer of the `DataSource`. Populated by `Client`"""
        return self._semantic_layer

    @property
    def decrypted_uri(self) -> str:
        """Access the decrypted connection URI"""
        return self._decrypt_uri(self.connection_uri, self._fernet)

    @classmethod
    def set_encryption_key(cls, key: bytes):
        """Set the encryption key for all DataSource instances"""
        cls._encryption_key = key
        cls._fernet = Fernet(key)

    def __init__(self, **data):
        if "connection_uri" in data:
            # Encrypt the connection URI before storing
            print(f"key is {self._encryption_key}")
            data["connection_uri"] = self._encrypt_uri(
                data["connection_uri"], self._fernet
            )
        super().__init__(**data)

    def _attach(self):
        """Attaches the database to Relta. This is called at client initialization.

        Raises:
            NotImplementedError: If the data source type is not supported
        """
        if self.type == DataSourceType.POSTGRES:
            self.conn.sql("INSTALL POSTGRES")
            self._connect_postgres()
        elif self.type == DataSourceType.MYSQL:
            self._connect_mysql()
        elif self.type == DataSourceType.CSV:
            self._connect_csv()
        elif self.type == DataSourceType.DUCKDB:
            self._connect_duckdb()
        else:
            raise NotImplementedError

    def _connect_duckdb(self):
        """Private method used to connect to a DuckDB datasource"""
        try:
            self.conn.sql(f"ATTACH '{self.decrypted_uri}' AS {self.name}")

            self.conn.sql(
                f"ATTACH '{self._config.transient_data_dir_path}/{self.name}.duckdb' AS transient_{self.name}"
            )

        except duckdb.BinderException as e:
            logfire.error(e)
            logfire.error(
                f"Unable to attach to DuckDB database with connection URI {self.decrypted_uri}. This could be because the connection_uri is wrong or database with same name already exists in Relta."
            )
            raise duckdb.BinderException

    def _connect_postgres(self):
        """Private method used to connect to a Postgres data source"""
        parsed_uri = urlparse(self.decrypted_uri)
        try:
            self.conn.sql(
                f"ATTACH 'dbname={parsed_uri.path.lstrip('/')} {f'host={parsed_uri.hostname}' if parsed_uri.hostname else ''} {f'user={parsed_uri.username}' if parsed_uri.username else ''} {f'password={parsed_uri.password}' if parsed_uri.password else ''} {f'port={parsed_uri.port}' if parsed_uri.port else ''}' AS {self.name} (TYPE POSTGRES, READ_ONLY)"
            )

            self.conn.sql(
                f"ATTACH '{self._config.transient_data_dir_path}/{self.name}.duckdb' AS transient_{self.name}"
            )
        except duckdb.BinderException as e:
            logfire.error(e)
            logfire.error(
                f"Unable to attach to Postgres database with name {self.name}. This could be because the connection_uri is wrong or database with same name already exists in Relta."
            )
            raise duckdb.BinderException

    def _connect_mysql(self):
        raise NotImplementedError

    def _connect_csv(self, dtypes: Optional[dict[str, str]] = None):
        """Private method used to connect to a CSV data source and create a table in Relta

        Args:
            dtypes (dict, optional): Map of column names to datatypes, overrides a column's auto-detected type.
                The datatypes should be [DuckDB datatypes](https://duckdb.org/docs/sql/data_types/overview).

        Raises:
            duckdb.CatalogException: raised if a table with the same name already exists in Relta
        """
        try:
            self.conn.sql(
                f"ATTACH '{self._config.transient_data_dir_path}/{self.name}.duckdb' AS transient_{self.name}"
            )
            self._load_csv(dtypes)
        except duckdb.CatalogException as e:
            logfire.error(e)
            logfire.error(
                f"Table with name {self.name} already exists in Relta. Please choose a different name or consider refreshing data using rehydrate()"
            )
            raise duckdb.CatalogException

    def _connect_parquet(self):
        raise NotImplementedError

    def connect(self, dtypes: Optional[dict[str, str]] = None):
        """Creates a connection to the given data source. This allows Relta to query the underlying data (e.g. read schema) but does not copy data into Relta.

        Args:
            dtypes (dict, optional): Map of column names to datatypes, overrides a column's auto-detected type.
                *Only for CSVs*. The datatypes should be [DuckDB datatypes](https://duckdb.org/docs/sql/data_types/overview).

        Raises:
            duckdb.CatalogException: If a table with the same name is already connected to Relta
        """
        if self.type == DataSourceType.CSV:
            self._connect_csv(dtypes)
        elif self.type == DataSourceType.PARQUET:
            self._connect_parquet()
        elif self.type == DataSourceType.POSTGRES:
            self._connect_postgres()
        elif self.type == DataSourceType.MYSQL:
            self._connect_mysql()
        elif self.type == DataSourceType.DUCKDB:
            self._connect_duckdb()

    def _disconnect(self):
        """Disconnects the data source from Relta

        Raises:
            duckdb.CatalogException: If the underlying data source does not exist in Relta
        """
        self.conn.sql(
            "ATTACH IF NOT EXISTS ':memory:' AS memory_db"
        )  # this is to gaurd in case the DB we are deleting is the default database
        try:
            if (
                self.type == DataSourceType.POSTGRES
                or self.type == DataSourceType.DUCKDB
            ):
                self.conn.sql("USE memory_db")
                self.conn.sql(f"DETACH {self.name}")
                os.remove(f"{self._config.transient_data_dir_path}/{self.name}.duckdb")
            elif self.type == DataSourceType.CSV:
                self.conn.sql("USE memory_db")
                self.conn.sql(f"DETACH  transient_{self.name}")
                os.remove(f"{self._config.transient_data_dir_path}/{self.name}.duckdb")

        except duckdb.CatalogException as e:
            logfire.error(e)
            logfire.error(
                f"Table with name {self.name} does not exist in Relta. Please check the name and try again"
            )
            raise duckdb.CatalogException

    def load(self):
        """Updates the data in Relta from the underlying data source"""
        if self.type == DataSourceType.POSTGRES:
            self._load_postgres()
        elif self.type == DataSource.DUCKDB:
            self._load_postgres()
        elif self.type == DataSourceType.CSV:
            self._load_csv()

    def _load_csv(self, dtypes: Optional[dict[str, str]] = None):
        self.last_hydrated = datetime.now()
        self.conn.sql(f"USE transient_{self.name}")
        if dtypes:
            create_table_sql = f"CREATE OR REPLACE TABLE {self.name} AS SELECT * from read_csv('{self.decrypted_uri}', types = {dtypes})"
        else:
            create_table_sql = f"CREATE OR REPLACE TABLE {self.name} AS SELECT * from read_csv('{self.decrypted_uri}')"

        self.conn.sql(create_table_sql)
        self.last_hydrated = datetime.now()

    def _load_postgres(self):
        self.conn.sql("USE relta_data")
        self.conn.sql(
            f"ATTACH IF NOT EXISTS '{self._config.transient_data_dir_path}/transient_{self.name}.duckdb' As transient_{self.name}"
        )

        self.conn.sql(f"USE transient_{self.name}")
        for metric in self._semantic_layer.metrics:
            self.conn.sql(
                f"CREATE TABLE IF NOT EXISTS {metric.name} AS {metric.sql_to_underlying_datasource}"
            )
            # for column in metric.dimensions:
            #    self.conn.sql(f"CREATE OR REPLACE VIEW {metric.name} AS SELECT {column}, {metric.name} FROM {self.name}")

        self.last_hydrated = datetime.now()  # TODO: this needs to be written to the database, but that is a client operation... what to do about this?
        # TODO: when to detach? it should be after hydrating?

    @deprecated(reason="Use DataSource().semantic_layer property instead")
    def create_semantic_layer(self) -> SemanticLayer:
        """Returns the semantic model of the data source"""
        self._semantic_layer = SemanticLayer(self, self._config)
        return self._semantic_layer

    @deprecated(reason="Use DataSource().semantic_layer property instead")
    def get_semantic_layer(self) -> SemanticLayer:
        return self._semantic_layer

    def _get_ddl(self) -> str:
        """Returns the DDL of the data source"""

        if self.type == DataSourceType.POSTGRES or self.type == DataSourceType.DUCKDB:
            ddl_list = self.conn.sql(
                f"select sql from duckdb_tables() where database_name='{self.name}' and schema_name != 'information_schema' and schema_name != 'pg_catalog'"
            ).fetchall()
            ddl = "\n".join([ddl[0] for ddl in ddl_list])
        elif self.type == DataSourceType.CSV:
            ddl = self.conn.sql(
                f"select sql from duckdb_tables() where table_name='{self.name}'"
            ).fetchone()[0]  # self.conn.sql(f"DESCRIBE {self.name}")

        return ddl

    def _create_metrics(self):
        self.conn.sql("USE relta_data")
        for metric in self._semantic_layer.metrics:
            # fully_formed_sql = self._append_db_to_table_name(metric.sql_to_underlying_datasource, f'transient_{self.name}')
            self.conn.sql(
                f"CREATE OR REPLACE VIEW {metric.name} AS select * from transient_{self.name}.{metric.name}"
            )

    def _execute_datasource_sql(self, sql: str):
        """Run SQL against the underlying datasource"""
        if self.type == DataSourceType.CSV:
            self.conn.sql(f"USE transient_{self.name}")
            return self.conn.sql(sql)
        else:
            raise NotImplementedError

    def _execute_sql(self, sql: str):
        self.conn.sql("USE relta_data")
        return self.conn.sql(sql)

    def _get_transient_ddl(self):
        # self.conn.sql(f"USE transient_{self.name}")
        return self.conn.sql(
            f"SELECT * FROM duckdb_tables() where database_name='transient_{self.name}'"
        ).fetchall()

    @staticmethod
    def _append_db_to_table_name(original_sql: str, db_name: str) -> str:
        """In DuckDB we need fully formed table and column names with database name appended. This method creates those.

        Args:
            original_sql (str): the sql we will be modifying

        Returns:
            str: The SQL statement with db name appended to table names
        """
        fully_formed_sql = original_sql
        table_names = list(parse_one(fully_formed_sql).find_all(exp.Table))
        tables = [
            str(table).partition(" ")[0] for table in table_names
        ]  # this is bc sqlglot returns the table name as '{TABLE NAME} AS {ALIAS}'
        tables = set(tables)
        for table in tables:
            fully_formed_sql = re.sub(
                r"\b" + re.escape(table) + r"\b",
                f"{db_name}.{table}",
                fully_formed_sql,
            )

        return fully_formed_sql

    def _create_transient_tables(self, calculate_statistics: bool = True):
        """Creates the transient tables in DuckDB

        Args:
            calculate_statistics (bool, optional): Calculate statistics (i.e. low cardinality columns) for each metric. Defaults to True.
        """
        self.conn.sql(f"USE transient_{self.name}")

        if self.type == DataSourceType.POSTGRES or self.type == self.type.DUCKDB:
            for metric in self._semantic_layer.metrics:
                fully_formed_sql_to_underlying_source = self._append_db_to_table_name(
                    metric.sql_to_underlying_datasource, self.name
                )
                self.conn.sql(
                    f"CREATE OR REPLACE TABLE {metric.name} AS {fully_formed_sql_to_underlying_source}"
                )

        elif self.type == DataSourceType.CSV:
            for metric in self._semantic_layer.metrics:
                fully_formed_sql_to_underlying_source = self._append_db_to_table_name(
                    metric.sql_to_underlying_datasource, f"transient_{self.name}"
                )
                self.conn.sql(
                    f"CREATE OR REPLACE TABLE {metric.name} AS {fully_formed_sql_to_underlying_source}"
                )

        if calculate_statistics:
            # the following code identifies low cardinality columns
            for metric in self.semantic_layer.metrics:
                for dimension in metric.dimensions:
                    dimension.categories = []
                    cardinality = self.conn.sql(
                        f"SELECT approx_count_distinct({dimension.name}) from {metric.name}"
                    ).fetchone()[0]
                    if cardinality < 100 and not dimension.skip_categorical_load:
                        categories = [
                            value[0]
                            for value in self.conn.sql(
                                f"SELECT DISTINCT {dimension.name} FROM {metric.name}"
                            ).fetchall()
                        ]
                        dimension.categories = categories

    def deploy(self, statistics: bool = True):
        """
        Deploys the semantic layer to the data source.

        Args:
            statistics (bool, optional): Calculate statistics (i.e. low cardinality columns) for each metric. Defaults to True.
        """
        logfire.span(
            "deploying semantic layer {semantic_layer}",
            semantic_layer=self.semantic_layer,
        )
        self._drop_removed_metrics()
        self._create_transient_tables(statistics)
        self._create_metrics()
        logfire.info("semantic layer deployed")

    def _drop_removed_metrics(self):
        """Checks the current list of metrics against views and transient tables. Drop them if they are no longer in the semantic layer"""
        self.conn.sql("use relta_data")
        if self.type == DataSourceType.CSV:  # on CSV we copy the entire data as a table
            existing_metrics = self.conn.sql(
                f"SELECT table_name FROM duckdb_tables() where database_name='transient_{self.name}' and table_name!='{self.name}'"
            ).fetchall()
        else:
            existing_metrics = self.conn.sql(
                f"SELECT table_name FROM duckdb_tables() where database_name='transient_{self.name}'"
            ).fetchall()

        existing_metric_names = [metric[0] for metric in existing_metrics]
        metric_names_to_keep = [metric.name for metric in self._semantic_layer.metrics]

        for metric_name in existing_metric_names:
            if metric_name not in metric_names_to_keep:
                self.conn.sql(
                    f"DROP TABLE IF EXISTS transient_{self.name}.{metric_name}"
                )
                self.conn.sql(f"DROP VIEW IF EXISTS {metric_name}")

    @staticmethod
    def _infer_name_if_none(type: DataSourceType, connection_uri: str) -> str:
        if type is DataSourceType.CSV or type is DataSourceType.DUCKDB:
            name = (
                connection_uri.split("/")[-1]
                .split(".")[0]
                .replace(" ", "_")
                .replace("-", "_")
            )
        elif type is DataSourceType.POSTGRES or DataSourceType.DUCKDB:
            name = connection_uri.split("/")[-1]

        return name

    @staticmethod
    def _encrypt_uri(connection_uri: str, key) -> str:
        """Encrypt a connection URI"""
        encrypted = key.encrypt(connection_uri.encode())
        return b64encode(encrypted).decode()

    @staticmethod
    def _decrypt_uri(connection_uri: str, key) -> str:
        """Decrypt the stored connection URI"""
        encrypted = b64decode(connection_uri.encode())
        return key.decrypt(encrypted).decode()

decrypted_uri: str property

Access the decrypted connection URI

semantic_layer property

The semantic layer of the DataSource. Populated by Client

connect(dtypes=None)

Creates a connection to the given data source. This allows Relta to query the underlying data (e.g. read schema) but does not copy data into Relta.

Parameters:

Name Type Description Default
dtypes dict

Map of column names to datatypes, overrides a column's auto-detected type. Only for CSVs. The datatypes should be DuckDB datatypes.

None

Raises:

Type Description
CatalogException

If a table with the same name is already connected to Relta

Source code in src/relta/datasource/datasource.py
def connect(self, dtypes: Optional[dict[str, str]] = None):
    """Creates a connection to the given data source. This allows Relta to query the underlying data (e.g. read schema) but does not copy data into Relta.

    Args:
        dtypes (dict, optional): Map of column names to datatypes, overrides a column's auto-detected type.
            *Only for CSVs*. The datatypes should be [DuckDB datatypes](https://duckdb.org/docs/sql/data_types/overview).

    Raises:
        duckdb.CatalogException: If a table with the same name is already connected to Relta
    """
    if self.type == DataSourceType.CSV:
        self._connect_csv(dtypes)
    elif self.type == DataSourceType.PARQUET:
        self._connect_parquet()
    elif self.type == DataSourceType.POSTGRES:
        self._connect_postgres()
    elif self.type == DataSourceType.MYSQL:
        self._connect_mysql()
    elif self.type == DataSourceType.DUCKDB:
        self._connect_duckdb()

create_semantic_layer()

Returns the semantic model of the data source

Source code in src/relta/datasource/datasource.py
@deprecated(reason="Use DataSource().semantic_layer property instead")
def create_semantic_layer(self) -> SemanticLayer:
    """Returns the semantic model of the data source"""
    self._semantic_layer = SemanticLayer(self, self._config)
    return self._semantic_layer

deploy(statistics=True)

Deploys the semantic layer to the data source.

Parameters:

Name Type Description Default
statistics bool

Calculate statistics (i.e. low cardinality columns) for each metric. Defaults to True.

True
Source code in src/relta/datasource/datasource.py
def deploy(self, statistics: bool = True):
    """
    Deploys the semantic layer to the data source.

    Args:
        statistics (bool, optional): Calculate statistics (i.e. low cardinality columns) for each metric. Defaults to True.
    """
    logfire.span(
        "deploying semantic layer {semantic_layer}",
        semantic_layer=self.semantic_layer,
    )
    self._drop_removed_metrics()
    self._create_transient_tables(statistics)
    self._create_metrics()
    logfire.info("semantic layer deployed")

load()

Updates the data in Relta from the underlying data source

Source code in src/relta/datasource/datasource.py
def load(self):
    """Updates the data in Relta from the underlying data source"""
    if self.type == DataSourceType.POSTGRES:
        self._load_postgres()
    elif self.type == DataSource.DUCKDB:
        self._load_postgres()
    elif self.type == DataSourceType.CSV:
        self._load_csv()

set_encryption_key(key) classmethod

Set the encryption key for all DataSource instances

Source code in src/relta/datasource/datasource.py
@classmethod
def set_encryption_key(cls, key: bytes):
    """Set the encryption key for all DataSource instances"""
    cls._encryption_key = key
    cls._fernet = Fernet(key)